用 TensorFlow.js 在浏览器中训练神经网络

本文结构:

  1. 什么是 TensorFlow.js
  2. 为什么要在浏览器中运行机器学习算法
  3. 应用举例:regression
  4. 和 tflearn 的代码比较

1. 什么是 TensorFlow.js

TensorFlow.js 是一个开源库,不仅可以在浏览器中运行机器学习模型,还可以训练模型。 具有 GPU 加速功能,并自动支持 WebGL 可以导入已经训练好的模型,也可以在浏览器中重新训练现有的所有机器学习模型 运行 Tensorflow.js 只需要你的浏览器,而且在本地开发的代码与发送给用户的代码是相同的。

TensorFlow.js 对未来 web 开发有着重要的影响,JS 开发者可以更容易地实现机器学习,工程师和数据科学家们可以有一种新的方法来训练算法,例如官网上 Emoji Scavenger Hunt 这样的游戏界面,让用户一边玩游戏一边将模型训练地更好。

用 Tensorflow.js 可以做很多事情, 例如 object detection in images, speech recognition, music composition, 而且 不需要安装任何库,也不用一次又一次地编译这些代码。


2. 为什么要在浏览器中运行机器学习算法

TensorFlow.js 可以为用户解锁巨大价值:

  1. 隐私:用户端的机器学习,用来训练模型的数据还有模型的使用都在用户的设备上完成,这意味着不需要把数据传送或存储在服务器上。
  2. 更广泛的使用:几乎每个电脑手机平板上都有浏览器,并且几乎每个浏览器都可以运行JS,无需下载或安装任何应用程序,在浏览器中就可以运行机器学习框架来实现更高的用户转换率,提高满意度,例如虚拟试衣间等服务。
  3. 分布式计算:每次用户使用系统时,他都是在自己的设备上运行机器学习算法,之后新的数据点将被推送到服务器来帮助改进模型,那么未来的用户就可以使用训练的更好的算法了,这样可以减少训练成本,并且持续训练模型。

3. 应用举例:regression

为了很快地看看效果,有下面三种方式:

  1. 可以直接从浏览器里写代码,例如 chrome 的 View > Developer > Javascript Console,
  2. 还可以在线写 有三个流行的在线 JS 平台:CodePen, JSFiddle, JSBin. https://codepen.io/thekevinscott/pen/aGapZL https://jsfiddle.net/ https://jsbin.com/?html,output
  3. 当然还可以在本地把代码保存为.html文件并用浏览器打开

那么先来看一下下面这段代码,可以在 codepen 中运行: https://codepen.io/pen?&editors=1011

这段代码的目的是做个回归预测,

数据集为: 构造符合 Y=2X-1 的几个点, 那么当 X 取 [-1, 0, 1, 2, 3, 4] 时, y 为 [-3, -1, 1, 3, 5, 7],

<html>

 <head>
    <!-- Load TensorFlow.js -->
    <!-- Get latest version at https://github.com/tensorflow/tfjs -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2">   
    </script>
 </head>
 
 <body>
   <div id="output_field"></div>
 </body>
 
 <script>
    async function learnLinear(){
    
        const model = tf.sequential();
        model.add(tf.layers.dense({
            units: 1, 
            inputShape: [1]
        }));
        
        model.compile({
            loss: 'meanSquaredError',
            optimizer: 'sgd'
        });
  
        const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
        const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
  
        await model.fit(xs, ys, {epochs: 500});
  
        document.getElementById('output_field').innerText =
            model.predict( tf.tensor2d([10], [1, 1]) );
    }
    
    learnLinear();
 </script>
 
<html>
  • 首先是熟悉的 js 的基础结构:
<html>
<head></head>
<body></body>
</html>
        const model = tf.sequential();
        model.add(tf.layers.dense({
            units: 1, 
            inputShape: [1]
        }));
  • 接着定义 loss 为 MSE 和 optimizer 为 SGD:
        model.compile({
            loss: 'meanSquaredError',
            optimizer: 'sgd'
        });
  • 同时需要定义 input 的 tensor,X 和 y,以及它们的维度都是 [6, 1]:
        const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
        const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
  • 然后用 fit 来训练模型,因为要等模型训练完才能预测,所以要用 await:
        await model.fit(xs, ys, {epochs: 500});
  • 训练结束后,用 predict 进行预测,输入的是 [1, 1] 维的 值为 10 的tensor ,
        document.getElementById('output_field').innerText =
            model.predict( tf.tensor2d([10], [1, 1]) );
  • 最后得到的输出为
Tensor 
[[18.9862976],]

4. 和 tflearn 的代码比较

再来通过一个简单的例子来比较一下 Tensorflow.js 和 tflearn, 可以看出如果熟悉 tflearn 的话,那么 Tensorflow.js 会非常容易上手,


学习资料: https://medium.com/tensorflow/getting-started-with-tensorflow-js-50f6783489b2 https://thekevinscott.com/reasons-for-machine-learning-in-the-browser/ https://www.analyticsvidhya.com/blog/2018/04/tensorflow-js-build-machine-learning-models-javascript/ https://hackernoon.com/introducing-tensorflow-js-3f31d70f5904 https://thekevinscott.com/tensorflowjs-hello-world/


推荐阅读 历史技术博文链接汇总 http://www.jianshu.com/p/28f02bb59fe5 也许可以找到你想要的: [入门问题][TensorFlow][深度学习][强化学习][神经网络][机器学习][自然语言处理][聊天机器人]

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能头条

TensorFlow 发布新版本v1.9(附应用实践教程)

【人工智能头条导读】TensorFlow 是一个开放源代码软件库,用于进行高性能数值计算。借助灵活的架构,用户可以轻松地将计算工作部署到多种平台(CPU、GPU...

523
来自专栏新智元

【谷歌新项目公开】无需学编程,用手机摄像头和浏览器即可机器学习

【新智元导读】谷歌最新的 Teachable Machine 项目,可以让用户无需编程就能利用摄像头采集数据、设计机器学习。作为 AI Experiment 的...

2565
来自专栏老秦求学

数据增强利器--Augmentor

Augmentor是一个Python包,旨在帮助机器学习任务的图像数据人工生成和数据增强。它主要是一种数据增强工具,但也将包含基本的图像预处理功能。

943
来自专栏人工智能LeadAI

pytorch入门教程 | 第四章:准备图片数据集

在训练神经网络之前,我们必须有数据,作为资深伸手党,必须知道以下几个数据提供源: 1 CIFAR-10 ? CIFAR-10图片样本截图 CIFAR-10是多...

7068
来自专栏iOSDevLog

人工智能的 "hello world":在 iOS 实现 MNIST 数学识别MNIST: http://yann.lecun.com/exdb/mnist/ 目标步骤

3208
来自专栏大数据文摘

干脆面君,你给我站住!你已经被TensorFlow盯上了

1593
来自专栏北京马哥教育

20行 Python 代码实现验证码识别

一、探讨 识别图形验证码可以说是做爬虫的必修课,涉及到计算机图形学,机器学习,机器视觉,人工智能等等高深领域…… 简单地说,计算机图形学的主要研究内容就是研究如...

4518
来自专栏机器之心

资源 | DMLC团队发布GluonCV和GluonNLP:两种简单易用的DL工具箱

选自 Gluon 机器之心编译 参与:思源、李亚洲 近日,DMLC 发布了简单易用的深度学习工具箱 GluonCV 和 GluonNLP,它们分别为计算机视觉和...

2828
来自专栏和蔼的张星的图像处理专栏

LCT代码跑起来先文章思路总结

论文才刚开始看,但是代码先跑了一下看结果,有一点小坑,记录下: 首先去论文的github上去下载代码:点这里 readme里其实写了怎么搞:

1163
来自专栏ATYUN订阅号

使用Java部署训练好的Keras深度学习模型

Keras库为深度学习提供了一个相对简单的接口,使神经网络可以被大众使用。然而,我们面临的挑战之一是将Keras的探索模型转化为产品模型。Keras是用Pyth...

974

扫码关注云+社区