前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow中生成手写笔迹的Demo

TensorFlow中生成手写笔迹的Demo

作者头像
Zach展
发布2018-02-06 15:45:18
2.5K0
发布2018-02-06 15:45:18
LSTM MDN生成的手写样本和下一个点的概率密集分布
LSTM MDN生成的手写样本和下一个点的概率密集分布
更长的手写样本
更长的手写样本

这项操作现在在github上已经可以使用了。

在我们的最后一篇文章中,我们讨论了混合密度网络,以及它们是怎样成为一种非常有用的,可以用各种状态模拟数据的工具,而不会试图揣测数据点的期望值。它使我们能够预测很多应用程序数据的整个概率密度函数,我们认为这无论是对于应用程序还是对于生成任务都是非常有用的。

在这篇文章中,我将会讨论一些能够将MDN与LSTM结合起来,以生成人造手写笔迹的例子。在研究了Alex Grave的论文之后,我们将尝试使用RNNs来实现他的部分工作,用于生成连续的数据。并且对他这个非常棒的demo进行一些有趣的操作。对于LSTMs和Recurrent网络的介绍,我推荐Colah的《理解LSTM网络》和Karparthy的《递归神经网络的不合理有效性》,Karparthy在书中讨论了使用RNN生成文本数据序列的方法,例如Shakesphere,LaTeX文档,甚至是假的Linux内核C代码都有涉及!在了解了这种方法的工作原理,和一些关于MDN的内容之后,你应该就能够理解这个算法是如何工作的了。

我已经使用TensorFlow在Python中实现了这个Demo,而且我依靠这个由sherjilozair制作的char-rnn-tensorflow工具实现了字符级的文本预测。他的例子极大地教会了我如何让LSTMs在TensorFlow中工作。

训练数据

为了能让我们的神经网络写出东西,它必须先训练一组相对较大的笔迹数据。我们将使用格雷夫斯在他的论文《IAM手写数据库》中相同的数据。因为下载这些数据需要请求权限,所以我不能把它们放在github上。如果你想训练这个网络,你需要自己把文件 lineStrokes-all.tar.gz 解压缩到data子目录中。

在IAM数据库中,大约有13000条不同的手写笔迹的例子,这些例子都是从一个数字化的笔划数据中记录下来的。这些数据通过xml格式记录,每个数据中包含一组笔划,每个笔划都是由一系列用笔在纸上连续画出来的点组成。下面的例子可以让你知道这些数据看起来是什么样子的。我已经写了一些代码来提取所有这些数据,并在IPython会话中用交互的方式绘制它们:

代码语言:txt
复制
    %run -i utils.py
    data_loader = DataLoader()
    for i in range(5):
      draw_strokes(random.choice(data_loader.raw_data))

有些数据是杂乱的,可能会包含一些错误,并且用户会在出错后手动抹掉他的笔迹。在格雷夫斯的论文中,他使用了一些过滤器来检测不好的例子,但是在这个demo中,我把所有的数据都放进去了。我只缩放数据的尺寸,使其与神经网的输出更加兼容,并且限制了从一个笔画到另一个笔画间距的大小。

每个训练样例可以被看成组成了一个笔画的一系列点:

代码语言:txt
复制
    sample = random.choice(data_loader.raw_data)
    draw_strokes_random_color(sample, per_stroke_mode = False)
    draw_strokes_random_color(sample)

从上面的图中我们可以看出,每个点或笔划的颜色都是随机的,每个例子都是由一组连接点组成,而每组连接点形成一个笔划。我们将这些数据建模为一系列向量,这些向量包含x和y方向到下一个点的步长,以及笔划的终点值(值可以是0或1),这个终点值可以表示要么下一个点仍是当前笔划的一部分,要么我们需要抬起笔并开始新的笔划。

模型描述

我们既不是用我们的网络预测下一个点的确切位置,也不是预测当前点是否是笔划的结束,而是使用MDN方法使网络输出一组关于下一个笔划位置的相对概率分布的参数(△x,△y),以及一个简单的用于推测笔画结束位置概率的伯努利 “硬币翻转”分布。直接预测方法不起作用的原因是,下一个笔划的位置会被太多不同的状态和环境所影响。我们所要做的就是预测下一个点的平均预期位置,虽然说这个位置可能是一个非常琐碎的结果,做个比喻,就像一条飘渺不定的线一样。

在前一篇关于MDNs的倒转正弦数据中,我们想要模拟数据中不同的潜在状态和环境,并能够产生下一个点的合理分布,这个分布的条件是基于整个历史出现过的点的,然后我们可以从中进行绘制并生成我们的手写示例。因此,在这个MDN中,到网络的输入将会是以下几点:最近一次有关联的笔划的运动轨迹,最近一次相关联的笔划的结束信号,以及网络先前的隐藏状态。而网络的输出可以是一组下一笔划运动轨迹和下一个笔画结束信号的参数化概率分布。

在我们根据过去的数据对网络进行了训练,并生成准确的未来分布之后,我们可以从概率分布中抽样来生成我们的手写笔迹样本。就像神经网络通过反馈自己以前生成的笔划来创造出一些手写的例子一样。在我们的demo中,我们使用了一个每层有256个节点,2层堆叠的基本LSTM网络(无窥孔连接)。

我们对于笔划轨迹矢量的概率分布模型将会是一个联合二维正态混合分布模式,表征为二维正态分布的概率加权的和。每个分布都有自己的均值和协方差。我们在演示中使用了20种混合,与Graves的论文一致,但是我们发现实际上其实5-10种混合的效果就很不错了,但是额外的混合数量并没有真的引起算法性能的大幅下降,因为大多数权重都在LSTM层中,所以我们依然保留了20个混合的使用。如果你想试验不同数量的节点,节点类型(RNN,GRU等),或者启用LSTM窥视孔连接,更改混合分布的数量,使用不同的DropOut概率 - 你可以通过在运行train.py时设置不同的标志来完成这些更改.

总的来说,我们需要从我们网络和MDN中输出121个值(用Z表示),以推断我们的分布。其中一个值将被用作笔画的结束概率,20个值将定义每个混合的概率,而其余的100个值构成20组2D正态分布参数。由于输出值可能是不受限制的实数,我们将执行一些转换以获得它在参数空间中的值:

就像在之前的MDN例子中那样使用这些转换,IIk这个值会经过softmax操作符进行转换,所以它的总和为1。而笔划结束概率e也被限制在0和1之间。值的标准偏差参数将为正,并且在应用指数和双曲线切线变换之后,两个坐标之间的相关性将会在-1和1之间。在获得参数之后,下一个笔划的概率密度将被定义为:

与前面的例子不同的是,所有权重都会存储在一个叫做球张量(global tensor)的变量类型中。因为这个任务涉及到更多的东西,并且有更多的移动部分,所以我们喜欢将模型很好地打包成一个类的形式,以便更容易的使用面向对象的接口。我们还为LSTM图层的每个输出层引入了DropOut来规范训练,以避免进行过度训练。但是我们没有将DropOut应用于输入层,因为写东西的顺序性和路径依赖性意味着它不会错过一个笔划的结束。我们发现DropOut在这个任务中是相当有效的,而且TensorFlow使得在这个特性中“drop”变得更容易一些。TensorFlow的rnn_cell模块使得使用DropOut实现堆叠RNN相当容易。例如,下面是使用DropOut构建我们的网络中使用的两层LSTM层所需的全部代码:

代码语言:txt
复制
    cell = rnn_cell.BasicLSTMCell(256)
    cell = rnn_cell.MultiRNNCell([cell] * 2)
    cell = rnn_cell.DropoutWrapper(cell, output_keep_prob = 0.8)

对于训练来说,我们将像以前一样对整个生成序列的最大相似预估使用交叉熵的方式。虽然梯度的有效封闭形式派生是可用的,但我们还是会依靠TensorFlow,通过其符号引擎来自动计算梯度。当我们向时间的反向传播派生的时候,我们使用10.0的梯度剪辑来避免梯度被过度放大。

具体的实施细节请参考model.py

批量梯度下降(Mini-Batch)训练的具体细节

为了预处理IAM的手写数据,我已经编写了模型来训练上述网络。如果预处理数据尚未建立的话,模块将从原始xml文件构建一个cPickle预处理数据库。不过对于训练来说,有一点比较棘手。我们如果想要使用批量梯度下降,为了保证操作有效,那么它们必须都要是相同的长度。我并不想连接每一个笔画并训练一组相同大小的笔画数据,因为这些笔划的线之间会有很多不自然的间距。而且我们还要对这种人为造成的错误进行训练。

最终,我选择了一个序列长度,这个序列会有300个点来供我们进行训练。我们扔掉那些少于300点的训练数据序列(其实不会扔掉很多,因为大部分的训练数据都有差不多300-2000个点)。之后,在创建批量梯度下降时,我将从每个样本中随机抽取连续的300个点的部分。例如,如果一个训练样本有400个数据点,插入到批量梯度下降中的样本将会是从0:300到100:400之间的任何地方,所以这实际上可能有助于更多地推广数据(如扭曲MNIST图像创建更多的数据点)。另外,对于哪些包含300个点以上的样本,比如说一个有1500点的样本,我会使用这个大的样本5倍于那些只有300-400分的样本,以确保更大的那个样本没有被训练不足。整个训练过程会持续大约30个时期。在没有使用GPU的情况下,在MacBook Pro上运行会花大约半天的时间。

从网络中生成样本

在训练数据结束后,我们的网络可以生成样本并保存为.svg文件。我想出了如何在IPython中显示它们的方法,并编写了一些模块来自动显示一些示例。

当我们对手写序列进行采样时,我们首先清空LSTM网络的状态,并将初始输入的值传入网络。

代码语言:txt
复制
    prev_x = np.zeros((1, 1, 3), dtype=np.float32)
    prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
    prev_state = sess.run(self.cell.zero_state(1, tf.float32))

最初的初始输入值只是一个零矢量,但它会触发笔画结束的信号,并向网络发出信号,指出它产生的下一个点将是新笔画的开始,而不是现有笔画的延续,因此我认为我们可以用更有意思的方式获得多样化的起点。

在初始输入值和零状态传入网络之后,我们将从网络输出中得到一组参数,这组参数将是混合二维高斯分布的参数,其定义了下一个点位置的的概率分布。另外还有一个参数定义了下一个点是另一个笔划开始的概率。

我们随机地从这个分布中抽取一组值,然后把这个点加到我们在这个过程中建立的一个.svg文件中,并记录下网络的状态。之后,我们重复这个过程,并将采样点和网络状态作为输入返回,以获得另一个概率分布为了从下一个点开始采样,我们一遍一遍的重复,直到得到800个点(或者用户指定)。下面是采样过程的python虚拟程序代码。

代码语言:txt
复制
    strokes = np.zeros((num, 3), dtype=np.float32)

    for i in xrange(num):

        # get the model parameters from the network
        feed = {model.input_data: prev_x, model.initial_state:prev_state}
        [model_params, next_state] = sess.run([model.model_params, model.final_state],feed)

        # sample whether we want to end the stroke
        eos = sample_eos(model_params.eos)

        # sample which mixture to use
        idx = sample_mixture_index(random.random(), model_params.pi)

        # sample location of the next stroke
        next_x1, next_x2 = sample_gaussian_2d(model_params, idx)

        # put the current generated stroke as the next input
        prev_x = [next_x1, next_x2, eos]

        # record the generated stroke, since we want to draw it later
        strokes[i,:] = prev_x

        # save the current RNN's state to feed it back in next time
        prev_state = next_state

样品结果

下面是我们如何在IPython中交互地使用代码来生成和绘制的一些例子,比如说,800个点的例子:

代码语言:txt
复制
    %run -i sample.py
    [strokes, params] = model.sample(sess, 800)
    draw_strokes_random_color(strokes)

其实这些结果看起来还挺好的。这些笔迹有时在印刷和草书之间转换,仿佛是被一个有精神妄想症的疯子写在一个荒漠城堡里面一间空荡荡的房间里一样。(我相信我们都碰到过这样的人)

除了保存采样点之外,我们还保存了概率分布参数的历史记录,以进一步显示实际情况。在下面的示例中,我们绘制了生成的样本,并额外绘制了两个不同的分布图以得到一个结论。

代码语言:txt
复制
    draw_strokes_pdf(strokes, params)
    draw_strokes_eos_weighted(strokes, params)

第二个和第三个可视化分布图是网络在构思样本时,它书写过程的概率分布。在第二个图中,我们绘制出实际的采样路径,加上每一个点到下一个点的概率密度。在第三个图中,我们将采样路径与每个点的结束概率重叠。我们可以看到,当我们接近每个笔画的末尾时,笔画发出结束信号的可能性自然地增加 - 这时线条变得越来越暗。另外,我们在第二幅图中看到,当网络进行连续的书写时,它对下一个点的位置相当有自信,因为图中的小红点暗示着这一小块目标区域的密集分布。同时,随着笔划结束的临近,下一个点的概率密度也变得越来越稀疏。并且,图中那些更大更透明的圆形气泡说明了我们的网络有时候会产生出更多样化的下一个点。

下面是另外两组拥有更长序列长度的样本图表:

下一步我们要做的是

接下来要做的事情可能就是尝试在Graves的论文中实现手写合成。这可能包含在char-rnn中使用的字符嵌入。他发现,字符预测和笔画预测的组合是生成自然的合成手写序列的关键,因为网络需要了解特定字符的特定笔画是如何衔接到另一个字符序列里面去的。

另外一项有趣的工作是将生成副本网络的方法合并到现网中。训练一个鉴别网络去区分假的手写笔迹和真的手写笔迹,而另一个网络产生假得手写来欺骗这个鉴别网络。虽然这对于RNN来说可能有点困难,但我们还是需要尝试的!

有的人认为,相对于GANs来说,生成匹配网络可能是生成RNN的一种更好的方法。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 训练数据
  • 模型描述
  • 批量梯度下降(Mini-Batch)训练的具体细节
  • 从网络中生成样本
  • 样品结果
  • 下一步我们要做的是
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档