首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Tensorflow2.1实践-IMDB影评文本分类(Kears

本文将文本形式的影评分为“正面”或“负面”影评。这是一个二元分类(又称为两类分类)的示例,也是一种重要且广泛适用的机器学习问题。

下载IMDB数据集

参数 num_words=10000 会保留训练数据中出现频次在前 10000 位的字词。为确保数据规模处于可管理的水平,罕见字词将被舍弃。

查看数据集中的数据格式

数据集已经过预处理:每个样本都是一个整数数组,表示影评中的字词。每个标签都是整数值 0 或 1,其中 0 表示负面影评,1 表示正面影评。

将数据集中的整数转换回单词

了解如何将整数转换回文本很有用。在以下代码中,我们将创建一个辅助函数word_index来查询包含整数到字符串映射的字典对象:

现在,我们可以使用 decode_review 函数显示第一条影评的文本:

输出如下:

数据准备

1、影评(整数数组)必须转换为Tensor,然后才能馈送到神经网络中。我们可以通过以下两种方法实现这种转换:对数组进行One-hot编码,将它们转换为由 0 和 1 构成的向量。例如,序列 [3, 5] 将变成一个 10000 维的向量,除索引 3 和 5 转换为 1 之外,其余全转换为 0。然后,将它作为网络的第一层,一个可以处理浮点向量数据的密集层。不过,这种方法会占用大量内存,需要一个大小为 num_words * num_reviews 的矩阵。

2、或者,我们可以填充数组,使它们都具有相同的长度,然后创建一个形状为 max_length * num_reviews 的整数张量。我们可以使用一个能够处理这种形状的嵌入层作为网络中的第一层。在本教程中,我们将使用第二种方法,当然你也可以用word2vec,不过效果都差不多,因为第二种会和神经网络一起训练,而如果用word2vec的话embedding层的参数就不会进行训练。

由于影评的长度必须相同,我们将使用 pad_sequences 函数将长度标准化:

检查(已填充的)第一条影评:

构建模型

神经网络通过堆叠层创建而成,这需要做出两个架构方面的主要决策:

1、要在模型中使用多少个层?

2、要针对每个层使用多少个隐藏单元?

在本示例中,输入数据由字词-索引数组(word-index)构成。要预测的标签是 0 或 1(好or坏)。接下来,我们为此问题构建一个模型:

第一层是 Embedding 层。该层会在整数编码的词汇表中查找每个字词-索引的嵌入向量。模型在接受训练时会学习这些向量。这些向量会向输出数组添加一个维度。生成的维度为:(batch, sequence, embedding)。接下来,一个 GlobalAveragePooling1D 层通过对序列维度求平均值,针对每个样本返回一个长度固定的输出向量。这样,模型便能够以尽可能简单的方式处理各种长度的输入。该长度固定的输出向量会传入一个全连接 (Dense) 层(包含 16 个隐藏单元)。最后一层与单个输出节点密集连接。应用 sigmoid 激活函数后,结果是介于 0 到 1 之间的浮点值,表示概率或置信水平。

隐藏单元

上述模型在输入和输出之间有两个中间层(也称为“隐藏”层)。输出(单元、节点或神经元)的数量是相应层的表示法空间的维度。换句话说,该数值表示学习内部表示法时网络所允许的自由度。如果模型具有更多隐藏单元(更高维度的表示空间)和/或更多层,则说明网络可以学习更复杂的表示法。不过,这会使网络耗费更多计算资源,并且可能导致学习不必要的模式(可以优化在训练数据上的表现,但不会优化在测试数据上的表现)。这称为过拟合,我们稍后会加以探讨。

损失函数和优化器

模型在训练时需要一个损失函数和一个优化器。由于这是一个二元分类问题且模型会输出一个概率(应用 S 型激活函数的单个单元层),因此我们将使用 binary_crossentropy 损失函数。该函数并不是唯一的损失函数,例如,您可以选择 mean_squared_error。但一般来说,binary_crossentropy 更适合处理概率问题,它可测量概率分布之间的“差距”,在本例中则为实际分布和预测之间的“差距”。稍后,在探索回归问题(比如预测房价)时,我们将了解如何使用另一个称为均方误差的损失函数。现在,配置模型以使用优化器和损失函数:

训练模型

用有 512 个样本的小批次训练模型 40 个周期。这将对 x_train 和 y_train 张量中的所有样本进行 40 次迭代。在训练期间,监控模型在验证集的 10000 个样本上的损失和准确率:

评估模型

我们来看看模型的表现如何。模型会返回两个值:损失(表示误差的数字,越低越好)和准确率。

为了防止过拟合可以加入checkpoint,dropout等。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20201021A04X1X00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券