前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >100行代码手写diffusion

100行代码手写diffusion

作者头像
Tom2Code
发布2023-09-12 17:55:58
4700
发布2023-09-12 17:55:58
举报
文章被收录于专栏:TomTom

前向过程

在扩散模型的前向过程中,数据会经历一个连续的噪声添加过程,直到原始信号完全退化,从而得到一个良好的分布,通常是高斯分布。在每个时间步长 t,我们使用高斯分布 q(x_{t+1}|x_{t})从当前数据点 x_{t} 采样下一个数据点,表示为 x_{t+1}。该分布的均值和协方差由一组称为贝塔的超参数决定。

在对数据点 x_{t} 进行采样时,我们不需要使用在 t 个时间步长内进行迭代的天真解决方案,因为条件分布 q(x_t|x_0) 已被证明具有可行的解析形式。因此,我们可以直接将 x_t 作为 x_0 的函数进行采样,这在实现过程的第 7 行完成。这种方法无需耗时的迭代循环,可以根据 x_0 对 x_t 进行高效采样。

在第 7 行使用 q(x_t|x_0) 的均值和协方差对 x_t 进行采样时,我们需要 q(x_{t-1}|x_t, x_0) 的均值和标准差来计算损失。这些变量可以用贝叶斯定理分析计算,然后直接并入代码中

代码语言:javascript
复制
 def forward_process(self, x0, t):

        t = t - 1  # Start indexing at 0
        beta_forward = self.beta[t]
        alpha_forward = self.alpha[t]
        alpha_cum_forward = self.alpha_bar[t]
        xt = x0 * torch.sqrt(alpha_cum_forward) + torch.randn_like(x0) * torch.sqrt(1. - alpha_cum_forward)

        mu1_scl = torch.sqrt(alpha_cum_forward / alpha_forward)
        mu2_scl = 1. / torch.sqrt(alpha_forward)
        cov1 = 1. - alpha_cum_forward / alpha_forward
        cov2 = beta_forward / alpha_forward
        lam = 1. / cov1 + 1. / cov2
        mu = (x0 * mu1_scl / cov1 + xt * mu2_scl / cov2) / lam
        sigma = torch.sqrt(1. / lam)
        return mu, sigma, xt

逆过程

扩散模型中反向过程的目的是在每个时间步长 t 近似正向过程的逆过程。这意味着一旦模型训练完成,我们就可以直接从良好分布中采样噪声,并利用反向过程生成数据样本,如二维图像。

反向过程的实现非常简单。由于模型的构造,q(x_{t-1}|x_{t-1}) 是一个标准偏差很小的高斯分布,因此 p(x_{t-1}|x_t) 也是一个高斯分布。因此,可以通过其均值和协方差矩阵对其进行参数化。我们的目标是训练模型,使其在每个时间步精确估计这些参数,从而生成高质量的合成数据。

在模型逼近分布 p(x_{t-1}|x_t) 的参数(均值和协方差)后,就可以利用这些参数从分布中采样。

代码语言:javascript
复制
def reverse(self, xt, t):

        t = t - 1  # Start indexing at 0
        if t == 0: return None, None, xt
        mu, h = self.model(xt, t).chunk(2, dim=1)
        sigma = torch.sqrt(torch.exp(h))
        samples = mu + torch.randn_like(xt) * sigma
        return mu, sigma, samples

采样

反向过程一旦实施,就可以递归调用,首先从良好分布中抽取噪声。通过这种递归调用,反向过程生成的数据有望与训练数据相似。

代码语言:javascript
复制
def sample(self, size, device):
        noise = torch.randn((size, 2)).to(device)
        samples = [noise]
        for t in range(self.n_steps):
            _, _, x = self.reverse(samples[-1], self.n_steps - t - 1 + 1)
            samples.append(x)
        return samples

构造函数

为了完成模型,我们将创建构造函数。它需要几个输入,包括随时间变化的模型、扩散步数以及执行该函数的设备。在第 10 行,我们定义了超参数 beta,它代表前向过程中每个扩散步骤的方差。此外,在第 12 行和第 13 行,我们定义了 beta 的函数变量,这有助于使其他部分的实现更加简洁。

代码语言:javascript
复制
class DiffusionModel(nn.Module):

    def __init__(self, model: nn.Module, n_steps=40, device='cuda'):
        super().__init__()

        self.model = model
        self.device = device

        betas = torch.linspace(-18, 10, n_steps)
        self.beta = torch.sigmoid(betas) * (3e-1 - 1e-5) + 1e-5

        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.n_steps = n_steps
        self.sigma2 = self.beta

训练

在训练过程中,我们的目标是使模型的对数似然最大化,但遗憾的是,模型的对数似然无法通过分析计算出来。不过,另一种方法是计算模型似然的下限,正如本文所演示的那样。这个下限涉及计算 q(x_{t-1}|x_t) 和 p(x_{t-1}|x_t) 分布之间的库尔巴克-莱伯勒(KL)发散。由于两个分布都是高斯分布,KL 发散可以直接用著名的解析形式计算。通过优化下限,我们可以间接地最大化模型的对数似然,并在训练过程中提高模型的性能。

代码语言:javascript
复制
def train(model, optimizer, nb_epochs=150_000, batch_size=64_000):
    training_loss = []
    for _ in tqdm(range(nb_epochs)):
        x0 = torch.from_numpy(sample_batch(batch_size)).float().to(device)
        t = np.random.randint(2, 40 + 1)
        mu_posterior, sigma_posterior, xt = model.forward_process(x0, t)
        mu, sigma, _ = model.reverse(xt, t)

        KL = (torch.log(sigma) - torch.log(sigma_posterior) + (sigma_posterior ** 2 + (mu_posterior - mu) ** 2) / (
                2 * sigma ** 2) - 0.5)
        loss = KL.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        training_loss.append(loss.item())

组合在一起

最后,可以毫不费力地将所有组件组合在一起,形成完整的实施方案。经过几个小时的训练,该模型应该能够生成一个出色的瑞士卷分布生成模型。

代码语言:javascript
复制
if __name__ == "__main__":
    device = 'cuda'
    model_mlp = MLP(hidden_dim=128).to(device)
    model = DiffusionModel(model_mlp)
    optimizer = torch.optim.Adam(model_mlp.parameters(), lr=1e-4)
    train(model, optimizer)

完整代码:

代码语言:javascript
复制
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
import torch.utils.data
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll


def sample_batch(size):
    x, _ = make_swiss_roll(size)
    return x[:, [2, 0]] / 10.0 * np.array([1, -1])


class MLP(nn.Module):

    def __init__(self, N=40, data_dim=2, hidden_dim=64):
        super(MLP, self).__init__()

        self.network_head = nn.Sequential(nn.Linear(data_dim, hidden_dim), nn.ReLU(),
                                          nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        self.network_tail = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                                         nn.ReLU(), nn.Linear(hidden_dim, data_dim * 2)
                                                         ) for _ in range(N)])

    def forward(self, x, t: int):
        h = self.network_head(x)
        return self.network_tail[t](h)


class DiffusionModel(nn.Module):

    def __init__(self, model: nn.Module, n_steps=40, device='cuda'):
        super().__init__()

        self.model = model
        self.device = device

        betas = torch.linspace(-18, 10, n_steps)
        self.beta = torch.sigmoid(betas) * (3e-1 - 1e-5) + 1e-5

        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.n_steps = n_steps
        self.sigma2 = self.beta

    def forward_process(self, x0, t):

        t = t - 1  # Start indexing at 0
        beta_forward = self.beta[t]
        alpha_forward = self.alpha[t]
        alpha_cum_forward = self.alpha_bar[t]
        xt = x0 * torch.sqrt(alpha_cum_forward) + torch.randn_like(x0) * torch.sqrt(1. - alpha_cum_forward)
        # Retrieved from https://github.com/Sohl-Dickstein/Diffusion-Probabilistic-Models/blob/master/model.py#L203
        mu1_scl = torch.sqrt(alpha_cum_forward / alpha_forward)
        mu2_scl = 1. / torch.sqrt(alpha_forward)
        cov1 = 1. - alpha_cum_forward / alpha_forward
        cov2 = beta_forward / alpha_forward
        lam = 1. / cov1 + 1. / cov2
        mu = (x0 * mu1_scl / cov1 + xt * mu2_scl / cov2) / lam
        sigma = torch.sqrt(1. / lam)
        return mu, sigma, xt

    def reverse(self, xt, t):

        t = t - 1  # Start indexing at 0
        if t == 0: return None, None, xt
        mu, h = self.model(xt, t).chunk(2, dim=1)
        sigma = torch.sqrt(torch.exp(h))
        samples = mu + torch.randn_like(xt) * sigma
        return mu, sigma, samples

    def sample(self, size, device):
        noise = torch.randn((size, 2)).to(device)
        samples = [noise]
        for t in range(self.n_steps):
            _, _, x = self.reverse(samples[-1], self.n_steps - t - 1 + 1)
            samples.append(x)
        return samples


def plot(model):
    plt.figure(figsize=(10, 6))
    x0 = sample_batch(5000)
    x20 = model.forward_process(torch.from_numpy(x0).to(device), 20)[-1].data.cpu().numpy()
    x40 = model.forward_process(torch.from_numpy(x0).to(device), 40)[-1].data.cpu().numpy()
    data = [x0, x20, x40]
    for i, t in enumerate([0, 20, 39]):
        plt.subplot(2, 3, 1 + i)
        plt.scatter(data[i][:, 0], data[i][:, 1], alpha=.1, s=1)
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.gca().set_aspect('equal')
        if t == 0: plt.ylabel(r'$q(\mathbf{x}^{(0...T)})$', fontsize=17, rotation=0, labelpad=60)
        if i == 0: plt.title(r'$t=0$', fontsize=17)
        if i == 1: plt.title(r'$t=\frac{T}{2}$', fontsize=17)
        if i == 2: plt.title(r'$t=T$', fontsize=17)

    samples = model.sample(5000, device)
    for i, t in enumerate([0, 20, 40]):
        plt.subplot(2, 3, 4 + i)
        plt.scatter(samples[40 - t][:, 0].data.cpu().numpy(), samples[40 - t][:, 1].data.cpu().numpy(),
                    alpha=.1, s=1, c='r')
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.gca().set_aspect('equal')
        if t == 0: plt.ylabel(r'$p(\mathbf{x}^{(0...T)})$', fontsize=17, rotation=0, labelpad=60)
    plt.savefig(f"Imgs/diffusion_model.png", bbox_inches='tight')
    plt.close()


def train(model, optimizer, nb_epochs=150_000, batch_size=64_000):
    training_loss = []
    for _ in tqdm(range(nb_epochs)):
        x0 = torch.from_numpy(sample_batch(batch_size)).float().to(device)
        t = np.random.randint(2, 40 + 1)
        mu_posterior, sigma_posterior, xt = model.forward_process(x0, t)
        mu, sigma, _ = model.reverse(xt, t)

        KL = (torch.log(sigma) - torch.log(sigma_posterior) + (sigma_posterior ** 2 + (mu_posterior - mu) ** 2) / (
                2 * sigma ** 2) - 0.5)
        loss = KL.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        training_loss.append(loss.item())


if __name__ == "__main__":
    device = 'cuda'
    model_mlp = MLP(hidden_dim=128).to(device)
    model = DiffusionModel(model_mlp)
    optimizer = torch.optim.Adam(model_mlp.parameters(), lr=1e-4)
    train(model, optimizer)
    plot(model)

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-09-04 17:42,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Tom的小院 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档