首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >关于在TFLearn中使用非二进制标签的问题

关于在TFLearn中使用非二进制标签的问题
EN

Stack Overflow用户
提问于 2019-08-15 22:11:48
回答 1查看 130关注 0票数 0

我正在尝试用Python3中的TFLearn编写一个神经网络,我有一个与标签相关的问题。神经网络的输入是长度为11的1维向量。与这些输入对应的标签也是长度为11的1维向量。但是,它们的值不是0和1。在大多数示例中,标签通常由1,0或0,1组成,例如在对猫和狗的图像进行分类的情况下。然而,我的情况更独特,我需要使用4,9,11,2,1,6,4,6,6,10,1,0这样的标签向量。有许多不同的标签类型,而不是典型的例子中只有2,我不能用由0和1组成的向量来重新编码我的标签。我的问题是,当使用所示形式的标签时,它似乎不起作用。基本上,我想知道这是为什么,以及如何让TFLearn正确地处理它。当我训练神经网络时,我得到了大约78%的准确率。但是,当我随后尝试.predict()函数时,它会输出一个由所有小于1的值组成的向量。

我的标签向量中的值始终是0到11之间的整数。所以我希望输出的值也在这个范围内,但实际上它输出的值在0到1之间。我该如何解决这个问题?到目前为止,我已经尝试将我标签中的每个值除以11,这样4,9,11,2,1,6,4,6,10,1,0将成为4/11,9/11,11/11,2/11,1/11,6/11,4/11,6/11,10/11,1/11,0/11.,然而,这似乎也不起作用。我仍然得到了大约78%的精度结果和0到1的输出向量,但它的表现就像它的过度拟合一样。我不确定我是否在编程上做错了什么,使用了错误的过程,或者是否真的太合适了。下面是我使用的TFLearn代码。

代码语言:javascript
复制
#NN starts
net = tflearn.input_data(shape=[None, len(input[0])])
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, len(labels[0]), activation='softmax')
net = tflearn.regression(net)

# Define model
model = tflearn.DNN(net)

# Start training (apply gradient descent algorithm)
model.fit(input, labels, n_epoch=10, batch_size=16, show_metric=True)

#Predict
pred = model.predict(testvector)
print('output = ', pred)`
EN

Stack Overflow用户

发布于 2019-08-19 16:11:55

获取0到1之间的输出预测值的原因是您正在使用softmax激活功能。Softmax function为0,1范围内的每个输出节点分配一个加起来为1的概率。因此,您不会获得任何整数值作为输出。归一化(即除以11)也不起作用,因为每个输出值都是相互独立的。

您可以通过以下方式将您的问题转换为multiclass-multilabel classification问题:

输出标签的

  1. 将每个数字转换为由四位数组成的二进制(因为最大值是11,这需要4位二进制数)。例: 4,9,...变成0100,1001,...
  2. 将每个二进制数作为一个类。所以你的问题变成了一个分类问题。例如: 0100,1001,...变成0,1,0,0,1,0,0,1,...其中每四位数字将代表你实际输出的数字。
  3. 这种方法将你的训练和测试数据集标签转换成二进制数字数组。
  4. 然后在你的模型中,代替使用softmax激活函数,你可以使用sigmoid激活函数。Sigmoid函数将为0,1范围内的每个输出节点分配一个概率。因此,在训练模型后,您的预测输出应该类似于: 0.0012,0.7890,0.0001,0.0100,0.8801,0.0030,0.0440,0.9120,...
  5. 选择适当的阈值将预测转换为二进制数。例如:如果我们选择0.5作为阈值,那么输出就变成: 0,1,0,0,1,0,0,1,...然后,您可以将每四位数视为实际输出标签的单个数字。
票数 0
EN
查看全部 1 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57510987

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档