首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Tensorflow中的随机梯度下降在概念上似乎是错误的

Tensorflow中的随机梯度下降在概念上似乎是错误的
EN

Stack Overflow用户
提问于 2018-07-14 05:33:03
回答 2查看 624关注 0票数 1

我正在使用Tensorflow探索线性回归。这是我来自this notebook的代码。

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np
learning_rate = 0.01

x_train = np.linspace(-1,1,101)
y_train = 2*x_train + np.random.randn(*x_train.shape) * 0.33

X = tf.placeholder("float")
Y = tf.placeholder("float")
def model(X, w):
    return tf.multiply(X,w)
w = tf.Variable(0.0, name = "weights")

training_epochs = 100
y_model = model(X,w)
cost = tf.reduce_mean(tf.square(Y-y_model))
train_op = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    for epoch in range(training_epochs):
        for (x,y) in zip(x_train,y_train):
            sess.run(train_op, feed_dict = {X:x, Y: y})
        print(sess.run(w))

它试图最小化一个成本函数。根据这个问题的answers的说法,我认为tf.reduce_mean()会像np.mean()一样工作。

然而,每次将一对(x,y)馈送到train_op时,权重w似乎不会根据该对更新,而是根据之前的所有对更新。

对此有何解释?这与与优化器一起工作有关吗?

EN

回答 2

Stack Overflow用户

发布于 2018-07-14 06:46:37

我想回答我自己的问题。如果你认为这是精确的线性回归,那么这不是一个微不足道的问题。

  1. 我误解了tf.train.GradientDescentOptimizer的性能。它只运行一步来最小化损失函数,而不是最小值。因此,将数据提供给优化器的顺序很重要。因此,下面的代码将给出不同的答案。

for (x,y) in list(zip(x_train,y_train))[::-1]: sess.run(train_op, feed_dict = {X:x, Y: y})

总之,这段代码运行的不是严格的线性回归,而是它的近似值。

票数 0
EN

Stack Overflow用户

发布于 2018-07-14 06:48:50

如果您更改了这段代码

代码语言:javascript
运行
复制
for epoch in range(training_epochs):
    for (x,y) in zip(x_train,y_train):
        sess.run(train_op, feed_dict = {X:x, Y: y})

通过这个

代码语言:javascript
运行
复制
for (x,y) in zip(x_train,y_train):
    for epoch in range(training_epochs):
        sess.run(train_op, feed_dict = {X:x, Y: y})

你得到你想要的了吗?

在您的原始代码中,第一个循环引用迭代,因此您将修复梯度下降的第一个迭代,然后将其应用于所有先前的对(因为第二个循环引用所有先前的对),然后修复第二个迭代,并再次针对所有先前的对应用梯度下降,依此类推。

如果你像上面那样交换你的循环,那么你就是在修复一对,然后将所有的梯度下降迭代应用于这一对。我不确定这是不是你想要的。

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51333061

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档