首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Tensorflow.js最小二乘法时间序列预测

Tensorflow.js最小二乘法时间序列预测
EN

Stack Overflow用户
提问于 2018-06-07 04:04:22
回答 1查看 3.6K关注 0票数 4

我正在尝试使用LSTM RNN在Tensorflow.js中构建一个简单的时间序列预测脚本。我显然是ML的新手。我一直在尝试从Keras RNN/LSTM层api改编我的JS代码,这显然是同一件事。从我收集的图层,形状等都是正确的。对我做错了什么有什么想法吗?

代码语言:javascript
运行
复制
async function predictfuture(){

    ////////////////////////
    // create fake data
    ///////////////////////

    var xs = tf.tensor3d([
        [[1],[1],[0]],
        [[1],[1],[0]],
        [[1],[1],[0]],
        [[1],[1],[0]],
        [[1],[1],[0]],
        [[1],[1],[0]]
    ]);
    xs.print();

    var ys = tf.tensor3d([
        [[1],[1],[0]],
        [[1],[1],[0]],
        [[1],[1],[0]],
        [[1],[1],[0]],
        [[1],[1],[0]],
        [[1],[1],[0]]
    ]);
    ys.print();


    ////////////////////////
    // create model w/ layers api
    ///////////////////////

    console.log('Creating Model...');

    /*

    model design:

                    i(xs)   h       o(ys)
    batch_size  ->  *       *       * -> batch_size
    timesteps   ->  *       *       * -> timesteps
    input_dim   ->  *       *       * -> input_dim


    */

    const model = tf.sequential();

    //hidden layer
    const hidden = tf.layers.lstm({
        units: 3,
        activation: 'sigmoid',
        inputShape: [3 , 1]
    });
    model.add(hidden);

    //output layer
    const output = tf.layers.lstm({
        units: 3,
        activation: 'sigmoid',
        inputShape: [3] //optional
    });
    model.add(output);

    //compile
    const sgdoptimizer = tf.train.sgd(0.1)
    model.compile({
        optimizer: sgdoptimizer,
        loss: tf.losses.meanSquaredError
    });

    ////////////////////////
    // train & predict
    ///////////////////////

    console.log('Training Model...');

    await model.fit(xs, ys, { epochs: 200 }).then(() => {

        console.log('Training Complete!');
        console.log('Creating Prediction...');

        const inputs = tf.tensor2d( [[1],[1],[0]] );
        let outputs = model.predict(inputs);
        outputs.print();

    });

}

predictfuture();

我的错误是:

EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50728673

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档