前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >无监督学习神经网络——自编码

无监督学习神经网络——自编码

作者头像
企鹅号小编
发布2018-02-07 10:29:39
3K0
发布2018-02-07 10:29:39
举报
文章被收录于专栏:企鹅号快讯企鹅号快讯

自编码是一种无监督学习的神经网络,主要应用在特征提取,对象识别,降维等。自编码器将神经网络的隐含层看成是一个编码器和解码器,输入数据经过隐含层的编码和解码,到达输出层时,确保输出的结果尽量与输入数据保持一致。也就是说,隐含层是尽量保证输出数据等于输入数据的。 这样做的一个好处是,隐含层能够抓住输入数据的特点,使其特征保持不变。例如,假设输入层有100个神经元,隐含层只有50个神经元,输出层有100个神经元,通过自动编码器算法,只用隐含层的50个神经元就找到了100个输入层数据的特点,能够保证输出数据和输入数据大致一致,就大大降低了隐含层的维度。 既然隐含层的任务是尽量找输入数据的特征,也就是说,尽量用最少的维度来代表输入数据,因此,隐含层各层之间的参数构成的参数矩阵,应该尽量是个稀疏矩阵,即各层之间有越多的参数为0就越好。

代码语言:python
复制
fromtorchimportoptim
fromtorchimportnnasnn
fromtorch.autogradimportVariable
fromtorch.utilsimportdata
fromtorchvisionimportdatasets
fromtorchvisionimporttransforms
#超参数
epochs =10
batch_size =64
lr =0.005
n_test_img =5
classAutoEncoder(nn.Module):
def__init(self):
    super(AutoEncoder, self).__init__()
# 压缩
self.encoder = nn.Sequential(
    nn.Linear(28*28,128),
    nn.Tanh(),
    nn.Linear(128,64),
    nn.Tanh(),
    nn.Linear(64,12),
    nn.Tanh(),
    nn.Linear(12,3)
)
# 解压
self.decoder = nn.Sequential(
    nn.Linear(3,12),
    nn.Tanh(),
    nn.Linear(12,64),
    nn.Tanh(),
    nn.Linear(64,128),
    nn.Tanh(),
    nn.Linear(128,28*28),
    nn.Sigmoid()
)
defforward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    returnencoded, decoded
    img_transform = transforms.Compose([transforms.ToTensor()])
if__name__ =='__main__':
    train_data = datasets.MNIST(root='./data', train=True,transform=img_transform, download=True)
    train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    autoencoder = AutoEncoder()
# 训练模型
optimizer = optim.Adam(autoencoder.parameters(), lr=lr)
loss = nn.MSELoss()
forepochinrange(epochs):
    forsetp, (x, y)inenumerate(train_loader):
        b_x = Variable(x.view(-1,28*28))
        b_y = Variable(x.view(-1,28*28))
        b_label = Variable(y)
        encoded, decoded = autoencoder(b_x)
        loss_data = loss(decoded, b_y)
        optimizer.zero_grad()
    loss.backward()
optimizer.step()

本文来自企鹅号 - 知之Python媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文来自企鹅号 - 知之Python媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档