前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow 2.0 识别MNIST手写数字

TensorFlow 2.0 识别MNIST手写数字

作者头像
用户6021899
发布2019-12-23 22:02:43
1.4K0
发布2019-12-23 22:02:43
举报

TensorFlow 2.0 在 1.x版本上进行了大量改进,主要变化如下:

  • 以Eager模式为默认的运行模式,不必构建Session
  • 删除tf.contrib库,将其中的高阶API整合到tf.kears库下。
  • 将1.x版本中大量重复重叠的API进行合并精简

下面是TF2.0 入门demo, 训练集是MNIST。代码略有更改:

代码语言:javascript
复制
"""TF2.0 官方demo,略有更改,加了Dropout"""from __future__ import absolute_import, division, print_function, unicode_literalsimport tensorflow as tffrom tensorflow.keras.layers import Dense, Flatten, Conv2D,Dropoutfrom tensorflow.keras import Model#加载并准备 MNIST 数据集。mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0# Add a channels dimensionx_train = x_train[..., tf.newaxis]x_test = x_test[..., tf.newaxis]#使用 tf.data 来将数据集切分为 batch 以及混淆数据集:train_ds = tf.data.Dataset.from_tensor_slices(    (x_train, y_train)).shuffle(10000).batch(100) #打乱样本顺序,训练batchsize设为100test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)#使用 Keras 模型子类化(model subclassing) API 构建 tf.keras 模型:
class MyModel(Model):  def __init__(self):    super(MyModel, self).__init__()    self.conv1 = Conv2D(filters=64, kernel_size=(3,3), activation='relu')    #self.dropout1 = Dropout(0.2)    self.flatten = Flatten() # 展平    self.d1 = Dense(256, activation='relu')#全连接层    self.dropout2 = Dropout(0.5)    self.d2 = Dense(10, activation='softmax')      def call(self, x):    x = self.conv1(x)    #x = self.dropout1(x)    x = self.flatten(x)    x = self.d1(x)    x = self.dropout2(x)    return self.d2(x)    model = MyModel()#为训练选择优化器与损失函数:loss_object = tf.keras.losses.SparseCategoricalCrossentropy()optimizer = tf.keras.optimizers.Adam()
#选择衡量指标来度量模型的损失值(loss)和准确率(accuracy)。这些指标在 epoch 上累积值,然后打印出整体结果。train_loss = tf.keras.metrics.Mean(name='train_loss')train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')test_loss = tf.keras.metrics.Mean(name='test_loss')test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
#使用 tf.GradientTape 来训练模型:@tf.functiondef train_step(images, labels):  with tf.GradientTape() as tape:    predictions = model(images)    loss = loss_object(labels, predictions)  gradients = tape.gradient(loss, model.trainable_variables)  optimizer.apply_gradients(zip(gradients, model.trainable_variables))  train_loss(loss)  train_accuracy(labels, predictions)   #测试模型:@tf.functiondef test_step(images, labels):  predictions = model(images)  t_loss = loss_object(labels, predictions)  test_loss(t_loss)  test_accuracy(labels, predictions)    EPOCHS = 5for epoch in range(EPOCHS):  for images, labels in train_ds:    train_step(images, labels)  for test_images, test_labels in test_ds:    test_step(test_images, test_labels)  template = 'Epoch {}, Training Loss: {}, Training Accuracy: {}; Test Loss: {}, Test Accuracy: {}'  print (template.format(epoch+1,                         train_loss.result(),                         train_accuracy.result()*100,                         test_loss.result(),                         test_accuracy.result()*100))

5个Epoch后,测试集上的预测准确度达98.43%:

代码语言:javascript
复制
Epoch 1, Training Loss: 0.13766814768314362, Training Accuracy: 95.86332702636719; Test Loss: 0.05851453170180321, Test Accuracy: 98.0999984741211Epoch 2, Training Loss: 0.08793511241674423, Training Accuracy: 97.33499908447266; Test Loss: 0.05357005447149277, Test Accuracy: 98.25999450683594Epoch 3, Training Loss: 0.06439810246229172, Training Accuracy: 98.03555297851562; Test Loss: 0.05198133364319801, Test Accuracy: 98.32333374023438Epoch 4, Training Loss: 0.05080028995871544, Training Accuracy: 98.44750213623047; Test Loss: 0.05018126592040062, Test Accuracy: 98.39500427246094Epoch 5, Training Loss: 0.04206531494855881, Training Accuracy: 98.71266174316406; Test Loss: 0.05160188674926758, Test Accuracy: 98.43399810791016
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-12-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档