前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow (2) CIFAR-10 简单图像识别

TensorFlow (2) CIFAR-10 简单图像识别

作者头像
coding01
发布2021-02-24 11:17:54
1.3K0
发布2021-02-24 11:17:54
举报
文章被收录于专栏:Coding01Coding01
本文 2345.5字,需要 5.86 分钟

CIFAR-10 图像识别

本文主要学习获取 CIFAR-10 数据集,通过简单的模型对数据集进行训练和识别。

下载数据集

和之前一样,使用 http_proxy 代理:

代码语言:javascript
复制
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow_datasets as tfds
import os
import ssl

os.environ['HTTP_PROXY'] = 'http://0.0.0.0:8888'
os.environ['HTTPS_PROXY'] = 'http://0.0.0.0:8888'
ssl._create_default_https_context =  ssl._create_unverified_context

下载 CIFAR-10 数据集[1]:

代码语言:javascript
复制
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

返回值 Tuple of Numpy arrays: (x_train, y_train), (x_test, y_test). x_train, x_test: uint8 arrays of RGB image data with shape (num_samples, 3, 32, 32) if tf.keras.backend.image_data_format() is 'channels_first', or (num_samples, 32, 32, 3) if the data format is 'channels_last'. y_train, y_test: uint8 arrays of category labels (integers in range 0-9) each with shape (num_samples, 1).

数据集中包含 50000 张 32*32 的彩色图片和这些图相对应的类别标签的训练集,10000 张测试图像。

其中,x_train 是训练集图片,y_train 是对应的标签,x_test 是测试集图片,y_test 是测试集对应的标签。

读取图片

我们将测试集的前 16 张图片和类名打印出来,来确保数据集被正确加载。

代码语言:javascript
复制
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(10,10))
for i in range(16):
    plt.subplot(4,4,i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i], cmap=plt.cm.binary)
    # 由于 CIFAR 的标签是 array, 
    # 因此您需要额外的索引(index)。
    plt.xlabel(class_names[train_labels[i][0]])
plt.show()

验证了数据的真确性。下面就开始模式训练和模式测试了。

模式训练

建立模型:

代码语言:javascript
复制
# 将像素的值标准化至0到1的区间内。
x_train, x_test = x_train / 255.0, x_test / 255.0

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))

model.summary()
代码语言:javascript
复制
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 64)                65600     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 122,570
Trainable params: 122,570
Non-trainable params: 0
_________________________________________________________________

编译并训练模型:

代码语言:javascript
复制
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=10, 
                    validation_data=(x_test, y_test))

训练过程:

代码语言:javascript
复制
Epoch 1/10
1563/1563 [==============================] - 57s 36ms/step - loss: 1.5198 - accuracy: 0.4505 - val_loss: 1.2727 - val_accuracy: 0.5481
Epoch 2/10
1563/1563 [==============================] - 55s 35ms/step - loss: 1.1464 - accuracy: 0.5947 - val_loss: 1.0817 - val_accuracy: 0.6183
Epoch 3/10
1563/1563 [==============================] - 62s 40ms/step - loss: 0.9908 - accuracy: 0.6552 - val_loss: 0.9716 - val_accuracy: 0.6606
Epoch 4/10
1563/1563 [==============================] - 58s 37ms/step - loss: 0.8882 - accuracy: 0.6899 - val_loss: 0.9284 - val_accuracy: 0.6737
Epoch 5/10
1563/1563 [==============================] - 54s 35ms/step - loss: 0.8148 - accuracy: 0.7145 - val_loss: 0.9531 - val_accuracy: 0.6779
Epoch 6/10
1563/1563 [==============================] - 53s 34ms/step - loss: 0.7595 - accuracy: 0.7333 - val_loss: 0.8891 - val_accuracy: 0.6948
Epoch 7/10
1563/1563 [==============================] - 48s 31ms/step - loss: 0.7137 - accuracy: 0.7508 - val_loss: 0.8636 - val_accuracy: 0.7020
Epoch 8/10
1563/1563 [==============================] - 44s 28ms/step - loss: 0.6681 - accuracy: 0.7656 - val_loss: 0.8790 - val_accuracy: 0.7027
Epoch 9/10
1563/1563 [==============================] - 41s 26ms/step - loss: 0.6332 - accuracy: 0.7779 - val_loss: 0.8627 - val_accuracy: 0.7145
Epoch 10/10
1563/1563 [==============================] - 42s 27ms/step - loss: 0.5896 - accuracy: 0.7920 - val_loss: 0.8551 - val_accuracy: 0.7114

评估模型

代码语言:javascript
复制
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)

至此,基本完成模式建立和训练,并对模型进行评估,简简单单的几行代码就能让识别率在 70% 左右。

总结

这识别率还是可以再提高的,随着下一阶段的学习,把增强型模型做好,以提高识别率。

参考

[1] CIFAR-10 数据集 https://www.tensorflow.org/datasets/catalog/cifar10

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-02-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 coding01 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档