前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于tensorflow.js的机器学习

基于tensorflow.js的机器学习

作者头像
周星星9527
发布2019-08-14 16:22:43
6590
发布2019-08-14 16:22:43
举报

机器学习是什么?我认为其实就是统计学的另一种花里胡哨/故弄玄虚的说法!tensorflow.js是一个机器学习的框架:

Develop ML with JavaScript, Use flexible and intuitive APIs to build and train models from scratch using the low-level JavaScript linear algebra library or the high-level layers API https://js.tensorflow.org/

我们结合tensorflow.js与百度echarts做一个最小二乘法的经典案例,线性回归例子:

代码语言:javascript
复制
<!DOCTYPE html>
<html style="height: 100%">
   <head><meta charset="utf-8"></head>
   <body style="height: 100%; margin: 0">
       <div id="container" style="height: 100%"></div>
       <script type="text/javascript" src="http://echarts.baidu.com/gallery/vendors/echarts/echarts.min.js"></script>
       <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.2/dist/tf.min.js"></script>
       <script type="text/javascript">
/**
 * @license 修改自官方案例,特此说明。
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

function generateData(numPoints, coeff, sigma = 0.04) {//产生伪随机数
  return tf.tidy(() => {
    const [k, b] = [tf.scalar(coeff.k),tf.scalar(coeff.b)];

    const xs = tf.randomUniform([numPoints], -1, 1);//x坐标
    const ys = k.mul(xs).add(b)//y坐标
      .add(tf.randomNormal([numPoints], 0, sigma));//叠加噪声

    return {xs, ys: ys};
  })
}

// Step 1. 要回归的变量
const k = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
// Step 2. 选取优化器、迭代次数等参数
const numIterations = 75;
const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);
// Step 3. 预测函数,定义为线性函数y = k * x + b
function predict(x) {// y = k * x + b
  return tf.tidy(() => {return k.mul(x).add(b);});
}
// Step 4. 计算方差,方差越小说明预测值越精确
function loss(prediction, labels) {
  const error = prediction.sub(labels).square().mean();
  return error;
}
// Step 5. 训练函数
async function train(xs, ys, numIterations) {
  for (let iter = 0; iter < numIterations; iter++) {
    //优化并使方差最小
    optimizer.minimize(() => {
      const pred = predict(xs);//根据输入数据预测输出值
      return loss(pred, ys);//计算预测值与训练数据间的方差
    });

    await tf.nextFrame();//
  }
}
//机器学习
async function learnCoefficients() {//
  const trueCoefficients = {k: 0.6, b: 0.8};//真实值
  const trainingData = generateData(100, trueCoefficients);//用于模型训练的数据
  
  await train(trainingData.xs, trainingData.ys, numIterations);// 模型训练

   var xvals = await trainingData.xs.data();//训练数据的x坐标值
   var yvals = await trainingData.ys.data();//训练数据的y坐标值
   var sDatas = Array.from(yvals).map((y,i) => {return [xvals[i],yvals[i]]});//整理训练数据以便绘图
 
  console.log("k&b:",k.dataSync()[0],b.dataSync()[0]);//经过训练后的系数
  showResult(sDatas,k.dataSync()[0],b.dataSync()[0]);
}
//绘制结果
function showResult(scatterData,k,b){
  var dom = document.getElementById("container");
  var myChart = echarts.init(dom);
  function realFun(x){return 0.6*x+0.8;}//理想曲线
  function factFun(x){return k*x+b;}//回归后的曲线
  var realData = [[-1,realFun(-1)],[1,realFun(1)]];
  var factData = [[-1,factFun(-1)],[1,factFun(1)]];

  var option = {
      title: {text: '通过机器学习进行数据的线性回归',left: 'left'},
      tooltip: {trigger: 'axis',axisPointer: {type: 'cross'}},
      xAxis: {type: 'value',splitLine: {lineStyle: {type: 'dashed'}},},
      yAxis: {type: 'value',splitLine: {lineStyle: {type: 'dashed'}}},
      series: [{
          name: '离散点',type: 'scatter',
          label: {
              emphasis: {
                  show: true,
                  position: 'left',
                  textStyle: {
                      color: 'blue',
                      fontSize: 16
                  }
              }
          },
          data: scatterData
      }, 
      {name: '理想曲线',type: 'line',showSymbol: false,data: realData,},
      {name: '回归曲线',type: 'line',showSymbol: false,data: factData,},],
      legend: {data:['离散点','理想曲线','回归曲线']},//图例文字
  };

  myChart.setOption(option, true);
}

learnCoefficients();
       </script>
   </body>
</html>

结果如下,回归曲线非常接近实际情况啦:

参考文献:

[1] Deqing L , Honghui M , Yi S ,et al. ECharts: A declarative framework for rapid construction of web-basedvisualization[J]. Visual Informatics, 2018:S2468502X18300068-.

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-08-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 传输过程数值模拟学习笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档