[MXNet逐梦之旅]练习二·使用MXNet拟合直线简洁实现

[MXNet逐梦之旅]练习二·使用MXNet拟合直线简洁实现

  • code
#%%
#%matplotlib inline
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random

#%%
num_inputs = 1
num_examples = 100
true_w = 1.56
true_b = 1.24
features = nd.arange(0,10,0.1).reshape((-1, 1))
labels = true_w * features + true_b
labels += nd.random.normal(scale=0.2, shape=labels.shape)
features[0], labels[0]



#%%
# 本函数已保存在d2lzh包中方便以后使用
from mxnet import gluon as gl
from mxnet.gluon import data as gdata

batch_size = 10
# 将训练数据的特征和标签组合
dataset = gl.data.ArrayDataset(features, labels)
# 随机读取小批量
data_iter = gl.data.DataLoader(dataset, batch_size, shuffle=True)


for X, y in data_iter:
    print(X, y)
    break



#%%
model = gl.nn.Sequential()

#%%

model.add(gl.nn.Dense(1))

model
#%%
import mxnet as mx

model.initialize(mx.init.Normal(sigma=0.01))

#%%



loss = gl.loss.L2Loss()  # 平方损失又称L2范数损失


#%%
trainer = gl.Trainer(model.collect_params(), 'adam', {'learning_rate': 0.5})
#%%

num_epochs = 10
for epoch in range(1, num_epochs + 1):
    for X, y in data_iter:
        with autograd.record():
            l = loss(model(X), y)
        l.backward()
        trainer.step(batch_size)
    l = loss(model(features), labels)
    print('epoch %d, loss: %f' % (epoch, l.mean().asnumpy()))




#%%
pre = model(features)
pre

plt.scatter(features.asnumpy(), labels.asnumpy(), 1)
plt.scatter(features.asnumpy(), pre.asnumpy(), 1)
plt.show()

#%%

print(model)

print("w:",model.collect_params()["dense0_weight"].data())
print("b:",model.collect_params()["dense0_bias"].data())
  • out
<NDArray 10x1 @cpu(0)>
epoch 1, loss: 5.570210
epoch 2, loss: 2.831637
epoch 3, loss: 0.995476
epoch 4, loss: 0.332262
epoch 5, loss: 0.060224
epoch 6, loss: 0.027413
epoch 7, loss: 0.031316
epoch 8, loss: 0.030222
epoch 9, loss: 0.027907
epoch 10, loss: 0.032840
Sequential(
  (0): Dense(1 -> 1, linear)
)
w:
[[1.5745053]]
<NDArray 1x1 @cpu(0)>
b:
[1.2476798]
<NDArray 1 @cpu(0)>

蓝色是原始数据

黄色为拟合数据

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Node开发

排序算法(1)---基本概念

排序与我们日常生活中息息相关,比如,我们要从电话簿中找到某个联系人首先会按照姓氏排序、买火车票会按照出发时间、买东西会按照销量排序、查找文件会按照最近修改时间排...

12120
来自专栏数据分析1480

如何用Python解决最优化问题?

现有5个广告投放渠道,分别是日间电视、夜间电视、网络媒体、平面媒体、户外广告,每个渠道的效果、费用及限制如下表所示:

93020
来自专栏Node开发

简单理解Token机制

互联网发展到现在已经到了一个非常成熟的时代,所以不再是一个你写一个静态网站就可以进行疯狂盈利的时代了。现在对产品有着很多的要求,健壮性,安全性这...

46110
来自专栏ChaMd5安全团队

QWB WriteUp

register的析构函数, 调用profile的__call函数,进而调用profile的upload_img函数

27320
来自专栏资深Tester

软件测试新人问题回复(一)

今天的文章是一个新入行的小伙伴咨询的一些问题,问题有点多,所以分成二次回复,针对这些问题,王豆豆觉得很适合刚入行、未对软件测试有过深了解的小伙伴们学习,故分享出...

19830
来自专栏Node开发

前后端分离之交互(1)

之前写过一篇文章讲到我对目前技术发展趋势的一些看法:我理解的技术发展趋势,里面其实有提到,现在比较流行MVVM,越来越多的公司开始采用前后端分...

77410
来自专栏Node开发

第三方登录(2)---GitHub登录

上一篇介绍了如何实现第三方QQ登录,其实都不涉及后端。在前端使用js就可以实现第三方QQ登录。然后如果有数据库操作可以发起ajax请求将登录得到的用户信息发给后...

21520
来自专栏数值分析与有限元编程

《Introduction to Programming with Fortran(2018)》 4th edition

本书官方网站: https://www.fortranplus.co.uk/ 提供相关源代码下载。

14120
来自专栏开源优测

数据库测试的重要性、组件和过程

企业级的服务系统通常是复杂的,一般都是多层设计,包括用户界面、业务逻辑、数据访问层和数据库层等。要确保服务按预期运行,所有这些层都需要一致且准确的协同工作。

20910
来自专栏Node开发

直接插入排序和直接选择排序

直接插入排序的基本思想是:每次将一个待排序的记录,按其关键字大小插入到前面已经排好序的序列中的适当位置,直到全部记录插入完成为止。

67610

扫码关注云+社区

领取腾讯云代金券

年度创作总结 领取年终奖励