建议阅读时长 10分钟
TensorFlow 是谷歌基于 DistBelief 进行研发的第二代人工智能学习系统,自 2015 年问世,并在去年 11 月迎来三周岁生日,已经发展为世界上最受欢迎和被广泛采用的机器学习平台之一。
TensorFlow 等于 Tensor+Flow,即数据 + 流动。
本人最开始接触 TF1.0 ,觉得它非常抽象,对新手不是很友好,故而转向 keras 。但是,单纯的学习 keras 并不能胜任更多的工作。 以下是从拉钩网中几家公司的招聘要求,基本上所有的公司都会要求使用 TF、caffe、pytorch 等深度学习框架,但是以这三者居多。
给几个我选择 TF 的几个理由,确切的说 TF2.0:
不管你之前有没有接触过或者 TF 或者其他深度学习框架,都无关紧要,当然,有其他框架的使用经历可能会有更好的理解。在这个教程中,我不会去和之前的版本进行比较,因为我没有使用过的经历,我更专注的是最新版 TF 的使用
在学习的过程中,肯定会遇到相当多的问题,但是,坚持下去总能收获,欢迎有同样爱好,或者在学习 TF2.0 的开发者联系我,一起学习。因为,一群人能走得更远!
在 TensorFlow2.0 中,Keras 是一个用于构建和训练深度学习模型的高阶 API。以下将介绍 keras 中的几个常用模块。
常用的几个模块使用
如何使用:
1# 导入包
2from tensorflow import keras
3# 定义数据集对象
4mnist = keras.datasets.mnist
5boston_housing = keras.datasets.boston_housing
6cifar10 = keras.datasets.cifar10
7cifar100= keras.datasets.cifar100
8fashion_mnist =keras.datasets.fashion_mnist
9imdb= keras.datasets.imdb
10reuters= keras.datasets.reuters
11
12# 导入数据,其他类似
13(x_train, y_train), (x_test, y_test) = mnist.load_data()
fashion_mnist 是 mnist 的一个升级版,有人曾调侃道:"如果一个算法在 MNIST 不 work,那么它就根本没法用;而如果它在 MNIST 上 work,它在其他数据上也可能不 work"。所有 fashion_mnist 是一个好的替代品。
与神经网络相关的层,包括卷积层、池化层、全连接层、上采样层,你会发现,这些和 keras 框架中的方法类似
介绍一个入门的例程,称作机器学习的 Hello World 的 Mnist 手写字符识别
1# Python 兼容的一些包,更多可以参考:https://docs.python.org/2/library/__future__.html
2
3# Install TensorFlow
4# !pip install tensorflow==2.0.0-alpha0 # 未安装的需要安装 TF2.0 这里已经安装了
5import tensorflow as tf
6
7# 导入数据集,并进行归一化
8mnist = tf.keras.datasets.mnist
9
10(x_train, y_train), (x_test, y_test) = mnist.load_data()
11x_train, x_test = x_train / 255.0, x_test / 255.0
12
13# 构建模型,这里没有用到卷积
14model = tf.keras.models.Sequential([
15 tf.keras.layers.Flatten(input_shape=(28, 28)),
16 tf.keras.layers.Dense(128, activation='relu'),
17 tf.keras.layers.Dropout(0.2),
18 tf.keras.layers.Dense(10, activation='softmax')
19])
20
21model.summary()
22model.compile(optimizer='adam',
23 loss='sparse_categorical_crossentropy',
24 metrics=['accuracy'])
25
26"""输出:
27Model: "sequential"
28_________________________________________________________________
29Layer (type) Output Shape Param #
30=================================================================
31flatten (Flatten) (None, 784) 0
32_________________________________________________________________
33dense (Dense) (None, 128) 100480
34_________________________________________________________________
35dropout (Dropout) (None, 128) 0
36_________________________________________________________________
37dense_1 (Dense) (None, 10) 1290
38=================================================================
39Total params: 101,770
40Trainable params: 101,770
41Non-trainable params: 0
42_________________________________________________________________
43"""
1# 训练模型
2model.fit(x_train, y_train, epochs=5)
3
4"""训练结果
5Epoch 1/5
660000/60000 [==============================] - 9s 156us/sample - loss: 0.2956 - accuracy: 0.9143
7Epoch 2/5
860000/60000 [==============================] - 9s 147us/sample - loss: 0.1405 - accuracy: 0.9585
9Epoch 3/5
1060000/60000 [==============================] - 9s 154us/sample - loss: 0.1050 - accuracy: 0.9678
11Epoch 4/5
1260000/60000 [==============================] - 9s 153us/sample - loss: 0.0867 - accuracy: 0.9727
13Epoch 5/5
1460000/60000 [==============================] - 9s 156us/sample - loss: 0.0737 - accuracy: 0.9764
15"""
1# 模型评估
2model.evaluate(x_test, y_test)
3
4"""预测效果
510000/10000 [==============================] - 1s 99us/sample - loss: 0.0706 - accuracy: 0.9763
6
7[0.07059373209443874, 0.9763]
8"""
TF2.0 中的 keras 模块应该会是最经常使用的模块,与 keras 框架有许多相似之处。
一句话送给大家:如果你在你当前的行业积累并不是很深的时候,而另外一个行业(比如:机器学习)又有比较好的前景的时候,可以考虑 ALL in 一个新的行业。
工具千千万,效率第一条,任何框架的迭代速度都在加快,不应该只是学会调用框架,一些最基础的概念、算法应该要相当熟悉,这也是我经常提醒自己的。
参考资料: