DanceNet:帮你生成会跳舞的小姐姐

来源:GitHub, 机器之心

项目地址:https://github.com/jsn5/dancenet

主要模块

DanceNet 中最主要的三个模块是变分自编码器、LSTM 与 MDN。其中变分自编码器(VAE)是最常见的生成模型之一,它能以无监督的方式学习复杂的分布,因此常被用来生成图像数据。VAE 非常优秀的属性是可以使用深度神经网络和随机梯度下降进行训练,并且中间的隐藏编码还表示了图像的某些属性。

如下变分自编码器中的编码器使用三个卷积层和一个全连接层,以生成隐藏编码 z 分布的均值与方差。因为 z 服从于高斯分布,因此确定方差与均值后 z 的分布就完全确定了,随后解码器从 z 分布中采样隐藏编码就能解码为与输出图像相近的输出。

z_mean 和 z_log_var 为编码器编码的均值与方差,z 为从它们确定的分布中所采样的隐藏编码,编码器最后输出这三个变量。以下展示了 Jaison 所采用解码器的架构,其首通过全连接层对隐藏编码 z 执行仿射变换,再交叉通过 4 个卷积层与 3 个上采样层以将隐藏编码恢复为原始输入图像大小。

在训练 VAE 后,我们从高斯分布中采样一个隐藏编码 z,并将其馈送到解码器,那么模型就能生成新的图像。因此我们可以设想给定不同的隐藏编码 z,解码器最终能生成不同的舞姿图像。

最后,我们还需要长短期记忆网络(LSTM)和混合密度层以将这些舞姿图像连接在一起,并生成真正的舞蹈动作。如下 Jaison 堆叠了三个 LSTM 层,且每一个 LSTM 层后面都采用 Dropout 操作以实现更好的鲁棒性。将循环层的输出接入全连接层与混合密度网络层,最后输出一套舞蹈动作每一个时间步的图像应该是怎样的。

以下是实践该项目的环境与过程,机器之心也尝试使用了预训练模型,并生成了效果还不错的视频。此外,根据试验结果,VAE 中的编码器参数数量约 172 万,解码器约为 174 万,但 LSTM+MDN 却有 1219 万参数。最后我们生成了一个 16 秒的舞蹈视频:

要求:

Python = 3.5.2

工具包:

keras==2.2.0

sklearn==0.19.1

numpy==1.14.3

opencv-python==3.4.1

如下展示了用于训练的数据集视频:

本地实现已训练模型:

下载预训练权重

提取到 dancenet 目录

运行 dancegen.ipynb

预训练权重下载地址:https://drive.google.com/file/d/1LWtERyPAzYeZjL816gBoLyQdC2MDK961/view?usp=sharing

如何在浏览器上运行:

打开 FloydHub 工作区

训练的权重数据集将自动与环境相连

运行 dancegen.ipynb

FloydHu 工作区:bhttps://floydhub.com/run

预训练权重:https://www.floydhub.com/whatrocks/datasets/dancenet-weights

从头开始训练:

在 imgs/ 文件夹下补充标签为 1.jpg, 2.jpg ... 的舞蹈序列图像

运行 model.py

运行 gen_lv.py 以编码图像

运行 video_from_lv.py 以测试解码视频

运行 jupyter notebook 中的 dancegen.ipynb 以训练 dancenet 并生成新视频

参考资料:

Cary Huang:https://www.youtube.com/watch?v=Sc7RiNgHHaE&t=9s

Francois Chollet:https://blog.keras.io/building-autoencoders-in-keras.html

通过 Python 和 Keras 使用 LSTM 循环神经网络构建时序预测模型:https://machinelearningmastery.com/time-series-prediction-lstm-recurrent-neural-networks-python-keras/

使用深度学习的生成编舞:https://arxiv.org/abs/1605.06921

David Ha:http://blog.otoro.net/2015/06/14/mixture-density-networks/

Charles Martin:https://github.com/cpmpercussion/keras-mdn-layer

- 加入AI学院学习 -

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180813A0PRG700?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券