前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >对 VAE 的理解与实现

对 VAE 的理解与实现

作者头像
为为为什么
发布2022-09-28 19:10:15
3120
发布2022-09-28 19:10:15
举报
文章被收录于专栏:又见苍岚又见苍岚

之前我们介绍过 ELBOVAE,本文记录我自己的理解与实现。

问题描述

  • 假设我们有来自某一未知分布 p 的随机变量观测样本集 X ,如何从 X 获取 p

构造生成器还是评估器

  • 对于某个分布,有两种方式可以描述这一分布 构造生成器:获取一个生产样本的生成器gg 生成的样本和 p 的样本来源相同(不可区分) 构造评估器:构造一个样本评估器 e ,对于给定的样本 xe 可以产生和 p 相同的概率密度, e(x) = p(x)
直接构造评估器路线
  • 解决上述问题的直观想法是构造一个关于参数 \alpha 的分 p_\alpha ,从其中选择与 p 最接近的,那么就让他们作为评估器,给定相同样本输出与 p 相同的概率密度就好啦~
  • 但问题是我们不知道样本 x 的真实概率密度 p(x) ,而且难以保证对于所有可能的 x 组成的集合 X_{all} ,我们的 \sum_{x\in {X_{all}}} p_\alpha(x) = 1 ,这应该是不可能完成的约束,因此直接构造评估器的路线并不现实
ELBO 路线
  • 如果我们有一组关于参数 \beta 的生成器族 g_\beta ,可以不断生成和 x 维度相同的数据, 优化 \beta 使得生成的数据和 p 生成的数据难以区分,我们就可以说得到了 p 的近似分布,GAN 基本上就延用了这个思路
  • 如果我们觉得直接用模型描述 X 分布困难或过于暴力,我们可以引入带有隐变量 z 的概率分布,也就走上了 ELBO 的生成模型 道路
  • 在ELBO 的生成模型中,我们为了描述复杂的概率分布引入了 z ,建立了 X,Z 的联合分布,但是这个 z 却是个大麻烦,因为我们的目标是 p ,这个分布和 Z 无关,仅和 X 有关,我们还得把 z 消掉
  • 直接的想法是对 z 积分, p_{\theta}(x)=\int p_{\theta}(x \mid z) p(z) d z ,可以蒙特卡洛积分计算,但是如果要求精度会很慢,因此我们转向贝叶斯的思路,也就走上了 ELBO 贝叶斯评估器 的道路
  • ELBO 的神奇之处在于同时结合了生成器和评估器的分布描述方式,在多处受阻的境况中巧妙运用贝叶斯公式找到了一种可以参数化、可以优化、贪心最大化变量 (ELBO) 的方法

VAE

  • 我理解 VAE 是对 ELBO 的直接实现
  • VAE 具象化了 ELBO 推导中的分布
p(z) = N(0,1)
p(z|x)=N(z;\mu (x), diag(\sigma(x)^2))
  • 直接优化 ELBO

  • 加入重参数化技巧实现训练过程
  • 当训练完成时,生成器(解码器)可以依赖 N(0,I) 上的采样生成近似 X 的样本,也就得到了近似 p 的生成器,以此近似描述 p 的分布

实现

  • 以瑞士卷(Swiss Roll) 数据作为目标分布 p
  • 瑞士卷数据集上实现 VAE,构造模仿瑞士卷分布的数据生成器
核心代码
代码语言:javascript
复制
class SimpleVAE(BaseVAE):
    def __init__(self, in_channels: int=2, latent_dim: int=2, hidden_dims: List = None) -> None:
        super(SimpleVAE, self).__init__()

        self.latent_dim = latent_dim

        if hidden_dims is None:
            hidden_dims = [128, 128]

        ori_in_channels = in_channels

        # Build Encoder
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(in_channels, h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)

        # Build Decoder
        modules = []
        de_hidden_dims = [hidden_dims[-1]] + hidden_dims

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1])
        hidden_dims.reverse()

        for i in range(len(de_hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.Linear(de_hidden_dims[i], de_hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Sequential(
                            nn.Linear(de_hidden_dims[-1], ori_in_channels))

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x in_channels]
        :return: (Tensor) List of latent codes [N x latent_dim]
        """
        result = self.encoder(input)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes onto the data space.
        
        :param z: (Tensor) [N x latent_dim]
        :return: (Tensor) [N x in_channels]
        """
        result = self.decoder_input(z)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [N x latent_dim]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [N x latent_dim]
        :return: (Tensor) [N x latent_dim]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var, z]

    def loss_function(self, forward_res, kld_weight) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        """
        recons = forward_res[0]
        input = forward_res[1]
        mu = forward_res[2]
        log_var = forward_res[3]

        recons_loss =F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':kld_loss.detach()}

代码仓库
效果展示

参考资料

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022年9月7日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 问题描述
  • 构造生成器还是评估器
    • 直接构造评估器路线
      • ELBO 路线
      • VAE
      • 实现
        • 核心代码
          • 代码仓库
            • 效果展示
            • 参考资料
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档