专栏首页MixLab科技+设计实验室深度学习生成舞蹈影片02之MDN代码练习

深度学习生成舞蹈影片02之MDN代码练习

阅读难度:★★★★☆

技能要求:机器学习基础、Keras、numpy、matplotlib

字数:960字

阅读时长:5分钟

本文接上一期,补充一些MDN的代码练习。本教程开发环境是python+jupyter,引用了一个用keras写的mdn包,目标是拟合反正弦函数曲线:

y=7.0sin(0.85x)+0.5x+r

该函数在每个点都有多个解,因此要求ANN模型需要有能力处理它的损失函数。 MDN是预测这些多输出值的好方法。

1

引入相关依赖

import keras
import mdn
import numpy as np
import matplotlib.pyplot as plt

2

生成模拟数据

#y=7.0sin(0.85x)+0.5x+r
#r标准的高斯随机噪声

NSAMPLE = 3000

y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))

r_data = np.random.normal(size=NSAMPLE)

x_data = np.sin(0.85 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0

x_data = x_data.reshape((NSAMPLE, 1))

plt.figure(figsize=(4, 4))

plt.plot(x_data,y_data,'ro', alpha=0.3,markersize = 1)

plt.show()

3

建模

接下来,我们在Keras中构建MDN模型。 使用了Keras中的Sequential模型,其中MDN层位于一个或多个Dense层之后。 您需要为MDN定义输出维度和混合状态的数量,比如:

MDN(output_dimension,number_mixtures)

对于本教程的问题,我们只需要定义输出维度为1,因为我们预测的y值维度为1。 添加更多的混合状态数量会增加更多参数(模型更复杂,需要更长时间训练),但可能有助于使预测结果更好。 你可以从训练数据中看到曲线评估混合状态的数量有5个,因此设置混合状态的数量N_MIXES = 5是比较好的方式。

对于MDN,我们需定义适合的损失函数,使其可以处理混合状态参数,损失函数必须考虑输出维数和混合状态的数量。

N_HIDDEN = 12
N_MIXES = 6

model = keras.Sequential()

model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))

model.add(keras.layers.Dense(N_HIDDEN, activation='relu'))

model.add(mdn.MDN(1, N_MIXES))

model.compile(loss=mdn.get_mixture_loss_func(1,N_MIXES), optimizer=keras.optimizers.Adam())

model.summary()

网络结果如下图所示:

4

训练模型

history = model.fit(x=x_data, y=y_data, batch_size=128, epochs=500, validation_split=0.2)

5

可视化

我们通过图表的方式查看模型是如何训练的。 从下图,我们可以看到,经过一定的训练后,训练效果的提升相当缓慢。对于本教程,1.5左右的损失值产生了相当好的结果。

代码如下:

plt.figure(figsize=(10, 5))
plt.ylim([0,9])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.show()

6

预测

现在我们可以通过在x轴上产生3000个均匀间隔点来预测y轴的数值,测试下训练好的模型。注意,y_test包含分布的参数,而不是图上的实际点。要在图上找到点,我们需要从每个分布中进行采样,采样后的结果为y_samples。

x_test = np.float32(np.arange(-15,15,0.01))

NTEST = x_test.size

print("Testing:", NTEST, "samples.")

x_test = x_test.reshape(NTEST,1) 

y_test = model.predict(x_test)

y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, 1, N_MIXES,temp=1.0)

对比下预测结果:

plt.figure(figsize=(4, 4))

plt.plot(x_data,y_data,'ro',x_test, y_samples[:,:,0], 'bo',alpha=0.3,markersize = 1)

plt.show()

附上keras实现的MDN:

https://github.com/cpmpercussion/keras-mdn-layer

本文分享自微信公众号 - 无界社区mixlab(Design-AI-Lab),作者:shadow chi

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-08-16

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 量化交易×AI音乐 | 数学之美 I.

    ibrandup对社区的 ML.413 Yuiant 进行了采访,以下为采访的正文:

    mixlab
  • 人工智能未来简史2020年-2805年

    人工智能裁判开始全面取代人类裁判员,部分人类已经无法与AI竞争,无法找到工作岗位;人类进入无人驾驶时代,出行自由,驾驶证这一历史证件被纳入博物馆珍藏;由Spac...

    mixlab
  • AI版知乎

    作为一个个体知识量的储备是有限的,况且每天信息更新这么快,我们又这么忙,哪有空余的时间是一个个信息学习,理解,消化。

    mixlab
  • TensorFlow ML cookbook 第一章7、8节 实现激活功能和使用数据源

    问题导读: 1、TensorFlow中有哪些激活函数? 2、如何运行激活函数? 3、TensorFlow有哪些数据源? 4、如何获得及使用数据源? 上...

    用户1410343
  • Python进阶之Matplotlib入门(二)

    Matplotlib是Python的画图领域使用最广泛的绘图库,它能让使用者很轻松地将数据图形化以及利用它可以画出许多高质量的图像,是用Python画图的必备技...

    HuangWeiAI
  • 微信小程序开放关键词搜索,让你的小程序更快被找到

    怎样可以找到一个想用的小程序?可能是线下扫码、公众号、好友分享、长按小程序码、搜索小程序名称…… 今天起,多了一个新方式——小程序后台新增自定义关键词功能: 已...

    极乐君
  • 小程序“自定义关键词”功能的常见问答

      我们知道小程序可以通过线下扫码、公众号、好友分享、长按小程序码、搜索小程序名称来找到,现在又多了一个新方式——小程序后台新增自定义关键词功能:已发布小程序的...

    ytkah
  • Python模拟登陆 —— 征服验证码 8 微信网页版

    微信登录界面 微信网页版使用了UUID含义是通用唯一识别码来保证二维码的唯一性。 先用一个伪造的appid获得uuid。 params = { ...

    SeanCheney
  • AccessibilityService+WindowManager+SurfaceView开系统权限

      本文是基于辅助功能+悬浮窗+SurfaceView来实现自动获取用户权限的具体方案设计与实现。辅助功能抢红包插件相信大家并不陌生,但是微信官方不允许,但是在...

    用户1155943
  • 来也科技完成B+轮3500万美元融资,携手UiBot

    合并后的来也科技,汪冠春继续担任董事长兼CEO,原奥森科技CEO李玮任联席CEO兼总裁,原来也CTO胡一川继续任CTO,原奥森科技CTO褚瑞任高级副总裁。

    新智元

扫码关注云+社区

领取腾讯云代金券