tf28: 手写汉字识别

MNIST手写数字数据集通常做为深度学习的练习数据集,这个数据集恐怕早已经被大家玩坏了。本帖就介绍一个和MNIST类似,同时又适合国人练习的数据集-手写汉字数据集,然后训练一个简单的Deep Convolutional Network识别手写汉字。

识别手写汉字要把识别手写洋文难上很多。首先,英文字符的分类少,总共10+26*2;而中文总共50,000多汉字,常用的就有3000多。其次,汉字有书法,每个人书写风格多样。

手写汉字数据集: CASIA-HWDB

下载HWDB1.1数据集:

$ wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip
# zip解压没得说, 之后还要解压alz压缩文件
$ wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip

这个数据集由模式识别国家重点实验室共享,它还共享了其它几个数据库,先mark:

  • 行为分析数据库
  • 三维人脸数据库
  • 中文语言资源库
  • 步态数据库
  • 掌纹数据库
  • 虹膜库数据

手写汉字的样子:

import os
import numpy as np
import struct
import PIL.Image
 
train_data_dir = "HWDB1.1trn_gnt"
test_data_dir = "HWDB1.1tst_gnt"
 
# 读取图像和对应的汉字
def read_from_gnt_dir(gnt_dir=train_data_dir):
	def one_file(f):
		header_size = 10
		while True:
			header = np.fromfile(f, dtype='uint8', count=header_size)
			if not header.size: break
			sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
			tagcode = header[5] + (header[4]<<8)
			width = header[6] + (header[7]<<8)
			height = header[8] + (header[9]<<8)
			if header_size + width*height != sample_size:
				break
			image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
			yield image, tagcode
 
	for file_name in os.listdir(gnt_dir):
		if file_name.endswith('.gnt'):
			file_path = os.path.join(gnt_dir, file_name)
			with open(file_path, 'rb') as f:
				for image, tagcode in one_file(f):
					yield image, tagcode
 
# 统计样本数
train_counter = 0
test_counter = 0
for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):
	tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
	"""
	# 提取点图像, 看看什么样
	if train_counter < 1000:
		im = PIL.Image.fromarray(image)
		im.convert('RGB').save('png/' + tagcode_unicode + str(train_counter) + '.png')
	"""
	train_counter += 1
for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
	tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
	test_counter += 1
 
# 样本数
print(train_counter, test_counter)

由于时间和系统资源有限,我只使用数据集的一部分(只识别最常用的140个汉字)。

训练模型

  • 模型架构参看论文:Deep Convolutional Network for Handwritten Chinese Character Recognition
import os
import numpy as np
import struct
import PIL.Image
 
train_data_dir = "HWDB1.1trn_gnt"
test_data_dir = "HWDB1.1tst_gnt"
 
# 读取图像和对应的汉字
def read_from_gnt_dir(gnt_dir=train_data_dir):
	def one_file(f):
		header_size = 10
		while True:
			header = np.fromfile(f, dtype='uint8', count=header_size)
			if not header.size: break
			sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
			tagcode = header[5] + (header[4]<<8)
			width = header[6] + (header[7]<<8)
			height = header[8] + (header[9]<<8)
			if header_size + width*height != sample_size:
				break
			image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
			yield image, tagcode
 
	for file_name in os.listdir(gnt_dir):
		if file_name.endswith('.gnt'):
			file_path = os.path.join(gnt_dir, file_name)
			with open(file_path, 'rb') as f:
				for image, tagcode in one_file(f):
					yield image, tagcode
 
import scipy.misc
from sklearn.utils import shuffle
import tensorflow as tf
 
# 我取常用的前140个汉字进行测试
char_set = "的一是了我不人在他有这个上们来到时大地为子中你说生国年着就那和要她出也得里后自以会家可下而过天去能对小多然于心学么之都好看起发当没成只如事把还用第样道想作种开美总从无情己面最女但现前些所同日手又行意动方期它头经长儿回位分爱老因很给名法间斯知世什两次使身者被高已亲其进此话常与活正感"
 
def resize_and_normalize_image(img):
	# 补方
	pad_size = abs(img.shape[0]-img.shape[1]) // 2
	if img.shape[0] < img.shape[1]:
		pad_dims = ((pad_size, pad_size), (0, 0))
	else:
		pad_dims = ((0, 0), (pad_size, pad_size))
	img = np.lib.pad(img, pad_dims, mode='constant', constant_values=255)
	# 缩放
	img = scipy.misc.imresize(img, (64 - 4*2, 64 - 4*2))
	img = np.lib.pad(img, ((4, 4), (4, 4)), mode='constant', constant_values=255)
	assert img.shape == (64, 64)
 
	img = img.flatten()
	# 像素值范围-1到1
	img = (img - 128) / 128
	return img
 
# one hot
def convert_to_one_hot(char):
	vector = np.zeros(len(char_set))
	vector[char_set.index(char)] = 1
	return vector
 
# 由于数据量不大, 可一次全部加载到RAM
train_data_x = []
train_data_y = []
 
for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):
	tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
	if tagcode_unicode in char_set:
		train_data_x.append(resize_and_normalize_image(image))
		train_data_y.append(convert_to_one_hot(tagcode_unicode))
 
# shuffle样本
train_data_x, train_data_y = shuffle(train_data_x, train_data_y, random_state=0)
 
batch_size = 128
num_batch = len(train_data_x) // batch_size
 
text_data_x = []
text_data_y = []
for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
	tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
	if tagcode_unicode in char_set:
		text_data_x.append(resize_and_normalize_image(image))
		text_data_y.append(convert_to_one_hot(tagcode_unicode))
# shuffle样本
text_data_x, text_data_y = shuffle(text_data_x, text_data_y, random_state=0)
 
 
X = tf.placeholder(tf.float32, [None, 64*64])
Y = tf.placeholder(tf.float32, [None, 140])
keep_prob = tf.placeholder(tf.float32)
 
def chinese_hand_write_cnn():
	x = tf.reshape(X, shape=[-1, 64, 64, 1])
	# 3 conv layers
	w_c1 = tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))
	b_c1 = tf.Variable(tf.zeros([32]))
	conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
	conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
	
	w_c2 = tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))
	b_c2 = tf.Variable(tf.zeros([64]))
	conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
	conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
	
	"""
	# 训练开始之后我就去睡觉了, 早晨起来一看, 白跑了, 准确率不足10%; 把网络变量改少了再来一发
	w_c3 = tf.Variable(tf.random_normal([3, 3, 64, 128], stddev=0.01))
	b_c3 = tf.Variable(tf.zeros([128]))
	conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
	conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
	conv3 = tf.nn.dropout(conv3, keep_prob)
	"""
 
	# fully connect layer
	w_d = tf.Variable(tf.random_normal([8*32*64, 1024], stddev=0.01))
	b_d = tf.Variable(tf.zeros([1024]))
	dense = tf.reshape(conv2, [-1, w_d.get_shape().as_list()[0]])
	dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))
	dense = tf.nn.dropout(dense, keep_prob)
 
	w_out = tf.Variable(tf.random_normal([1024, 140], stddev=0.01))
	b_out = tf.Variable(tf.zeros([140]))
	out = tf.add(tf.matmul(dense, w_out), b_out)
 
	return out
 
def train_hand_write_cnn():
	output = chinese_hand_write_cnn()
 
	loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output, Y))
	optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
 
	accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(output, 1), tf.argmax(Y, 1)), tf.float32))
 
	# TensorBoard
	tf.scalar_summary("loss", loss)
	tf.scalar_summary("accuracy", accuracy)
	merged_summary_op = tf.merge_all_summaries()
 
	saver = tf.train.Saver()
	with tf.Session() as sess:
		sess.run(tf.global_variables_initializer())
 
		# 命令行执行 tensorboard --logdir=./log  打开浏览器访问http://0.0.0.0:6006
		summary_writer = tf.train.SummaryWriter('./log', graph=tf.get_default_graph())
 
		for e in range(50):
			for i in range(num_batch):
				batch_x = train_data_x[i*batch_size : (i+1)*batch_size]
				batch_y = train_data_y[i*batch_size : (i+1)*batch_size]
				_, loss_, summary = sess.run([optimizer, loss, merged_summary_op], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.5})
				# 每次迭代都保存日志
				summary_writer.add_summary(summary, e*num_batch+i)
				print(e*num_batch+i, loss_)
 
				if e*num_batch+i % 100 == 0:
					# 计算准确率
					acc = accuracy.eval({X: text_data_x[:500], Y: text_data_y[:500], keep_prob: 1.})
					#acc = sess.run(accuracy, feed_dict={X: text_data_x[:500], Y: text_data_y[:500], keep_prob: 1.})
					print(e*num_batch+i, acc)
 
train_hand_write_cnn()

Computation Graph:

loss:

准确率:

Share the post "TensorFlow练习22: 手写汉字识别"

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能

如何仅使用TensorFlow C+来训练深度神经网络

作者|Florian Courtial 译者|Debra 编辑|Emily AI 前线导读:训练神经网络是一件十分复杂,难度非常大的工作,有没有可能让训练的过程...

2155
来自专栏ATYUN订阅号

Deep Photo Styletransfer的一种纯Tensorflow实现,教你如何转换图片风格

通过深度学习,一秒钟让你的照片高大上,这是康奈尔大学和 Adobe 的工程师合作的一个新项目,通过卷积神经网络把图片进行风格迁移。项目已开源,名字叫「Deep ...

4575
来自专栏点滴积累

geotrellis使用(十六)使用缓冲区分析的方式解决投影变换中边缘数据值计算的问题

Geotrellis系列文章链接地址http://www.cnblogs.com/shoufengwei/p/5619419.html 目录 前言 问题探索 ...

3434
来自专栏专知

【资源】Python实现多种模型(Naive Bayes, SVM, CNN, LSTM, etc)用于推文情感分析

【导读】近日,Abdul Fatir 在自己的CS5228课程报告使用不同的方法进行Tweets情感分析(作为二分类问题),并对这些方法的性能进行比较,主要是基...

4079
来自专栏ATYUN订阅号

如何用自己的数据训练MASK R-CNN模型

计算机一直很擅长数字处理,却苦于分析图像中的大量数据。直到最近,创建图形处理单元库已不仅仅用于游戏,现在我们可以利用数千个核心的原始力量来揭示图片背后的含义。

3176
来自专栏日常学python

教你如何用Python玩转小程序游戏「一笔画完」

欢迎大家关注。话不多说,直接上张图,看看一笔画完游戏是什么样的,我们需要将灰格子都走完且只走一遍 Z

694
来自专栏素质云笔记

图像增强︱window7+opencv3.2+keras/theano简单应用(函数解读)

在服务器上安装opencv遇到跟CUDA8.0不适配的问题,于是不得不看看其他机器是否可以预装并使用。 . 一、python+opencv3.2安装 ope...

38410
来自专栏程序员宝库

Python验证码识别:利用pytesser识别简单图形验证码

来源: j_hao104 my.oschina.net/jhao104/blog/647326 一、探讨 识别图形验证码可以说是做爬虫的必修课,涉及到计算机图...

53810
来自专栏北京马哥教育

20行 Python 代码实现验证码识别

一、探讨 识别图形验证码可以说是做爬虫的必修课,涉及到计算机图形学,机器学习,机器视觉,人工智能等等高深领域…… 简单地说,计算机图形学的主要研究内容就是研究如...

5098
来自专栏小詹同学

人脸检测(一)——基于单文档的应用台程序

Opencv自带训练好的人脸模型(人脸的人眼、口等器官类似),此文基于vs2013建立应用台单文档程序,具体建立过程不予详细叙述,主要记录利用的Opencv自带...

3915

扫码关注云+社区