前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[机器学习]线性回归-基于tensorflow.js

[机器学习]线性回归-基于tensorflow.js

作者头像
周星星9527
发布2021-07-20 14:52:29
7920
发布2021-07-20 14:52:29
举报

《传热学》横掠管外对流换热系数测定实验中,奴赛尔数Nu与雷诺数Re的关系式,通过实验测定,并确定公式中的系数C和指数n。这里使用机器学习进行线性回归。

功能:输入x坐标和y坐标,进行线性拟合,并绘制曲线。

<!doctype html>
<html lang="en">
  <head>
    <!-- Required meta tags -->
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
    <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/echarts/dist/echarts.min.js"></script>
    <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/echarts-stat/dist/ecStat.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.2/dist/tf.min.js"></script>
    <title>Hello, world!</title>
  </head>
  <body>
    <div class="container">
      <h5>分别输入x坐标和y坐标,并用英文逗号隔开</h5>
      <form class="row g-3">
        <div class="input-group mb-3">
          <span class="input-group-text">x坐标</span>
          <input type="text" class="form-control" id="xs" placeholder="x坐标,如填写:1,2,3" value="1000,2000,3000" aria-label="xs" aria-describedby="basic-addon1">
        </div>
        <div class="input-group mb-3">
          <span class="input-group-text">y坐标</span>
          <input type="text" class="form-control" id="ys" placeholder="y坐标,如填写:2,4,8" value="1008,2123,2899" aria-label="ys" aria-describedby="basic-addon2">
        </div>
        <div class="col-12">
          <input type="button" class="btn btn-primary" onclick="learnCoefficients()" value="计算"></input>
        </div>
      </form>
    </div>

    <div class="row w100" id="container" style="height: 500px;border-color: black;"></div>

    <div class="modal fade" id="myModal" tabindex="-1" aria-labelledby="exampleModalLabel" aria-hidden="true">
      <div class="modal-dialog">
        <div class="modal-content">
          <div class="modal-header">
            <h5 class="modal-title" id="exampleModalLabel">计算结果</h5>
            <button type="button" class="btn-close" data-bs-dismiss="modal" aria-label="Close"></button>
          </div>
          <div class="modal-body" id="resultDiv">
            hi
          </div>
          <div class="modal-footer">
            <button type="button" class="btn btn-secondary" data-bs-dismiss="modal">关闭</button>
            <button type="button" class="btn btn-primary">重新计算</button>
          </div>
        </div>
      </div>
    </div>

    <script type="text/javascript">
      var m,n,max,min,realK,realB,sx,sy;

      function generateData(numPoints, coeff, sigma = 0.04) {
        var x=document.getElementById("xs").value.split(',');
        var y=document.getElementById("ys").value.split(',');

        x= Array.from(x).map((x,i) => {return parseFloat(x)});
        y= Array.from(y).map((y,i) => {return parseFloat(y)});
        sx=x;sy=y;
        console.log("Original xy data=",x,y);

        max = Math.max.apply(null, x);
        min = Math.min.apply(null, x);
        console.log("min & max:",min,max);

        var scaledX=x.map(x=>{return 2*(x-min)/(max-min)-1;});
        console.log("scaledX:",scaledX);

        m=2/(max-min);
        n=-2*min/(max-min)-1;
        console.log("m,n:",m,n);

        return tf.tidy(() => {    
          const xs = tf.tensor(scaledX);
          const ys = tf.tensor(y);
          return {xs:xs, ys: ys};
        })
      }
      
      // Step 1. Set up variables, these are the things we want the model
      const k = tf.variable(tf.scalar(Math.random()));
      const b = tf.variable(tf.scalar(Math.random()));
      
      
      // Step 2. Create an optimizer, we will use this later. You can play
      const numIterations = 75;
      const learningRate = 0.5;
      const optimizer = tf.train.sgd(learningRate);
      
      // Step 3. Write our training process functions.
      function predict(x) {// y = k * x + b
        return tf.tidy(() => {return k.mul(x).add(b);});
      }
      
      function loss(prediction, labels) {
        const error = prediction.sub(labels).square().mean();
        return error;
      }
      
      async function train(xs, ys, numIterations) {
        for (let iter = 0; iter < numIterations; iter++) {
      
          optimizer.minimize(() => {
            // Feed the examples into the model
            const pred = predict(xs);
            return loss(pred, ys);
          });
      
          await tf.nextFrame();
        }
      }
      
      
      async function learnCoefficients() {
        const trueCoefficients = {k: 0.6, b: 0.8};
        const trainingData = generateData(100, trueCoefficients);
        //console.log("trainingData:",trainingData);
      
        // 模型训练
        await train(trainingData.xs, trainingData.ys, numIterations);
      
        var xvals = await trainingData.xs.data();
        var yvals = await trainingData.ys.data();

        var sDatas = Array.from(sx).map((v,i) => {return [sx[i],sy[i]]});
        
        var [kk,bb]=[k.dataSync()[0],b.dataSync()[0]]
        console.log("kk & bb:",kk,bb);
        realK=kk*m;realB=kk*n+bb;
        console.log("realK & realB:",realK,realB);

        showResult(sDatas);

        document.getElementById("resultDiv").innerText="斜率:"+realK.toFixed(3)+"\n"+"截距:"+realB.toFixed(3);
        var myModal = new bootstrap.Modal(document.getElementById('myModal'), {keyboard: false});
        myModal.show();
      }
      
      function showResult(scatterData){
        var dom = document.getElementById("container");
        var myChart = echarts.init(dom);
        function realFun(x){return realK*x+realB;}//https://wenku.baidu.com/view/ed39e43ff12d2af90242e6fc.html
        var realData = [[min,realFun(min)],[max,realFun(max)]];
      
        var option = {
            title: {
                text: '线性函数拟合',
                left: 'center'
            },
            tooltip: {trigger: 'axis',axisPointer: {type: 'cross'}},
            xAxis: {
                type: 'value',
                splitLine: {lineStyle: {type: 'dashed'}},
            },
            yAxis: {
                type: 'value',
                splitLine: {lineStyle: {type: 'dashed'}}
            },
            series: [{
                name: 'scatter',
                type: 'scatter',
                label: {
                    emphasis: {
                        show: true,
                        position: 'left',
                        textStyle: {
                            color: 'blue',
                            fontSize: 16
                        }
                    }
                },
                data: scatterData
            }, {
                name: 'line',
                type: 'line',
                showSymbol: false,
                smooth: true,
                data: realData,
            }]
        };
      
        if (option && typeof option === "object") {
            myChart.setOption(option, true);
        }
      }
</script>
    <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.0.1/dist/js/bootstrap.bundle.min.js" integrity="sha384-gtEjrD/SeCtmISkJkNUaaKMoLD0//ElJ19smozuHV6z3Iehds+3Ulb9Bn9Plx0x4" crossorigin="anonymous"></script>
  </body>
</html>

运行结果:

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-06-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 传输过程数值模拟学习笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档