首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >回归不起作用的简单案例(PyTorch)

回归不起作用的简单案例(PyTorch)
EN

Stack Overflow用户
提问于 2020-12-16 23:45:51
回答 1查看 43关注 0票数 1

这个来自https://medium.com/@benjamin.phillips22/simple-regression-with-neural-networks-in-pytorch-313f06910379的简单PyTorch代码没有从图1的数据中找到图2中预期的回归。我尝试了几个层(nb和size)、优化器、损失、学习率、时期。结果只是一条平坦的水平线。作为一个新手,我想我错过了一些东西。

数据:

大约预期结果:

代码语言:javascript
运行
复制
####
#code from https://medium.com/@benjamin.phillips22/simple-regression-with-neural-networks-in-pytorch-313f06910379
####

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
import matplotlib.pyplot as plt
# %matplotlib inline
import numpy as np
import imageio
import pandas as pd

torch.manual_seed(1)    # reproducible

x=torch.from_numpy(pd.read_csv('./x.csv').to_numpy())
y=torch.from_numpy(pd.read_csv('./y.csv').to_numpy())
x=x.float()
y=y.float()

# torch can only train on Variable, so convert them to Variable
x, y = Variable(x), Variable(y)

# view data
plt.figure(figsize=(10,4))
plt.scatter(x.data.numpy(), y.data.numpy(), color = "orange")
plt.title('regression')
plt.xlabel('x input')
plt.ylabel('y output')
plt.show()


net = torch.nn.Sequential(
 torch.nn.Linear(1,20),
 torch.nn.Sigmoid(),
 torch.nn.Linear(20,20),
 torch.nn.Sigmoid(),
 torch.nn.Linear(20,1),
)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss

my_images = []
fig, ax = plt.subplots(figsize=(12,7))

# train the network
for t in range(500):
  
    prediction = net(x)     # input x and predict based on x

    loss = loss_func(prediction, y)     # must be (1. nn output, 2. target)

    optimizer.zero_grad()   # clear gradients for next train
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
    
    # plot and show learning process
    plt.cla()
    ax.set_xlabel('x input',fontsize=12)
    ax.set_ylabel('y output',fontsize=12)
    ax.set_xlim(min(x),max(x))
    ax.set_ylim(min(y),max(y))
    ax.scatter(x.data.numpy(), y.data.numpy(), color = "orange")
    ax.plot(x.data.numpy(), prediction.data.numpy(), 'g-', lw=3)
    ax.set_title('regression step='+str(t)+" loss="+str(loss.data.numpy()),fontsize=16)

    # Used to return the plot as an image array 
    # (https://ndres.me/post/matplotlib-animated-gifs-easily/)
    fig.canvas.draw()       # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    my_images.append(image)
    print(str(t)+" "+str(loss.data.numpy()))
   
# save images as a gif    
imageio.mimsave('./result.gif', my_images, fps=10)

x.csv数据:从0到191,步骤1

y.csv数据:

代码语言:javascript
运行
复制
0.316905
0.322015
0.332582
0.310465
0.250653
0.292722
0.297352
0.276525
0.308451
0.283465
0.276011
0.29934
0.307176
0.28573
0.248667
0.288614
0.249795
0.255556
0.258393
0.235972
0.225121
0.207828
0.190252
0.192881
0.204084
0.167646
0.155202
0.146516
0.162462
0.182906
0.160287
0.213769
0.186362
0.201151
0.186125
0.190625
0.146851
0.169204
0.207855
0.196557
0.208835
0.2244
0.206303
0.193485
0.185266
0.205616
0.229315
0.196254
0.219849
0.209988
0.197361
0.195402
0.210149
0.240754
0.210418
0.191776
0.189532
0.206153
0.165696
0.187938
0.157561
0.163148
0.19473
0.18966
0.162334
0.189277
0.166506
0.198193
0.157867
0.135192
0.152216
0.137521
0.142007
0.121252
0.136517
0.118812
0.126124
0.141713
0.13222
0.2032
0.156077
0.166526
0.167117
0.130817
0.167058
0.188566
0.178803
0.224779
0.217089
0.194542
0.199796
0.246194
0.249908
0.23034
0.204611
0.222958
0.24259
0.234767
0.278205
0.267297
0.275127
0.264059
0.25439
0.287421
0.267725
0.252964
0.256326
0.229031
0.276914
0.244985
0.273892
0.298103
0.256733
0.27219
0.301747
0.278291
0.274979
0.300091
0.310184
0.333836
0.297877
0.279405
0.278263
0.291442
0.278518
0.28268
0.321826
0.355584
0.315503
0.338342
0.39687
0.388692
0.353228
0.368169
0.328025
0.407137
0.38092
0.357814
0.362786
0.405149
0.354694
0.348222
0.295455
0.307671
0.290612
0.24626
0.229377
0.26535
0.217139
0.206268
0.230013
0.255796
0.27014
0.246626
0.224845
0.272181
0.201281
0.252555
0.270198
0.289443
0.243552
0.238465
0.207842
0.197373
0.238857
0.224703
0.259659
0.288809
0.24757
0.264744
0.250775
0.245659
0.193861
0.296178
0.234242
0.219704
0.264879
0.290614
0.296195
0.237291
0.208546
0.295197
0.272628
0.288054
0.293539
0.305374
0.328142
0.328574
0.307407
0.298107
0.315636
0.301924
EN

回答 1

Stack Overflow用户

发布于 2020-12-23 19:42:31

嗯,解决方案很简单:增加层数、层的节点数和时期数。示例: layer1=400节点、layer2=200节点、layer3=100节点、layer4=50节点、epochs=1500

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65326577

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档