前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR-10数据集 图像识别

CIFAR-10数据集 图像识别

作者头像
用户6021899
发布2019-12-25 14:42:15
1.2K0
发布2019-12-25 14:42:15
举报

之前我是在CPU上跑Tensorflow,计算速度着实让人捉急。最近更新了显卡驱动,安装了CUDA和 GPU版的TensorFlow,同样的神经网络结构,学习速度有了百倍提升。

下面言归正传,我们来讲代码。本篇我们还是用序列化的(串行的)卷积神经网络,基于CIFAR-10数据集创建图像识别模型。由于我的GTX750Ti连入门级显卡都算不上,因此仅仅用了3个卷积层+1个池化层+两个全连接层(中间还加了两个Dropout以避免过拟合)。

代码语言:javascript
复制
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 10 20:04:58 2019
@author: wsp
Tensorflow version:2.0
Python version 3.7
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from load_dataset import train_dataset, train_labels, valid_dataset, valid_labels
from matplotlib import pyplot as plt

cifar10 = tf.keras.datasets.cifar10
#(x_train, y_train), (x_test, y_test) = cifar10.load_data() #从网络下载数据集
x_train, y_train = train_dataset, train_labels
x_test, y_test= valid_dataset, valid_labels
x_train, x_test = x_train / 255.0, x_test / 255.0

tf.keras.models.Sequential()用于创建序列化的神经网络模型。

tf.keras.layers.Flatten() 用于将tensor展平,展平后才能做全连接层的input。

tf.keras.layers.Dense()用于创建全连接层。

代码语言:javascript
复制
model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(input_shape=(32,32,3),filters= 16,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu'),
  tf.keras.layers.Conv2D(filters= 32,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu'),
  tf.keras.layers.Conv2D(filters= 64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu'), 
  tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=(1,1)),
  tf.keras.layers.Dropout(0.25),

  #tf.keras.layers.Conv2D(filters= 32,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu'),
  #tf.keras.layers.Conv2D(filters= 64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu'),
  #tf.keras.layers.Conv2D(filters= 64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu'),
  #tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=(1,1)),
  #tf.keras.layers.Dropout(0.25),

  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(500, activation='relu'),
  tf.keras.layers.Dropout(0.25),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile()用于配置模型的训练流程。tf.keras.Model.compile 采用三个重要参数:

  • optimizer:此对象会指定训练过程。从tf.train模块向其传递优化器实例,例如AdamOptimizer、RMSPropOptimizer或GradientDescentOptimizer。
  • loss:要在优化期间最小化的函数。常见选择包括均方误差(mse)、categorical_crossentropy 和 binary_crossentropy。损失函数由名称或通过从 tf.keras.losses 模块传递可调用对象来指定。
  • metrics:用于监控训练。它们是 tf.keras.metrics 模块中的字符串名称或可调用对象。
代码语言:javascript
复制
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

tf.keras.Model.fit()进行测试试数据与模型的拟合

model.fit(data,labels,epochs=20,batch_size=500,validation_data=(val_data, val_labels))

代码语言:javascript
复制
history = model.fit(x_train, y_train, epochs=50) # 训练50轮

tf.keras.Model.evaluate() 用于评估模型

verbose:日志显示 verbose = 0 为不在标准输出流输出日志信息 verbose = 1 为输出进度条记录 verbose = 2 为每个epoch输出一行记录

代码语言:javascript
复制
model.evaluate(x_test,  y_test, verbose=2)

保存模型:

代码语言:javascript
复制
model.save('my_cifar10_model.h5')

下面这段的作用仅仅是绘制出Loss和预测准确度曲线:

代码语言:javascript
复制
plt.subplot(121)
accuracy = history.history['accuracy']
plt.plot(range(1,len(accuracy)+1), accuracy)
#plt.plot(history.history['val_acc'])
plt.title('train set accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.subplot(122)
plt.plot(range(1,len(accuracy)+1), history.history['loss'])
#plt.plot(history.history['val_acc'])
plt.title('train set loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

可以看出,训练完后模型在验证集上的预测准确度高达98%。

下面我们可以使用已经保存好的模型来预测从网上下载的图片的分类:

代码语言:javascript
复制
# -*- coding: utf-8 -*-
"""

@author: wsp
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
#from load_dataset import train_dataset, train_labels, valid_dataset, valid_labels,names
from matplotlib import pyplot as plt
#x_train, y_train = train_dataset, train_labels
#x_test, y_test= valid_dataset, valid_labels
#x_train, x_test = x_train / 255.0, x_test / 255.0

#加载模型
new_model = tf.keras.models.load_model('my_cifar10_model.h5')
#利用加载后的模型对整个验证集做预测
#预测一组样本
#y_ = new_model.predict(x_test) 
#result = tf.argmax(y_, 1) #
#print(result)
#for i in range(5):
    #index = int(result[i])
    #print("验证集第%d张图片的分类索引是 %d:"% (i, index))
    #print("分类名称是:%s "% names[index])
#print()

#预测单张图片
def resize(img_path): 
    '''将图片resize为 32x32x3'''
    image = plt.imread(img_path)
    resized = tf.image.resize(image,[32,32],method='bilinear')
    #return tf.cast(resized, tf.uint8)
    return resized
 
from load_dataset import names
my_img = resize('4.jpg')/255.0 #要符合xinput的格式
my_input = tf.reshape(my_img,(1,32,32,3))
result = tf.argmax(new_model.predict(my_input) , 1)
print("分类名称是:%s "% names[int(result[0])])

分类名称是:b'bird'

分类名称是:b'horse'

分类名称是:b'cat'

还行!

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-12-12,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档