前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow Keras:mnist分类demo

Tensorflow Keras:mnist分类demo

原创
作者头像
千万别过来
发布2023-06-26 13:42:42
4690
发布2023-06-26 13:42:42
举报
文章被收录于专栏:推荐算法学习推荐算法学习

tf2集成的keras非常好用,对一些简单的模型可以快速搭建,下面以经典mnist数据集为例,做一个demo,展示一些常用的方法

1 导入包并查看版本号

代码语言:javascript
复制
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

获取数据集并归一化

这里如果不做归一化模型会不收敛,用的sklearn的归一化

这里注意:

  1. fit_transform指的是训练数据用的归一化,会记录下均值和方差
  2. transform指的是测试集和验证集用训练集保存下来的方差和均值来做归一化
  3. 归一化时候要做除法运算,所以先用astype(np.float32)转换成浮点
  4. 接着归一化的时候需要二维的输入,这里是三维,所以用reshape:x_train: [None, 28, 28] -> [None, 784]
  5. 归一化完了之后要再变回来,所以再用一个reshape
代码语言:javascript
复制
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
# x = (x - u) / std
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
# x_train: [None, 28, 28] -> [None, 784]
x_train_scaled = scaler.fit_transform(
    x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_valid_scaled = scaler.transform(
    x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_test_scaled = scaler.transform(
    x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
代码语言:javascript
复制
(5000, 28, 28) (5000,)
(55000, 28, 28) (55000,)
(10000, 28, 28) (10000,)

构建模型

  1. Sequential()构建模型,有两种构建方法,一种被注释了。
  2. 由于输入的时候是28x28的图片,所以在输入层需要一个Flatten拉平
  3. loss使用的是sparse_categorical_crossentropy,他可以自动把类别变成one-hot形式的概率分布,如果标签已经是概率分布,那就用categorical_crossentropy
  4. 优化器还有adam之类的,直接给名字就行,具体见官方api
  5. metrics还有mes之类的,具体见官方api
代码语言:javascript
复制
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3' # 使用 GPU 3
# tf.keras.models.Sequential()
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28, 28]))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))
# model = keras.models.Sequential([
#     keras.layers.Flatten(input_shape=[28, 28]),
#     keras.layers.Dense(300, activation='relu'),
#     keras.layers.Dense(100, activation='relu'),
#     keras.layers.Dense(10, activation='softmax')
# ])
# relu: y = max(0, x)
# softmax: 将向量变成概率分布. x = [x1, x2, x3],
#          y = [e^x1/sum, e^x2/sum, e^x3/sum], sum = e^x1 + e^x2 + e^x3
# reason for sparse: y->index. y->one_hot->[]
model.compile(loss="sparse_categorical_crossentropy",
              optimizer = "sgd",
              metrics = ["acc"])

训练模型

  • 注意2.0和2.+的输出日志有一点不同,2.+版本后默认batchsize是32
  • 和sklearn很像,使用fit函数,返回一个history里面有相关历史信息
  • callbacks是回调函数,有很多种,这里只举3个例子,剩下的可以看api。使用的时候在fit里面增加一个callbacks参数,并以list的形式传入
  • Tensorboard需要一个目录
  • ModelCheckpoint需要保存的文件目录,后缀名是h5好像也可以说ckpt,h5便于移植caffe或keras。save_best_only保存最好的模型,不加这个默认保存是最近的一个模型
  • EarlyStopping提前终止,patience是可以保持多看几步的耐心,具体见api;min_delta是停止的阈值。可以看出来我设置的是30epoch,在20epoch的时候就earlystopping了
代码语言:javascript
复制
# Tensorboard, earlystopping, ModelCheckpoint
logdir = './callbacks'
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir,
                                 "fashion_mnist_model.h5")
callbacks = [
    keras.callbacks.TensorBoard(logdir),
    keras.callbacks.ModelCheckpoint(output_model_file,
                                    save_best_only = True),
    keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]
history = model.fit(x_train_scaled, y_train, epochs=30,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

输出日志:

代码语言:javascript
复制
Epoch 1/30
1/1719 [..............................] - ETA: 0s - loss: 3.0171 - acc: 0.0000e+00WARNING:tensorflow:From /data1/home/zly/anaconda3/envs/tf2.3/lib/python3.7/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use tf.profiler.experimental.stop instead.
WARNING:tensorflow:Callbacks method on_train_batch_end is slow compared to the batch time (batch time: 0.0027s vs on_train_batch_end time: 0.0207s). Check your callbacks.
60/1719 [>.............................] - ETA: 4s - loss: 1.2842 - acc: 0.5719

1719/1719 [==============================] - 4s 2ms/step - loss: 0.5355 - acc: 0.8089 - val_loss: 0.4270 - val_acc: 0.8524
Epoch 2/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.3914 - acc: 0.8575 - val_loss: 0.3716 - val_acc: 0.8652
Epoch 3/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.3516 - acc: 0.8746 - val_loss: 0.3620 - val_acc: 0.8680
Epoch 4/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.3261 - acc: 0.8819 - val_loss: 0.3487 - val_acc: 0.8736
Epoch 5/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.3078 - acc: 0.8892 - val_loss: 0.3266 - val_acc: 0.8856
Epoch 6/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2926 - acc: 0.8930 - val_loss: 0.3133 - val_acc: 0.8812
Epoch 7/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2790 - acc: 0.8979 - val_loss: 0.3315 - val_acc: 0.8774
Epoch 8/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2669 - acc: 0.9030 - val_loss: 0.3103 - val_acc: 0.8900
Epoch 9/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2563 - acc: 0.9064 - val_loss: 0.3039 - val_acc: 0.8900
Epoch 10/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2461 - acc: 0.9108 - val_loss: 0.3175 - val_acc: 0.8836
Epoch 11/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2366 - acc: 0.9135 - val_loss: 0.3059 - val_acc: 0.8894
Epoch 12/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2276 - acc: 0.9162 - val_loss: 0.3144 - val_acc: 0.8846
Epoch 13/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2209 - acc: 0.9196 - val_loss: 0.3020 - val_acc: 0.8900
Epoch 14/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2113 - acc: 0.9239 - val_loss: 0.3216 - val_acc: 0.8854
Epoch 15/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.2043 - acc: 0.9265 - val_loss: 0.2941 - val_acc: 0.8926
Epoch 16/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1967 - acc: 0.9297 - val_loss: 0.3036 - val_acc: 0.8920
Epoch 17/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1892 - acc: 0.9328 - val_loss: 0.3082 - val_acc: 0.8894
Epoch 18/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1835 - acc: 0.9336 - val_loss: 0.2951 - val_acc: 0.8936
Epoch 19/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1771 - acc: 0.9363 - val_loss: 0.3003 - val_acc: 0.8956
Epoch 20/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.1704 - acc: 0.9385 - val_loss: 0.3261 - val_acc: 0.8854

接着查看历史信息:

代码语言:javascript
复制
# 查看历史信息
history.history

输出:

代码语言:javascript
复制
{'loss': [0.5354679226875305,
0.39141619205474854,
0.3516489565372467,
0.32609444856643677,
0.3078019320964813,
0.2926427125930786,
0.27895301580429077,
0.2669494152069092,
0.2563493847846985,
0.24608077108860016,
0.23657047748565674,
0.2275625765323639,
0.2209150642156601,
0.21127846837043762,
0.20427322387695312,
0.19672366976737976,
0.1892261505126953,
0.1835436224937439,
0.1771387904882431,
0.17038710415363312],
'acc': [0.8089091181755066,
0.8575454354286194,
0.8745636343955994,
0.8818908929824829,
0.889163613319397,
0.8929636478424072,
0.8978727459907532,
0.902999997138977,
0.9064363837242126,
0.9108181595802307,
0.913527250289917,
0.9161636233329773,
0.9196363687515259,
0.923909068107605,
0.9265090823173523,
0.9297090768814087,
0.9327636361122131,
0.9336363673210144,
0.9363454580307007,
0.9385091066360474],
'val_loss': [0.4270329475402832,
0.3716042935848236,
0.3619808852672577,
0.34866154193878174,
0.32663166522979736,
0.31333956122398376,
0.3315422832965851,
0.3103165328502655,
0.3038505017757416,
0.3175022304058075,
0.30592072010040283,
0.3144492208957672,
0.3020140528678894,
0.32157447934150696,
0.2940865457057953,
0.3035639524459839,
0.30824777483940125,
0.29505348205566406,
0.3002834916114807,
0.326121985912323],
'val_acc': [0.852400004863739,
0.8651999831199646,
0.8679999709129333,
0.8736000061035156,
0.8855999708175659,
0.8812000155448914,
0.8773999810218811,
0.8899999856948853,
0.8899999856948853,
0.8835999965667725,
0.8894000053405762,
0.8845999836921692,
0.8899999856948853,
0.8853999972343445,
0.8925999999046326,
0.8920000195503235,
0.8894000053405762,
0.8935999870300293,
0.8956000208854675,
0.8853999972343445]}

画出loss、acc图

根据history,先转换成dataframe,再画出图

代码语言:javascript
复制
def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8, 5))
    plt.grid(True)# 使用网格
    plt.gca().set_ylim(0, 1)# 设置y坐标轴范围
    plt.show()
plot_learning_curves(history)

测试模型

代码语言:javascript
复制
model.evaluate(x_test_scaled, y_test)

输出:

代码语言:javascript
复制
1/313 [..............................] - ETA: 0s - loss: 0.3203 - acc: 0.8750WARNING:tensorflow:Callbacks method on_test_batch_end is slow compared to the batch time (batch time: 0.0010s vs on_test_batch_end time: 0.0038s). Check your callbacks.
313/313 [==============================] - 0s 2ms/step - loss: 0.3541 - acc: 0.8774A: 0s - loss: 0.3632 - acc:
[0.35407549142837524, 0.8773999810218811]

Tensorboard样式:

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 导入包并查看版本号
  • 2 获取数据集并归一化
  • 3 构建模型
  • 4 训练模型
  • 5 画出loss、acc图
  • 6 测试模型
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档