首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

TensorFlow中如何优雅给模型批量输入训练样本

在使用TensorFlow进行模型训练的时候,我们一般不会在每一步训练的时候输入所有训练样本数据,而是通过batch的方式,每一步都随机输入少量的样本数据,这样可以防止过拟合。

所以,对训练样本的shuffle和batch是很常用的操作。

这里再说明一点,为什么需要打乱训练样本即shuffle呢?

举个例子:比如我们在做一个分类模型,前面部分的样本的标签都是A,后面部分的样本的标签全是B,那你如果不打乱样本顺序的话,就会出现前面训练出来的模型,在预测的时候会偏向于输出A,因为模型一直在标签A的方向拟合,而后面的模型,会偏向于预测B

直接看代码例子,有详细注释!!

```python

import tensorflow as tf

import numpy as np

d = np.arange(0,60).reshape([6, 10])

# 将array转化为tensor

data = tf.data.Dataset.from_tensor_slices(d)

# 从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本

# buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size,

# 此时会再次打乱

data = data.shuffle(buffer_size=3)

# 每次从buffer中抽取4个样本

data = data.batch(4)

# 将data数据集重复,其实就是2个epoch数据集

data = data.repeat(2)

# 构造获取数据的迭代器

iters = data.make_one_shot_iterator()

# 每次从迭代器中获取一批数据

batch = iters.get_next()

sess = tf.Session()

sess.run(batch)

# 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError

```

```

In [21]: d

Out[21]:

array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],

[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],

[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],

[30, 31, 32, 33, 34, 35, 36, 37, 38, 39],

[40, 41, 42, 43, 44, 45, 46, 47, 48, 49],

[50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])

In [22]: sess.run(batch)

Out[22]:

array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],

[30, 31, 32, 33, 34, 35, 36, 37, 38, 39],

[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],

[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])

In [23]: sess.run(batch)

Out[23]:

array([[40, 41, 42, 43, 44, 45, 46, 47, 48, 49],

[50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])

```

从输出结果可以看出:

1. shuffle是按顺序将数据放入buffer里面的;

2. 当repeat函数在shuffle之后的话,是将一个epoch的数据集抽取完毕,再进行下一个epoch的。

那么,当repeat函数在shuffle之前会怎么样呢?如下:

```python

data = data.repeat(2)

data = data.shuffle(buffer_size=3)

data = data.batch(4)

```

```

In [25]: sess.run(batch)

Out[25]:

array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],

[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],

[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],

[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

In [26]: sess.run(batch)

Out[26]:

array([[50, 51, 52, 53, 54, 55, 56, 57, 58, 59],

[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],

[30, 31, 32, 33, 34, 35, 36, 37, 38, 39],

[30, 31, 32, 33, 34, 35, 36, 37, 38, 39]])

In [27]: sess.run(batch)

Out[27]:

array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],

[50, 51, 52, 53, 54, 55, 56, 57, 58, 59],

[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],

[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

```

可以看出,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,一直shuffle和batch下去。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20200512A0PXOR00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券