机器学习是什么?我认为其实就是统计学的另一种花里胡哨/故弄玄虚的说法!tensorflow.js是一个机器学习的框架:
Develop ML with JavaScript, Use flexible and intuitive APIs to build and train models from scratch using the low-level JavaScript linear algebra library or the high-level layers API https://js.tensorflow.org/
我们结合tensorflow.js与百度echarts做一个最小二乘法的经典案例,线性回归例子:
<!DOCTYPE html>
<html style="height: 100%">
<head><meta charset="utf-8"></head>
<body style="height: 100%; margin: 0">
<div id="container" style="height: 100%"></div>
<script type="text/javascript" src="http://echarts.baidu.com/gallery/vendors/echarts/echarts.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.2/dist/tf.min.js"></script>
<script type="text/javascript">
/**
* @license 修改自官方案例,特此说明。
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function generateData(numPoints, coeff, sigma = 0.04) {//产生伪随机数
return tf.tidy(() => {
const [k, b] = [tf.scalar(coeff.k),tf.scalar(coeff.b)];
const xs = tf.randomUniform([numPoints], -1, 1);//x坐标
const ys = k.mul(xs).add(b)//y坐标
.add(tf.randomNormal([numPoints], 0, sigma));//叠加噪声
return {xs, ys: ys};
})
}
// Step 1. 要回归的变量
const k = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
// Step 2. 选取优化器、迭代次数等参数
const numIterations = 75;
const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);
// Step 3. 预测函数,定义为线性函数y = k * x + b
function predict(x) {// y = k * x + b
return tf.tidy(() => {return k.mul(x).add(b);});
}
// Step 4. 计算方差,方差越小说明预测值越精确
function loss(prediction, labels) {
const error = prediction.sub(labels).square().mean();
return error;
}
// Step 5. 训练函数
async function train(xs, ys, numIterations) {
for (let iter = 0; iter < numIterations; iter++) {
//优化并使方差最小
optimizer.minimize(() => {
const pred = predict(xs);//根据输入数据预测输出值
return loss(pred, ys);//计算预测值与训练数据间的方差
});
await tf.nextFrame();//
}
}
//机器学习
async function learnCoefficients() {//
const trueCoefficients = {k: 0.6, b: 0.8};//真实值
const trainingData = generateData(100, trueCoefficients);//用于模型训练的数据
await train(trainingData.xs, trainingData.ys, numIterations);// 模型训练
var xvals = await trainingData.xs.data();//训练数据的x坐标值
var yvals = await trainingData.ys.data();//训练数据的y坐标值
var sDatas = Array.from(yvals).map((y,i) => {return [xvals[i],yvals[i]]});//整理训练数据以便绘图
console.log("k&b:",k.dataSync()[0],b.dataSync()[0]);//经过训练后的系数
showResult(sDatas,k.dataSync()[0],b.dataSync()[0]);
}
//绘制结果
function showResult(scatterData,k,b){
var dom = document.getElementById("container");
var myChart = echarts.init(dom);
function realFun(x){return 0.6*x+0.8;}//理想曲线
function factFun(x){return k*x+b;}//回归后的曲线
var realData = [[-1,realFun(-1)],[1,realFun(1)]];
var factData = [[-1,factFun(-1)],[1,factFun(1)]];
var option = {
title: {text: '通过机器学习进行数据的线性回归',left: 'left'},
tooltip: {trigger: 'axis',axisPointer: {type: 'cross'}},
xAxis: {type: 'value',splitLine: {lineStyle: {type: 'dashed'}},},
yAxis: {type: 'value',splitLine: {lineStyle: {type: 'dashed'}}},
series: [{
name: '离散点',type: 'scatter',
label: {
emphasis: {
show: true,
position: 'left',
textStyle: {
color: 'blue',
fontSize: 16
}
}
},
data: scatterData
},
{name: '理想曲线',type: 'line',showSymbol: false,data: realData,},
{name: '回归曲线',type: 'line',showSymbol: false,data: factData,},],
legend: {data:['离散点','理想曲线','回归曲线']},//图例文字
};
myChart.setOption(option, true);
}
learnCoefficients();
</script>
</body>
</html>
结果如下,回归曲线非常接近实际情况啦:
参考文献:
[1] Deqing L , Honghui M , Yi S ,et al. ECharts: A declarative framework for rapid construction of web-basedvisualization[J]. Visual Informatics, 2018:S2468502X18300068-.
本文分享自 传输过程数值模拟学习笔记 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!