前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tensorflow入门:CNN for MNIST

tensorflow入门:CNN for MNIST

作者头像
Steve Wang
发布2019-05-26 15:03:57
4640
发布2019-05-26 15:03:57
举报
文章被收录于专栏:从流域到海域从流域到海域
在这里插入图片描述
在这里插入图片描述

使用tensorflow构建如上图所示的CNN用于对MNIST数据集进行softmax classification。

理论部分不再赘述,完整的代码如下:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# hyperparameter
learning_rate = 0.001
training_epoches = 20
batch_size = 100

class Model:
    
    def __init__(self, sess, name):
        self.sess = sess
        self.name = name
        self._build_net()
        
    def _build_net(self):
        # with tf.variable_scope(self.name):
        self.training = tf.placeholder(tf.bool)
        # input placeholder for X & Y
        self.X = tf.placeholder(tf.float32, [None, 784])
        self.Y = tf.placeholder(tf.float32, [None, 10])
        # img 28x28x1 (black/white)
        X_img = tf.reshape(self.X, [-1, 28, 28, 1])
            
        # convolutional layer 1 & pooling layer 1
        conv1 = tf.layers.conv2d(inputs=X_img, filters=32, kernel_size=[3, 3], 
                                 padding="SAME", activation=tf.nn.relu)
        pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2],
                                        padding="SAME", strides=2)
        dropout1 = tf.layers.dropout(inputs=pool1, rate=0.3, training=self.training)

        # convolutional layer 2 & pooling layer 2
        conv2 = tf.layers.conv2d(inputs=dropout1, filters=64, kernel_size=[3, 3],
                                padding="SAME", activation=tf.nn.relu)
        pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2],
                                       padding="SAME", strides=2)
        dropout2 = tf.layers.dropout(inputs=pool2, rate=0.3, training=self.training)

        # convolutional layer 3 & pooling layer 3
        conv3 = tf.layers.conv2d(inputs=dropout2, filters=128, kernel_size=[3, 3],
                                padding="SAME", activation=tf.nn.relu)
        pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2],
                                       padding="SAME", strides=2)
        dropout3 = tf.layers.dropout(inputs=pool3, rate=0.3, training=self.training)

        # dense layer with Relu
        flat = tf.reshape(dropout3, [-1, 128 * 4 * 4])
        dense4 = tf.layers.dense(inputs=flat, units=625, activation=tf.nn.relu)
        dropout4 = tf.layers.dropout(inputs=dense4, rate=0.5, training=self.training)

        # FC layer 625 input -> 10 output, no activation function
        self.logits = tf.layers.dense(inputs=dropout4, units=10)

        # define loss & optimizer
        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                                   logits = self.logits, labels=self.Y))
        self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.cost)

        # accuracy
        correct_prediction = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.Y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        
    def train(self, x_data, y_data, training=True):
        return self.sess.run([self.cost, self.optimizer], 
                             feed_dict={self.X: x_data, self.Y: y_data, self.training: training})
    
    def predict(self, x_test, training=False):
        return self.sess.run(self.logits, 
                             feed_dict={self.X :x_test, self.training: training})

    def get_accuracy(self, x_test, y_test, training=False):
        return self.sess.run(self.accuracy, 
                             feed_dict={self.X: x_test,self.Y: y_test, self.training: training})
    

# train the models
with tf.Session() as sess:
    models = []
    num_models = 2
    
    for m in range(num_models):
        models.append(Model(sess, "modal"+str(m)))
        
    sess.run(tf.global_variables_initializer())
    
    print('Learning Start!')
    
    for epoch in range(training_epoches):
        avg_cost_list = np.zeros(len(models))
        total_batch = int(mnist.train.num_examples / batch_size)
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            
            # train each modal
            for m_id, m in enumerate(models):
                
                c, _ = m.train(batch_xs, batch_ys)
                avg_cost_list[m_id] += c / total_batch
        
        print('Epoch: ', '%04d' %(epoch+1), 'cost=', avg_cost_list)
        
    print('Learning finished!')

    # test & accuracy                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
    test_size = len(mnist.test.labels)
    predictions = np.zeros([test_size, 10])

    for m_id,m in enumerate(models):
        print(m_id, "Accuracy:", m.get_accuracy(mnist.test.images, mnist.test.labels))
        p = m.predict(mnist.test.images)
        predictions += p

    ensemble_correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.argmax(mnist.test.labels, 1))
    ensemble_accuracy = tf.reduce_mean(tf.cast(ensemble_correct_prediction, tf.float32))
    print("Ensemble_accuracy:", sess.run(ensemble_accuracy))

结果:

代码语言:javascript
复制
Learning Start!
Epoch:  0001 cost= [0.29211415 0.28355632]
Epoch:  0002 cost= [0.08716567 0.0870499 ]
Epoch:  0003 cost= [0.06902521 0.06623169]
Epoch:  0004 cost= [0.05563359 0.05452387]
Epoch:  0005 cost= [0.04963774 0.04871382]
Epoch:  0006 cost= [0.04462749 0.04449957]
Epoch:  0007 cost= [0.04132144 0.03907955]
Epoch:  0008 cost= [0.03792324 0.03861412]
Epoch:  0009 cost= [0.0354344  0.03323769]
Epoch:  0010 cost= [0.03516847 0.03405525]
Epoch:  0011 cost= [0.03143759 0.03219781]
Epoch:  0012 cost= [0.03051504 0.02993162]
Epoch:  0013 cost= [0.02906878 0.02711077]
Epoch:  0014 cost= [0.02729127 0.02754832]
Epoch:  0015 cost= [0.02729633 0.02632647]
Epoch:  0016 cost= [0.02438517 0.02701174]
Epoch:  0017 cost= [0.02482958 0.0244114 ]
Epoch:  0018 cost= [0.02455271 0.02649499]
Epoch:  0019 cost= [0.02371975 0.02178147]
Epoch:  0020 cost= [0.02260135 0.0213784 ]
Learning finished!
0 Accuracy: 0.995
1 Accuracy: 0.9949
Ensemble_accuracy: 0.9954

结果前面的其实有很长的warning,这里没有给出。warning是说新版本的tensorflow把mnist数据集移动到了别的地方,建议你从别的地方导入进来。这篇博文仅做例子。实际使用tensorflow的时候,你都是自己写读取数据的函数什么的,需要根据数据集的存储格式写不同的Python代码。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年02月16日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档