前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >在MNIST数据集上使用Pytorch中的Autoencoder进行维度操作

在MNIST数据集上使用Pytorch中的Autoencoder进行维度操作

作者头像
代码医生工作室
发布2019-09-10 15:20:08
3.4K0
发布2019-09-10 15:20:08
举报
文章被收录于专栏:相约机器人

通过理论与代码的联系来学习!

现在根据深度学习书,自动编码器是一种神经网络,经过训练旨在将其输入复制到其输出。在内部,它有一个隐藏层,用于描述用于表示输入的代码。网络可被视为由两部分组成:编码器功能“h = f(x)”和产生重建“r = g(h)”的解码器。

好的,知道你在想什么!只是另一篇没有正确解释的帖子?没有!那不是将如何进行的。将理论知识与代码逐步联系起来!这将有助于更好地理解并帮助在将来为任何ML问题建立直觉。

首先构建一个简单的自动编码器来压缩MNIST数据集。使用自动编码器,通过编码器传递输入数据,该编码器对输入进行压缩表示。然后该表示通过解码器以重建输入数据。通常,编码器和解码器将使用神经网络构建,然后在示例数据上进行训练。

但这些编码器和解码器到底是什么?

自动编码器的一般结构,通过内部表示或代码“h”将输入x映射到输出(称为重建)“r”。

自动

编码器有两个组成部分:编码器:它具有从x到h的映射,即f(映射x到h)

解码器:它具有从h到r的映射(即映射h到r)。

将了解如何连接此信息并在几段后将其应用于代码。

那么,这个“压缩表示”实际上做了什么呢?

压缩表示通常包含有关输入图像的重要信息,可以将其用于去噪图像或其他类型的重建和转换!它可以以比存储原始数据更实用的方式存储和共享任何类型的数据。

为编码器和解码器构建简单的网络架构,以了解自动编码器。

  • 总是首先导入我们的库并获取数据集。
  1. 将数据转换为torch.FloatTensor
  2. 加载训练和测试数据集
代码语言:javascript
复制
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
  • 然后像往常一样创建训练和测试数据加载器
  1. 用于数据加载的子进程数
  2. 每批加载多少个样品
  3. 准备数据加载器,现在如果自己想要尝试自动编码器的数据集,则需要创建一个特定于此目的的数据加载器。
代码语言:javascript
复制
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
  • 可视化数据:现在,这是可选的,但查看数据是否已正确加载始终是一个好习惯。可以通过
  1. 获得一批训练图像
  2. 然后从批处理中获取一个图像
代码语言:javascript
复制
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)

在案例中,应该看到类似的东西

现在,由于正在尝试学习自动编码器背后的概念,将从线性自动编码器开始,其中编码器和解码器应由一个线性层组成。连接编码器和解码器的单元将是压缩表示。

请注意,MNIST数据集的图像尺寸为28 * 28,因此将通过将这些图像展平为784(即28 * 28 = 784)长度向量来训练自动编码器。此外,来自此数据集的图像已经标准化,使得值介于0和1之间。

由于图像在0和1之间归一化,我们需要在输出层上使用sigmoid激活来获得与此输入值范围匹配的值。

  • 模型架构:这是自动编码器最重要的一步,因为试图实现与输入相同的目标!
  1. 定义NN架构:
  • 编码器:编码器将由一个线性层组成,其深度尺寸应如下变化:784输入 - > encoding_dim。

现在对于那些对编码维度(encoding_dim)有点混淆的人,将其视为输入和输出之间的中间维度,可根据需要进行操作,但其大小必须保持在输入和输出维度之间。

在下面的代码中,选择了encoding_dim = 32,这基本上就是压缩表示!

代码语言:javascript
复制
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
  • 训练:在这里,我将编写一些代码来训练网络。我对这里的验证不太感兴趣,所以让我们稍后观察训练损失和测试损失。

也不关心标签,在这种情况下,只是图像可以从train_loader获取。由于要比较输入和输出图像中的像素值,因此使用适用于回归任务的损失将是最有益的。回归就是比较数量而不是概率值。

代码语言:javascript
复制
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)

这是PyTorch非常简单的训练。将图像展平,将它们传递给自动编码器,然后记录训练损失。

代码语言:javascript
复制
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)

检查结果:

  1. 获得一批测试图像
  2. 获取样本输出
  3. 准备要显示的图像
  4. 输出大小调整为一批图像
  5. 当它是requires_grad的输出时使用detach
  6. 绘制前十个输入图像,然后重建图像
  7. 在顶行输入图像,在底部输入重建
代码语言:javascript
复制
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
# 5
output = output.detach().numpy()
# 6
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
# 7
for images, row in zip([images, output], axes):
 for img, ax in zip(images, row):
 ax.imshow(np.squeeze(img), cmap='gray')
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)

如果自动编码器成功地只是学习在任何地方设置g(f(x))= x,那么它就不是特别有用。相反,自动编码器被设计为无法学习完美复制。通常,它们的限制方式只允许它们大约复制,并且只复制类似于训练数据的输入。因为模型被迫优先考虑应该复制输入的哪些方面,所以它通常会学习数据的有用属性。

由于在这里处理图像,可以(通常)使用卷积层获得更好的性能。因此接下来可以做的是用卷积层构建一个更好的自动编码器。可以使用此处学到的基础知识作为带卷积层的自动编码器的基础。

如果想自己试用代码,请参考:线性自动编码器

https://github.com/Garima13a/Linear-Autoencoders

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-09-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 相约机器人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档