本篇文章在上篇TensorFlow-手写数字识别(一)的基础上进行改进,主要实现以下3点:
上次的代码每次进行模型训练时,都会重新开始进行训练,之前的训练结果都被覆盖掉了,极不方便。
在backwork.py中加入ckpt操作,可以实现断点续训功能。
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
for i in range(STEPS):
xs, ys = sess.run([img_batch, label_batch])
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 100 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
注解:
tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
该函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。checkpoint_dir
:表示存储断点文件的目录latest_filename=None
:断点文件的可选名称,默认为“checkpoint”saver.restore(sess, ckpt.model_checkpoint_path)
该函数表示恢复当前会话,将ckpt中的值赋给w和b。sess
:表示当前会话,之前保存的结果将被加载入这个会话ckpt.model_checkpoint_path
:表示模型存储的位置,不需要提供模型的名字,它会去查看 checkpoint 文件,看看最新的是谁,叫做什么。代码运行效果:
RESTART: G:\TestProject\python\tensorflow\...\mnist_backward.py
After 16203 training step(s), loss on training batch is 0.155758.
After 16303 training step(s), loss on training batch is 0.173135.
After 16403 training step(s), loss on training batch is 0.159716.
可以看出,程序可以接着之前的训练数据接着训练
上次的代码只能使用MNIST自带数据集中的数据进行训练,这次通过编写mnist_app.py函数,实现真实图片数据的预测。
def application():
testNum = input("input the number of test pictures:")
for i in range(testNum):
testPic = raw_input("the path of test picture:")
testPicArr = pre_pic(testPic)
preValue = restore_model(testPicArr)
print "The prediction number is:", preValue
任务分两个函数完成
#预处理函数,包括resize、转变灰度图、二值化操作
def pre_pic(picName):
img = Image.open(picName) #加载待测试图片(白底)
reIm = img.resize((28,28), Image.ANTIALIAS) #调整大小到28x28
im_arr = np.array(reIm.convert('L'))
threshold = 50 #二进制阈值
for i in range(28):
for j in range(28):
im_arr[i][j] = 255 - im_arr[i][j] #反色(黑底)
if (im_arr[i][j] < threshold): #黑底白字
im_arr[i][j] = 0
else:
im_arr[i][j] = 255
nm_arr = im_arr.reshape([1, 784]) #图片转成1行
nm_arr = nm_arr.astype(np.float32)
img_ready = np.multiply(nm_arr, 1.0/255.0) #取值范围限制在0~1之间
return img_ready
def restore_model(testPicArr):
with tf.Graph().as_default() as tg:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y = mnist_forward.forward(x, None)
preValue = tf.argmax(y, 1)
variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict={x:testPicArr})
return preValue
else:
print("No checkpoint file found")
return -1
注解:
1)main 函数中调用的application()
函数:输入要识别的几张图片(注意要给出待识别图片的路径和名称)。
2)代码处理过程:
1)运行 mnist_backward.py 首先对模型进行训练
RESTART: G:\TestProject\python\tensorflow\...\mnist_backward.py
After 16203 training step(s), loss on training batch is 0.155758.
After 16303 training step(s), loss on training batch is 0.173135.
After 16403 training step(s), loss on training batch is 0.159716.
2)运行 mnist_test.py 使用测试集,监测模型的准确率
RESTART: G:\TestProject\python\tensorflow\...\mnist_test.py
After 16703 training step(s), test accuracy = 0.9798
3)运行 mnist_app.py 输入1~10之间的数(表示循环验证的图片数量)
RESTART: G:\TestProject\python\tensorflow\...\mnist_app.py
input the number of test pictures:5
the path of test picture:pic\0.png
The prediction number is: [0]
the path of test picture:pic\1.png
The prediction number is: [3]
the path of test picture:pic\5.png
The prediction number is: [5]
the path of test picture:pic\8.png
The prediction number is: [8]
the path of test picture:pic\9.png
The prediction number is: [9]
>>>
上次的程序使用的MNIST整理好的特定格式的数据,如果想要用自己的图片进行模型训练,就需要自己制作数据集。
数据集的制作的不仅仅是将图片整理在一起,通过转换成特定的格式,可以加速图片读取的效率。
下面将MNIST数据集转换成tfrecords格式,该方法也可以将普通图片转换为该格式。
tfrecords
:一种二进制文件,可先将图片和标签制作成该格式的文件,使用tfrecords进行数据读取会提高内存利用率tf.train.Example
:用来存储训练数据,训练数据的特征用键值对的形式表示SerializeToString( )
:把数据序列化成字符串存储读取原始图片和标签文件,转换为tfrecord格式
def write_tfRecord(tfRecordName, image_path, label_path):
writer = tf.python_io.TFRecordWriter(tfRecordName) #新建一个writer
num_pic = 0
f = open(label_path, 'r') #打开标签文件
contents = f.readlines() #读入(格式如:2028_7.jpg 7)
f.close()
for content in contents: #遍历每张图片和对应标签
value = content.split() #拆分:图片名+对应标签
img_path = image_path + value[0]
img = Image.open(img_path) #打开对应的图片文件
img_raw = img.tobytes()
labels = [0] * 10
labels[int(value[1])] = 1
#把每张图片和标签封装到 example 中
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
writer.write(example.SerializeToString()) #把example进行序列化
num_pic += 1
print ("the number of picture:", num_pic)
writer.close() #关闭writer
print("write tfrecord successful")
注解:
writer = tf.python_io.TFRecordWriter( tfRecordName)
:新建一个 writerfor循环
:遍历每张图和标签writer.write(example.SerializeToString())
:把 example 进行序列化writer.close()
:关闭 writer保存tfrecord格式文件
def generate_tfRecord():
isExists = os.path.exists(data_path) #检查用于存放数据集的路径是否存在
if not isExists:
os.makedirs(data_path)
print('The directory was created successfully')
else:
print('directory already exists')
write_tfRecord(tfRecord_train, image_train_path, label_train_path)
write_tfRecord(tfRecord_test, image_test_path, label_test_path)
获取tfrecords文件接口函数
def get_tfrecord(num, isTrain=True):
if isTrain:
tfRecord_path = tfRecord_train
else:
tfRecord_path = tfRecord_test
img, label = read_tfRecord(tfRecord_path)
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size = num,
num_threads = 2,
capacity = 1000,
min_after_dequeue = 700)
#返回的图片和标签为随机抽取的 batch_size 组
return img_batch, label_batch
注解:
tf.train.shuffle_batch(tensors, batch_size, capacity,
min_after_dequeue, num_threads=1, seed=None,
enqueue_many=False, shapes=None, allow_smaller_final_batch=False,
shared_name=None, name=None)
tensors
: 待乱序处理的列表中的样本(图像和标签)batch_size
: 从队列中提取的新批量大小capacity
:队列中元素的最大数量min_after_dequeue
: 出队后队列中的最小数量元素,用于确保元素的混合级别num_threads
: 排列 tensors 的线程数seed
:用于队列内的随机洗牌enqueue_many
: tensor 中的每个张量是否是一个例子shapes
: 每个示例的形状allow_smaller_final_batch
: (可选)布尔值。如果为 True,则在队列中剩余数量不足时允许最终批次更小。shared_name
:(可选)如果设置,该队列将在多个会话中以给定名称共享。name
:操作的名称(可选)读取tfrecords文件
def read_tfRecord(tfRecord_path):
filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
reader = tf.TFRecordReader() # 新建一个reader
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([10], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)# 将img_raw字符串转换为8位无符号整型
img.set_shape([784])# 将形状变为一行784列
img = tf.cast(img, tf.float32) * (1./255)# 变成0到1之间的浮点数
label = tf.cast(features['label'], tf.float32) # 把标签列表变为浮点数
return img, label # 返回图片和标签(跳回到 get_tfrecord)
注解:
tf.train.string_input_producer( string_tensor, num_epochs=None,
shuffle=True,seed=None,capacity=32,
shared_name=None,name=None,cancel_op=None)
该函数会生成一个先入先出的队列,文件阅读器会使用它来读取数据
string_tensor
: 存储图像和标签信息的 TFRecord 文件名列表num_epochs
: 循环读取的轮数(可选)shuffle
:布尔值(可选),如果为 True,则在每轮随机打乱读取顺序seed
:随机读取时设置的种子(可选)capacity
:设置队列容量shared_name
:(可选) 如果设置,该队列将在多个会话中以给定名称共享。
所有具有此队列的设备都可以通过 shared_name 访问它。在分布式设置中使用这种方法意味着每个名称只能被访问此操作的其中一个会话看到。name
:操作的名称(可选)cancel_op
:取消队列(None)_, serialized_example = reader.read(filename_queue)
把读出的每个样本保存在 serialized_example 中进行解序列化,标签和图片的键名应该和制作 tfrecords 的键名相同,其中标签给出几分类
tf.parse_single_example(serialized,features,name=None,example_names=None)
该函数可以将 tf.train.Example 协议内存块(protocol buffer)解析为张量。
serialized
: 一个标量字符串张量features
: 一个字典映射功能键 FixedLenFeature 或 VarLenFeature值,也就是在协议内存块中储存的name
:操作的名称(可选)example_names
: 标量字符串联的名称(可选)将批获取的操作放到线程协调器开启和关闭之间
coord = tf.train.Coordinator( )
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
coord.request_stop( )
coord.join(threads)
注解:
tf.train.start_queue_runners( sess=None, coord=None, daemon=True,
start=True, collection=tf.GraphKeys.QUEUE_RUNNERS)
这个函数将会启动输入队列的线程,填充训练样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个tf.train.Coordinator
,这样可以在发生错误的情况下正确地关闭这些线程。
sess
:用于运行队列操作的会话。默认为默认会话coord
:可选协调器,用于协调启动的线程daemon
: 守护进程,线程是否应该标记为守护进程,这意味着它们不会阻止程序退出start
:设置为 False 只创建线程,不启动它们collection
:指定图集合以获取启动队列的 GraphKey,默认为GraphKeys.QUEUE_RUNNERS
修改后的mnist_backward.py的关键部分:
...
import mnist_generateds#【1】
...
train_num_examples = 60000#【2】 训练集图片的个数
def backward():
...
saver = tf.train.Saver()
img_batch, label_batch = mnist_generateds.get_tfrecord(BATCH_SIZE, isTrain=True)#【3】
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
coord = tf.train.Coordinator()#【4】 开启线程协调器
threads = tf.train.start_queue_runners(sess=sess, coord=coord)#【5】
for i in range(STEPS):
xs, ys = sess.run([img_batch, label_batch])#【6】
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 100 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
coord.request_stop()#【7】 关闭线程协调器
coord.join(threads)#【8】
注解:
train_num_examples=60000
在梯度下降学习率中需要计算多少轮更新一次学习率,这个值是:总样本数/batch size
image_batch, label_batch=mnist_generateds.get_tfrecord(BATCH_SIZE,isTrain=True)
xs,ys=sess.run([img_batch,label_batch])
修改后的mnist_test.py的关键部分:
#coding:utf-8
...
TEST_NUM = 10000#【1】
def test():
with tf.Graph().as_default() as g:
...
img_batch, label_batch = mnist_generateds.get_tfrecord(TEST_NUM, isTrain=False)#【2】
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
coord = tf.train.Coordinator()#【3】
threads = tf.train.start_queue_runners(sess=sess, coord=coord)#【4】
xs, ys = sess.run([img_batch, label_batch])#【5】
accuracy_score = sess.run(accuracy, feed_dict={x: xs, y_: ys})
print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))
coord.request_stop()#【6】
coord.join(threads)#【7】
else:
print('No checkpoint file found')
return
time.sleep(TEST_INTERVAL_SECS)
注解:
TEST_NUM=10000
image_batch, label_batch=mnist_generateds.get_tfrecord(TEST_NUM,isTrain=False)
xs,ys=sess.run([img_batch,label_batch])
运行测试代码 mnist_test.py
RESTART: G:\TestProject\python\tensorflow\...\mnist_test.py
After 16703 training step(s), test accuracy = 0.9794
After 16703 training step(s), test accuracy = 0.9797
After 16703 training step(s), test accuracy = 0.9795
After 16703 training step(s), test accuracy = 0.9792
运行测试代码 mnist_app.py
RESTART: G:\TestProject\python\tensorflow\...\mnist_app.py
input the number of test pictures:5
the path of test picture:pic\0.png
The prediction number is: [0]
the path of test picture:pic\1.png
The prediction number is: [3]
the path of test picture:pic\5.png
The prediction number is: [5]
the path of test picture:pic\8.png
The prediction number is: [8]
the path of test picture:pic\9.png
The prediction number is: [9]
>>>
可以看出和之前的结果一样,代码可用。
注:以上测试图片用的是下面教程中自带的图片,测试结果100%准确,我自己用Windows画图板手写了0~9的数字,准确度只有50%左右,可能是我手写字体和MNIST库中的风格差异较大,或是目前的网络还不够好,下一篇通过搭建CNN网络继续测试。
参考:人工智能实践:Tensorflow笔记