前端也可以搞 AI?今天就来看看 @周林 给我们带来的深度神经网络的JS框架 Keras.js 的初体验。
简介
Keras是一款非常流行的深度学习模型开发框架,基于python,语法简洁,封装程度高,只需十几行代码就可以构建一个深度神经网络。
Keras.js是一个可以在浏览器中运行深度神经网络的JS框架,支持CPU,GPU计算。区别于Keras,Keras.js只能运行已经调试好的模型,无法进行模型训练。
KerasJS开发流程如下,首先使用Keras开发训练神经网络,将神经网络模型和参数导出为文件,KerasJS在浏览器端加载此文件,这样才能进行预测。
模型
借鉴这篇文章,开发一个识别圣诞老人的神经网络。本文不涉及Keras的开发细节,感兴趣的同学可以去原文查看。这里直接给出python代码。
数据
标注数据是AI模型的原料,数据搜集特别是图片搜集是前端可以介入的一个环节。笔者基于React,开发了一款Chrome图片批量下载插件GetThemAll,方便我们进行标记图片搜集。
安装好插件后,去谷歌图片搜索“santa”, 使用插件标记不需要的图片,然后下载到本地的santa文件夹,通过谷歌图片可以搜集到400张圣诞老人的图片。
接着我们再下载一些非圣诞老人的图片,搜索“object”,同样的使用GetThemAll插件下载大约400张图片到本地的non_santa文件夹中。
除了训练数据集,我们还需要一个测试数据集用来衡量模型的泛化能力。在本地新建一个test文件夹,把刚刚准备好的训练集里面的最后100张圣诞老人图片移到test文件夹下的santa文件中,同样的,移动100张非圣诞老人图片到non_stanta文件中。这样,你可以得到如下的本地图片集:
有了标记数据,我们就可以进行模型训练啦。具体的训练过程请见pyton代码,这里直接给出训练的结果,蓝点表示训练数据集准确率,蓝线表示测试数据集准备率,模型有着明显的High Variance问题,不过这个bug留给深度学习的专家们解决吧,这里就假设这个模型可用。
迁移
上一步训练出的模型keras_santa.h5(h5是文件后缀,和HTML5没啥关系)不能直接给KerasJS使用,需要通过KerasJS提供的转换工具转换后,方可被KerasJS加载解析。
转换后,得到了keras_santa.bin文件,20M左右,这个文件包含了神经网络结构和所有参数,可以被KerasJS加载。
KerasJS
通过上面的步骤,我们得到了一个训练完成的CNN神经网络以及全部参数,这个网络结构和参数全部保存在keras_santa.bin文件中。接下来,我们只需要在浏览器中复原上面的神经网络,然后就可以开始做预测啦。
使用webpack配合React,搭建一套简单的开发环境。做好了基础工作,就可以开始第一步开发,加载神经网络模型文件keras_santa.bin:
使用上面的模型做预测前,需要将数据转化成模型能够接受的数据格式。这个圣诞老人网络需要数据输入格式为(128,128,3),也即是图片需要为128x128分辨率,只能包含RGB三个分量。
借助canvas,可以实现图片分辨率转换:
注意preprocess方法,通过canvas获取到的图片资源包含了rgba四个维度,prepross返回这4个维度中的前3个维度,也即rgb,同时将数据标准化:
最后,使用上面返回的数据做预测
思考
可以看到,KerasJS在预测过程中,整个页面无法响应用户操作。这是因为神经网络计算过程中占用了大量CPU资源,从而致使页面卡顿。下一篇文章中,我们将介绍如何使用WebGL,将计算过程转移到GPU,达到实现前端高性能计算的目的。
同时,模型参数文件体积超过20M。如何对模型文件进行压缩,满足生产级别可用的要求,也是前端同学可以深挖的一个方向。
相关资源
Image classification with Keras and deep learning, Adrain Rosebrock
GetThemAll, eeandrew
React Keras,eeandrew
领取专属 10元无门槛券
私享最新 技术干货