基于tensorflow.js的机器学习

机器学习是什么?我认为其实就是统计学的另一种花里胡哨/故弄玄虚的说法!tensorflow.js是一个机器学习的框架:

Develop ML with JavaScript, Use flexible and intuitive APIs to build and train models from scratch using the low-level JavaScript linear algebra library or the high-level layers API https://js.tensorflow.org/

我们结合tensorflow.js与百度echarts做一个最小二乘法的经典案例,线性回归例子:

<!DOCTYPE html>
<html style="height: 100%">
   <head><meta charset="utf-8"></head>
   <body style="height: 100%; margin: 0">
       <div id="container" style="height: 100%"></div>
       <script type="text/javascript" src="http://echarts.baidu.com/gallery/vendors/echarts/echarts.min.js"></script>
       <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.2/dist/tf.min.js"></script>
       <script type="text/javascript">
/**
 * @license 修改自官方案例,特此说明。
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

function generateData(numPoints, coeff, sigma = 0.04) {//产生伪随机数
  return tf.tidy(() => {
    const [k, b] = [tf.scalar(coeff.k),tf.scalar(coeff.b)];

    const xs = tf.randomUniform([numPoints], -1, 1);//x坐标
    const ys = k.mul(xs).add(b)//y坐标
      .add(tf.randomNormal([numPoints], 0, sigma));//叠加噪声

    return {xs, ys: ys};
  })
}

// Step 1. 要回归的变量
const k = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
// Step 2. 选取优化器、迭代次数等参数
const numIterations = 75;
const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);
// Step 3. 预测函数,定义为线性函数y = k * x + b
function predict(x) {// y = k * x + b
  return tf.tidy(() => {return k.mul(x).add(b);});
}
// Step 4. 计算方差,方差越小说明预测值越精确
function loss(prediction, labels) {
  const error = prediction.sub(labels).square().mean();
  return error;
}
// Step 5. 训练函数
async function train(xs, ys, numIterations) {
  for (let iter = 0; iter < numIterations; iter++) {
    //优化并使方差最小
    optimizer.minimize(() => {
      const pred = predict(xs);//根据输入数据预测输出值
      return loss(pred, ys);//计算预测值与训练数据间的方差
    });

    await tf.nextFrame();//
  }
}
//机器学习
async function learnCoefficients() {//
  const trueCoefficients = {k: 0.6, b: 0.8};//真实值
  const trainingData = generateData(100, trueCoefficients);//用于模型训练的数据
  
  await train(trainingData.xs, trainingData.ys, numIterations);// 模型训练

   var xvals = await trainingData.xs.data();//训练数据的x坐标值
   var yvals = await trainingData.ys.data();//训练数据的y坐标值
   var sDatas = Array.from(yvals).map((y,i) => {return [xvals[i],yvals[i]]});//整理训练数据以便绘图
 
  console.log("k&b:",k.dataSync()[0],b.dataSync()[0]);//经过训练后的系数
  showResult(sDatas,k.dataSync()[0],b.dataSync()[0]);
}
//绘制结果
function showResult(scatterData,k,b){
  var dom = document.getElementById("container");
  var myChart = echarts.init(dom);
  function realFun(x){return 0.6*x+0.8;}//理想曲线
  function factFun(x){return k*x+b;}//回归后的曲线
  var realData = [[-1,realFun(-1)],[1,realFun(1)]];
  var factData = [[-1,factFun(-1)],[1,factFun(1)]];

  var option = {
      title: {text: '通过机器学习进行数据的线性回归',left: 'left'},
      tooltip: {trigger: 'axis',axisPointer: {type: 'cross'}},
      xAxis: {type: 'value',splitLine: {lineStyle: {type: 'dashed'}},},
      yAxis: {type: 'value',splitLine: {lineStyle: {type: 'dashed'}}},
      series: [{
          name: '离散点',type: 'scatter',
          label: {
              emphasis: {
                  show: true,
                  position: 'left',
                  textStyle: {
                      color: 'blue',
                      fontSize: 16
                  }
              }
          },
          data: scatterData
      }, 
      {name: '理想曲线',type: 'line',showSymbol: false,data: realData,},
      {name: '回归曲线',type: 'line',showSymbol: false,data: factData,},],
      legend: {data:['离散点','理想曲线','回归曲线']},//图例文字
  };

  myChart.setOption(option, true);
}

learnCoefficients();
       </script>
   </body>
</html>

结果如下,回归曲线非常接近实际情况啦:

参考文献:

[1] Deqing L , Honghui M , Yi S ,et al. ECharts: A declarative framework for rapid construction of web-basedvisualization[J]. Visual Informatics, 2018:S2468502X18300068-.

原文发布于微信公众号 - 传输过程数值模拟学习笔记(SongSimStudio)

原文发表时间:2019-08-08

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券