用Tensorflow搭建卷积神经网络

卷积神经网络(convolutional neural networks,简称CNN)它已经在很多实际应用中获得了优秀的结果,特别是在图像的识别领域。本文将基于NotMnist数据集,使用tensorflow搭建卷积网络,完成分类功能。

关于对数据集的读取和预处理,在之前的一篇《基于Tensorflow搭建神经网络》里讲过,这里就不在重复讲述,小伙伴们可以去这篇文章里找到相关代码。

这里建一个两层卷积网络,并将训练结果与之前的两层隐含层的神经网络结果进行对比。网络结构是两层卷积+一层全连接。

网络参数如下:

num_channels=1 #图像通道数,灰度图,因此为1

batch_size = 16 #每次训练的批量大小

patch_size = 5 #卷积核大小

depth = 16 #卷积通道数

num_hidden = 64 #全连接层节点数

对应的各层参数为:

layer1_weights = tf.Variable(tf.truncated_normal(

[patch_size, patch_size,num_channels, depth], stddev=0.1))

layer1_biases = tf.Variable(tf.zeros([depth]))

layer2_weights = tf.Variable(tf.truncated_normal(

[patch_size, patch_size, depth, depth], stddev=0.1))

layer2_biases = tf.Variable(tf.constant(1.0, shape=[depth]))

layer3_weights= tf.Variable(tf.truncated_normal(

[image_size // 4 * image_size // 4 * depth,num_hidden], stddev=0.1))

layer3_biases = tf.Variable(tf.constant(1.0, shape=[num_hidden]))

layer4_weights = tf.Variable(tf.truncated_normal(

[num_hidden, num_labels], stddev=0.1))

layer4_biases = tf.Variable(tf.constant(1.0, shape=[num_labels]))

模型计算过程:

defmodel(data):

shape = hidden.get_shape().as_list()

reshape = tf.reshape(hidden, [shape[0], shape[1] * shape[2] * shape[3]])

return tf.matmul(hidden, layer4_weights) + layer4_biases

如果小伙伴们看过《基于Tensorflow搭建神经网络》这篇的话就会发现,计算过程的实现,跟DNN的差别并不大,就是参数换成了卷积核,每层网络的运算,由WX+b变成了卷积运算,激活函数仍然是Relu。看到这里,大家对CNN应该如何实现应该了然了,如何搭建更深层次的卷积网络,大家应该也心中有数了,哈哈,没错,我讲东西,靠的就是咱们的默契!其实搭建神经网络就像搭积木,一层层的加起来,真是想要多深就能有多深。

两层网络中卷积运算的stride为2,这意味着输出的宽高为输入的一半,因此两层卷积计算完成后的输出维度为(image_size // 4 , image_size // 4 , depth),这也就是全连接层的输入,因此layer3_weight的维度是(image_size // 4 *image_size // 4 * depth, num_hidden)。关于卷积运算的输入输出维度关系,可参考《卷积与转置卷积输入输出计算》中总结的公式,强烈建议大家收藏这篇文章,或者,把里面的英文版宝典下载下来,它一定是一本你会经常查阅的工具书。

下面给出模型定义完整代码:

将训练次数设为2W次,得到的结果和2层隐含层DNN训练结果进行对比:

卷积网络训练结果

2层隐含层网络训练结果

从结果我们可以看到,与2层隐含层网络相比,2层卷积网络有非常明显的提升。

卷积神经网络从AlexNet到现在,经历了爆发式的发展,非常多的网络被陆续提出,并有着越来越好的表现,模型也是朝着越来越深的方向发展。不过小伙伴们在选择模型时,还是建议结合具体任务和实际数据,从较浅较小的模型开始尝试,选择合适的模型,不必盲目追求模型深度或者复杂度,毕竟更复杂的网络意味着更多的参数、更复杂的计算、更长的训练时间和更高的过拟合可能。

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180513G0FRN800?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券