这个来自https://medium.com/@benjamin.phillips22/simple-regression-with-neural-networks-in-pytorch-313f06910379的简单PyTorch代码没有从图1的数据中找到图2中预期的回归。我尝试了几个层(nb和size)、优化器、损失、学习率、时期。结果只是一条平坦的水平线。作为一个新手,我想我错过了一些东西。
数据:
大约预期结果:
####
#code from https://medium.com/@benjamin.phillips22/simple-regression-with-neural-networks-in-pytorch-313f06910379
####
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
import matplotlib.pyplot as plt
# %matplotlib inline
import numpy as np
import imageio
import pandas as pd
torch.manual_seed(1) # reproducible
x=torch.from_numpy(pd.read_csv('./x.csv').to_numpy())
y=torch.from_numpy(pd.read_csv('./y.csv').to_numpy())
x=x.float()
y=y.float()
# torch can only train on Variable, so convert them to Variable
x, y = Variable(x), Variable(y)
# view data
plt.figure(figsize=(10,4))
plt.scatter(x.data.numpy(), y.data.numpy(), color = "orange")
plt.title('regression')
plt.xlabel('x input')
plt.ylabel('y output')
plt.show()
net = torch.nn.Sequential(
torch.nn.Linear(1,20),
torch.nn.Sigmoid(),
torch.nn.Linear(20,20),
torch.nn.Sigmoid(),
torch.nn.Linear(20,1),
)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss() # this is for regression mean squared loss
my_images = []
fig, ax = plt.subplots(figsize=(12,7))
# train the network
for t in range(500):
prediction = net(x) # input x and predict based on x
loss = loss_func(prediction, y) # must be (1. nn output, 2. target)
optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
# plot and show learning process
plt.cla()
ax.set_xlabel('x input',fontsize=12)
ax.set_ylabel('y output',fontsize=12)
ax.set_xlim(min(x),max(x))
ax.set_ylim(min(y),max(y))
ax.scatter(x.data.numpy(), y.data.numpy(), color = "orange")
ax.plot(x.data.numpy(), prediction.data.numpy(), 'g-', lw=3)
ax.set_title('regression step='+str(t)+" loss="+str(loss.data.numpy()),fontsize=16)
# Used to return the plot as an image array
# (https://ndres.me/post/matplotlib-animated-gifs-easily/)
fig.canvas.draw() # draw the canvas, cache the renderer
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
my_images.append(image)
print(str(t)+" "+str(loss.data.numpy()))
# save images as a gif
imageio.mimsave('./result.gif', my_images, fps=10)
x.csv数据:从0到191,步骤1
y.csv数据:
0.316905
0.322015
0.332582
0.310465
0.250653
0.292722
0.297352
0.276525
0.308451
0.283465
0.276011
0.29934
0.307176
0.28573
0.248667
0.288614
0.249795
0.255556
0.258393
0.235972
0.225121
0.207828
0.190252
0.192881
0.204084
0.167646
0.155202
0.146516
0.162462
0.182906
0.160287
0.213769
0.186362
0.201151
0.186125
0.190625
0.146851
0.169204
0.207855
0.196557
0.208835
0.2244
0.206303
0.193485
0.185266
0.205616
0.229315
0.196254
0.219849
0.209988
0.197361
0.195402
0.210149
0.240754
0.210418
0.191776
0.189532
0.206153
0.165696
0.187938
0.157561
0.163148
0.19473
0.18966
0.162334
0.189277
0.166506
0.198193
0.157867
0.135192
0.152216
0.137521
0.142007
0.121252
0.136517
0.118812
0.126124
0.141713
0.13222
0.2032
0.156077
0.166526
0.167117
0.130817
0.167058
0.188566
0.178803
0.224779
0.217089
0.194542
0.199796
0.246194
0.249908
0.23034
0.204611
0.222958
0.24259
0.234767
0.278205
0.267297
0.275127
0.264059
0.25439
0.287421
0.267725
0.252964
0.256326
0.229031
0.276914
0.244985
0.273892
0.298103
0.256733
0.27219
0.301747
0.278291
0.274979
0.300091
0.310184
0.333836
0.297877
0.279405
0.278263
0.291442
0.278518
0.28268
0.321826
0.355584
0.315503
0.338342
0.39687
0.388692
0.353228
0.368169
0.328025
0.407137
0.38092
0.357814
0.362786
0.405149
0.354694
0.348222
0.295455
0.307671
0.290612
0.24626
0.229377
0.26535
0.217139
0.206268
0.230013
0.255796
0.27014
0.246626
0.224845
0.272181
0.201281
0.252555
0.270198
0.289443
0.243552
0.238465
0.207842
0.197373
0.238857
0.224703
0.259659
0.288809
0.24757
0.264744
0.250775
0.245659
0.193861
0.296178
0.234242
0.219704
0.264879
0.290614
0.296195
0.237291
0.208546
0.295197
0.272628
0.288054
0.293539
0.305374
0.328142
0.328574
0.307407
0.298107
0.315636
0.301924
发布于 2020-12-23 19:42:31
嗯,解决方案很简单:增加层数、层的节点数和时期数。示例: layer1=400节点、layer2=200节点、layer3=100节点、layer4=50节点、epochs=1500
https://stackoverflow.com/questions/65326577
复制相似问题