前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【生成模型】浅析玻尔兹曼机的原理和实践

【生成模型】浅析玻尔兹曼机的原理和实践

作者头像
用户1508658
发布2020-11-19 17:21:37
1.2K0
发布2020-11-19 17:21:37
举报
文章被收录于专栏:有三AI

这一期将介绍另一种生成模型—玻尔兹曼机,虽然它现在已经较少被提及和使用,但其对概率密度函数的处理方式能加深我们对生成模型的理解。

作者&编辑 | 小米粥

1 玻尔兹曼机

玻尔兹曼机属于另一种显式概率模型,它是一种基于能量的模型。训练玻尔兹曼机同样需要基于极大似然的思想,但在计算极大似然的梯度时,运用了一种不同于变分法的近似算法。玻尔兹曼机已经较少引起关注,故在此我们只简述。

在能量模型中,通常将样本的概率p(x)建模成如下形式:

其中,Z为配分函数。为了增强模型的表达能力,通常会在可见变量h的基础上增加隐变量v,以最简单的受限玻尔兹曼机RBM为例,RBM中的可见变量和隐变量均为二值离散随机变量(当然也可推广至实值)。它定义了一个无向概率图模型,并且为二分图,其中可见变量v组成一部分,隐藏变量h组成另一部分,可见变量之间不存在连接,隐藏变量之间也不存在连接(“受限”即来源于此),可见变量与隐藏变量之间实行全连接,结构如下图所示:

在RBM中,可见变量和隐藏变量的联合概率分布由能量函数给出,即

其中能量函数的表达式为

配分函数Z可写为

考虑到二分图的特殊结构,发现在隐藏变量已知时,可见变量之间彼此独立;当可见变量已知时,隐藏变量之间也彼此独立,即有

以及

进一步地,可得到离散概率的具体表达式:

为了使得RBM与能量模型有一致的表达式,定义可见变量v的自由能f(v)为

其中hi为第i个隐藏变量,此时可见变量的概率为

配分函数Z。使用极大似然法训练RBM模型时,需要计算似然函数的梯度,记模型的的参数为θ ,则

可以看出,RBM明确定义了可见变量的概率密度函数,但它并不易求解,因为计算配分函数 Z 需要对所有的可见变量v和隐藏变量h求积分,所以对数似然log p(v)也无法直接求解,故无法直接使用极大似然的思想训练模型。但是,若跳过对数似然函数的求解而直接求解对数似然函数的梯度,也可完成模型的训练。对于其中的权值、偏置参数有:

分析其梯度表达式,其中不易计算的部分在于对可见变量v的期望的计算。RBM通过采样的方法来对梯度进行近似,然后使用近似得到的梯度进行权值更新。为了采样得到可见变量v,可构建一个马尔科夫链并使其最终收敛到p(v),即马尔科夫链的平稳分布为p(v)。初始随机给定样本,迭代运行足够次数后达到平稳分布,这时可根据转移矩阵从模型分布p(v)连续采样得到样本。我们可使用吉布斯采样方法完成该过程,由于两部分变量的独立性,当固定可见变量(或隐藏变量)时,隐藏变量(可见变量)的分布分别为h(n+1) ~sigmoid(WTv(n)+c)和 v(n+1)~sigmoid(Wv(n+1)+b) ,即先采样得到隐藏变量,再采样得到可见变量,这样,我们便可以使用“随机最大似然”完成生成模型的训练了。

玻尔兹曼机依赖马尔可夫链来训练模型或者使用模型生成样本,但是这种技术现在已经很少被使用了,很可能是因为马尔可夫链近似技术不能被适用于像ImageNet的生成问题。并且,即便是马尔可夫链方法可以很好的用于训练,但是使用一个基于马尔可夫链的模型生成样本是需要花费很大计算代价。

2 玻尔兹曼机代码

代码语言:javascript
复制
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
 batch_size = 64
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([      transforms.ToTensor()])),batch_size=batch_size)
 test_loader = torch.utils.data.DataLoader(    datasets.MNIST('../data', train=False, transform=transforms.Compose([                       transforms.ToTensor()])),batch_size=batch_size)
 class RBM(nn.Module):
    def __init__(self, n_vis=784, n_hin=500, k=5):
        super(RBM, self).__init__()
        self.W = nn.Parameter(torch.randn(n_hin, n_vis) * 1e-2)
        self.v_bias = nn.Parameter(torch.zeros(n_vis))
        self.h_bias = nn.Parameter(torch.zeros(n_hin))
        self.k = k
    def sample_from_p(self, p):
        return F.relu(torch.sign(p - Variable(torch.rand(p.size()))))
    def v_to_h(self, v):
        p_h = F.sigmoid(F.linear(v, self.W, self.h_bias))
        sample_h = self.sample_from_p(p_h)
        return p_h, sample_h
    def h_to_v(self, h):
        p_v = F.sigmoid(F.linear(h, self.W.t(), self.v_bias))
        sample_v = self.sample_from_p(p_v)
        return p_v, sample_v
    def forward(self, v):
        pre_h1, h1 = self.v_to_h(v)
        h_ = h1
        for _ in range(self.k):
            pre_v_, v_ = self.h_to_v(h_)
            pre_h_, h_ = self.v_to_h(v_)
        return v, v_
    def free_energy(self, v):
        vbias_term = v.mv(self.v_bias)
        wx_b = F.linear(v, self.W, self.h_bias)
        hidden_term = wx_b.exp().add(1).log().sum(1)
        return (-hidden_term - vbias_term).mean()
 rbm = RBM(k=1)
 train_op = optim.SGD(rbm.parameters(),0.1)
 for epoch in range(10):
    loss_ = []
    for _, (data, target) in enumerate(train_loader):
        data = Variable(data.view(-1, 784))
        sample_data = data.bernoulli()
        print(sample_data[0])
        v, v1 = rbm(sample_data)
        loss = rbm.free_energy(v) - rbm.free_energy(v1)
        #loss_.append(loss.data[0])
        train_op.zero_grad() 
        loss.backward()
        train_op.step()
    print np.mean(loss_)
 show_adn_save("real",make_grid(v.view(32,1,28,28).data))
 show_adn_save("generate",make_grid(v1.view(32,1,28,28).data))
 def show_adn_save(file_name,img):
    npimg = np.transpose(img.numpy(),(1,2,0))
    f = "./%s.png" % file_name
    plt.imshow(npimg)
    plt.imsave(f, npimg)

[1] 伊恩·古德费洛, 约书亚·本吉奥, 亚伦·库维尔. 深度学习

[2]李航. 统计机器学习

总结

本期带大家学习了玻尔兹曼机,至此几种显式生成模型都介绍完了,除了显式模型就是大家非常熟悉的隐式生成模型了,其主要的代表是GAN,我们生态已经介绍过许多内容,大家可以去学习。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-11-16,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 有三AI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档