前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >卷积神经网络处理图像识别(三)

卷积神经网络处理图像识别(三)

作者头像
用户6021899
发布2019-11-25 16:47:41
7990
发布2019-11-25 16:47:41
举报

本篇接着上一篇来介绍卷积神经网络的训练(即反向传播)和应用。

训练神经网络和保存训练结果的代码如下:

代码语言:javascript
复制
import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import os
import numpy as np
import CNN_MNIST_inference

MODEL_SAVE_PATH ="E:/Python36/my tensorflow/CNN/model_path/"
MODEL_NAME = "MNIST_CNNmodel.ckpt"
print(os.path.join(MODEL_SAVE_PATH, MODEL_NAME))
BATCH_SIZE  =100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
MOVING_AVERAGE_DECAY = 0.99
TRAINING_STEPS = 20000

def train(mnist):
    '''training'''
    x = tf.placeholder(tf.float32,
                       [None,
                        CNN_MNIST_inference.IMAGE_HEIGHT,
                        CNN_MNIST_inference.IMAGE_WIDTH,
                        CNN_MNIST_inference.NUM_CHANNELS], name='x-input')
    y_ = tf.placeholder(tf.float32, [None, CNN_MNIST_inference.OUTPUT_NODE], name = 'y-input')
    #I2 正则
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    y = CNN_MNIST_inference.inference(x, True, regularizer, None, reuse = False)
    global_step = tf.Variable(0, trainable = False)
    #平均移动
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables()) # moving average applied
    average_y = CNN_MNIST_inference.inference(x, True, regularizer,variable_averages, reuse = True)
    
    # loss
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, labels = tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    tf.add_to_collection('losses', cross_entropy_mean)
    loss = tf.add_n(tf.get_collection('losses'))
    #loss = cross_entropy_mean
    
    #learning rate with decay
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY, staircase = True)
    #learning_rate = 0.01
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)
    train_op = tf.group(train_step, variables_averages_op)
    correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    saver = tf.train.Saver() #初始化持久类
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run() # 真正变量初始化
        
        validation_set  = np.reshape(mnist.validation.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        validate_feed  = {x: validation_set, y_ : mnist.validation.labels} #验证集

        test_set  = np.reshape(mnist.test.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        test_feed        = {x: test_set, y_ : mnist.test.labels} #测试集(训练集)
        
        steps = [] # only for plot
        accs = [] # only for plot
        losses = [] # only for plot
        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            xs = np.reshape(xs,
                            [BATCH_SIZE,
                             CNN_MNIST_inference.IMAGE_HEIGHT,
                             CNN_MNIST_inference.IMAGE_WIDTH,
                             CNN_MNIST_inference.NUM_CHANNELS])
                                
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict = {x : xs, y_: ys})

            #print(i,loss_value)
            if i % 25  == 0:
                validate_acc = sess.run(accuracy, feed_dict = validate_feed) #验证集 准确度
                steps.append(step); accs.append(validate_acc*100); losses.append(loss_value) # only for plot
                print("After %d training steps, validation dataset accuracy after this batch is %g%%, test dataset loss on this batch is %g"%(step, validate_acc*100,loss_value))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)
                
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)    
        test_acc = sess.run(accuracy, feed_dict = test_feed)
        print("After %d training steps, test accuracy using average model is %g%%"%
              (TRAINING_STEPS, test_acc*100))
        writer = tf.summary.FileWriter("E://TensorBoard//test",sess.graph)
        
        saver.save(sess, r"E:\Python36\my tensorflow\ckpt files\mode_mnist.ckpt")
    #only for plot
    from matplotlib import pyplot as plt
    import matplotlib.ticker as mtick
    plt.subplot(211)
    plt.plot(steps, losses,color="red")
    plt.scatter(steps, losses,s=20,color="red")
    plt.xlabel("训练的步数(Batch数)"); plt.ylabel("训练batch上的Loss(含L2正则Loss)")
    plt.subplot(212)
    plt.plot(steps, accs,color="green")
    plt.scatter(steps, accs,s=20,color="green")
    yticks = mtick.FormatStrFormatter("%.3f%%")
    plt.gca().yaxis.set_major_formatter(yticks)
    plt.xlabel("step"); plt.ylabel("验证集上的预测准确率")
    plt.show()
 
def main(argv = None):
    mnist = input_data.read_data_sets(r"E:\Python36\my tensorflow\MNIST_data",one_hot =True)
    train(mnist)

if __name__ == "__main__":
    tf.app.run() #调用main()

下面是测试Batch的总Loss和验证集上的准确率的收敛趋势图。由于我的电脑性能不好,所以我大幅度削减了待训练参数个数。尽管如此,2000轮训练之后,在验证集上5000个图片的预测正确率已达98.3%。如若不削减参数,准确率可达99.4%。

下面的代码是利用训练好的卷积神经网络模型来评估它在验证集上的准确率(可以在正式训练时不评估从而节省训练时间),以及用它用来识别单张图片。

代码语言:javascript
复制
import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import os
import numpy as np
import CNN_MNIST_inference
import CNN_MNIST_train
import matplotlib.pyplot as plt

def evaluate(mnist):   #评估验证集的预测准确度
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32,
                           [None,
                            CNN_MNIST_inference.IMAGE_HEIGHT,
                            CNN_MNIST_inference.IMAGE_WIDTH,
                            CNN_MNIST_inference.NUM_CHANNELS], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, CNN_MNIST_inference.OUTPUT_NODE], name = 'y-input')
        validation_set  = np.reshape(mnist.validation.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        validate_feed  = {x: validation_set, y_ : mnist.validation.labels} #验证集
        
        y = CNN_MNIST_inference.inference(x, False, None, None, reuse = False)
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            #平均移动
        variable_averages = tf.train.ExponentialMovingAverage(CNN_MNIST_train.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
        with tf.Session() as sess:
            #print(CNN_MNIST_train.MODEL_SAVE_PATH)
            #找到目录中最新的模型文件
            ckpt = tf.train.get_checkpoint_state(CNN_MNIST_train.MODEL_SAVE_PATH)
            #print(ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                #加载模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                #模型的迭代轮数
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split("-")[-1]
                accuary_score = sess.run(accuracy, feed_dict =validate_feed)
                print("After %s training steps, validation accuary = %g" %(global_step, accuary_score)) #global_step是str
            else:
                print('No checkpoint file found')
                return
                
 #把所有输入数据input_data、声明的常量放进with tf.Graph().as_default(): 里面就行了,就可以统一到同一个graph了,
#不然input_data是放到系统默认创建的Graph,跟你又重新with tf.Graph().as_default():不是同一个Graph()就会报错           
def recognize(input_x):
    g = tf.get_default_graph() # 因为 input_x 默认的图中,所以可把下面的计算也默认的图中
    with g.as_default():
        y = CNN_MNIST_inference.inference(input_x, False, None, None, reuse = False)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            #找到目录中最新的模型文件
            ckpt = tf.train.get_checkpoint_state(CNN_MNIST_train.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                #加载模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                predicted_label = tf.argmax(y, 1)
                print("predicted_label: ", sess.run(predicted_label)[0])
            else:
                print('No checkpoint file found')
                return
                
def plotImage(path):#仅用于绘制待识别的图片
    image_rawdata = tf.gfile.FastGFile(path,"rb").read()
    img_data = tf.image.decode_jpeg(image_rawdata)
    if img_data.dtype != tf.float32:
        img_data = tf.image.convert_image_dtype(img_data, dtype = tf.float32)
    with tf.Session() as sess:
        image_data = img_data.eval() # return a numpy array#需要运行在会话中
    image_data_shaped1 = image_data.reshape(image_data.shape[0],image_data.shape[1])#numpy array
    #print(image_data_shaped1)
    plt.imshow(image_data_shaped1,cmap='gray')
    plt.show()
    
def main(argv=None): 
    mnist = input_data.read_data_sets(r"E:\Python36\my tensorflow\MNIST_data",one_hot =True)
    evaluate(mnist) #评估在验证集上的预测准确度
    #输入
    image_path = r"E:\Python36\MNIST picture\test\50.jpg"
    image_rawdata = tf.gfile.FastGFile(image_path,"rb").read()
    img_data0 = tf.image.decode_jpeg(image_rawdata)
    if img_data0.dtype != tf.float32:
        img_data = tf.image.convert_image_dtype(img_data0, dtype = tf.float32)
    
    #根据神经网络的要求转换图片数据的shape!    
    input_x =  tf.reshape(img_data, [1,
                                    CNN_MNIST_inference.IMAGE_HEIGHT,
                                    CNN_MNIST_inference.IMAGE_WIDTH,
                                    CNN_MNIST_inference.NUM_CHANNELS])
    plotImage(image_path)
    recognize(input_x)

if __name__ =="__main__":
    #tf.app.run() #调用main()
    main()#
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-11-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档