前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >生成一切的基础,DiT复现

生成一切的基础,DiT复现

作者头像
Srlua
发布2024-11-30 10:55:10
发布2024-11-30 10:55:10
64500
代码可运行
举报
文章被收录于专栏:CSDN社区搬运CSDN社区搬运
运行总次数:0
代码可运行

概述

DiT的全称是Diffusion in Transformer,它将Transformer引入到扩散模型中,替换了Stable Diffusion中常用的U-Net主干。通过增加Transformer深度/宽度或增加输入令牌数量,具有较高Gflops(浮点数运算次数)的DiT始终具有较低的FID(Fréchet初始距离,用于描述生成的图片和原始的图片之间的距离)。本文会讲解扩散模型的原理,从零开始逐步复现DiT模型。

演示效果

如下为DiT模型生成的图片效果

扩散模型

扩散模型(Diffusion Models)是一类用于生成数据(如图像、音频等)的深度学习模型,它们在近年来逐渐成为生成模型领域的重要研究方向,尤其是在图像生成任务中表现出与生成对抗网络(GANs)相媲美甚至更优的性能。

基本原理

扩散模型的基本原理是模拟数据从真实分布逐渐转变到已知分布(通常是高斯噪声分布)的过程,然后再从这个已知分布逆向恢复出数据。这个过程可以分为两个主要阶段:扩散(或正向过程)和逆扩散(或反向过程)。

扩散(正向过程):

从原始数据开始,逐步添加随机噪声,直到数据变成不包含任何原始信息的高斯噪声。 这个过程通常是固定的,不需要学习。每一步添加的噪声量可以是一个预定义的 schedule。 扩散过程可以用一个马尔可夫链来表示,每一步都依赖于前一步的状态。

具体过程

初始化:从真实数据 x0x0​,开始,它服从数据分布PdataPdata​。 逐步添加噪声:在时间步 t(从 0 到 T),逐步向数据中添加噪声,直到数据变成不包含任何原始信息的高斯噪声。这个过程可以用以下公式表示: q(xt∣x0)=N(xt;αtx0,(1−αt)I)q(xt​∣x0​)=N(xt​;αt​​x0​,(1−αt​)I) 累积噪声:在每一步,我们都有xtxt​的条件分布,但通常我们需要得到xtxt​关于x0x0​的分布。这可以通过累积每一步的噪声来实现: q(xt∣x0)=N(xt;αtx0,(1−αt)I)q(xt​∣x0​)=N(xt​;αt​​x0​,(1−αt​)I) 其中,αt=∏i=1t(1−βi)αt​=∏i=1t​(1−βi​)是累积的噪声比例。

逆扩散(反向过程):

从高斯噪声开始,通过学习到的模型逐步去除噪声,恢复出原始数据。 这个过程需要通过神经网络来学习,网络需要预测在每一步应该如何去除噪声,以恢复出接近原始数据的状态。 逆扩散过程通常涉及到条件概率的估计,即给定当前噪声数据,预测原始数据在该步骤的潜在状态。

具体过程

初始化:从高斯噪声 xTxT​开始,它服从标准高斯分布N(xT;0,I)N(xT​;0,I)。 逐步去噪:在时间步 t(从 T 到 0),逐步去除噪声以恢复原始数据。这个过程通常是通过神经网络pθ(xt−1∣xt)pθ​(xt−1​∣xt​)来预测的。神经网络的训练目标是最大化以下似然函数的对数:log⁡pθ(x0)≈∑t=1Tlog⁡pθ(xt−1∣xt)logpθ​(x0​)≈∑t=1T​logpθ​(xt−1​∣xt​) 其中,pθ(xt−1∣xt)pθ​(xt−1​∣xt​)是神经网络预测的xt−1xt−1​给定的xtxt​的条件概率。 预测和重参数化:在每一步,神经网络预测xt−1xt−1​的均值和方差,然后通过重参数化技巧从这些参数中采样。这个过程可以用以下公式表示: xt−1=μθ(xt,t)+σθ(xt,t)⋅ϵxt−1​=μθ​(xt​,t)+σθ​(xt​,t)⋅ϵ 其中,ϵϵ是从标准高斯分布中采样的噪声,μθμθ​和σθσθ​是神经网络预测的均值和标准差。

DiT架构综述

DiT的核心特点:

结合Transformer架构的扩散模型:DiT使用Transformer作为扩散模型的骨干网络,而不是传统的卷积神经网络(如U-Net)。这使得模型能够通过自注意力机制捕捉图像中的长距离依赖关系。 潜在空间操作:DiT在潜在空间中训练,通常比直接在像素空间训练更高效。通过使用变分自编码器(VAE)将图像编码到潜在空间,DiT减少了计算复杂度。 可扩展性:DiT展示了出色的可扩展性,通过增加模型的计算量(以Gflops衡量),可以显著提高生成图像的质量。这种可扩展性允许DiT在不同的分辨率和复杂度下生成图像。

DiT的工作原理:

数据准备:使用预训练的变分自编码器(VAE)将输入图像编码成潜在空间的表示。 分块化(Patchification):将输入的潜在表示分割成一系列的小片段(patches),每个片段对应于Transformer模型的一个输入标记(token)。 Transformer Blocks模块:输入的标记序列通过一系列的Transformer块进行处理,这些块包括自注意力层、前馈神经网络以及层归一化等组件。 条件扩散过程:在训练过程中,DiT模型学习逆向扩散过程,即从噪声数据中恢复出清晰的图像。 样本生成:在训练完成后,可以通过DiT模型生成新的图像。首先,从标准正态分布中采样一个潜在表示,然后通过DiT模型逆向扩散过程,逐步去除噪声,最终解码回像素空间,得到生成的图像。

代码

前向加噪

代码语言:javascript
代码运行次数:0
运行
复制
import torch
from config import *

# 前向diffusion计算参数
betas=torch.linspace(0.0001,0.02,T) # (T,)
alphas=1-betas  # (T,)
alphas_cumprod=torch.cumprod(alphas,dim=-1) # alpha_t累乘 (T,)    [a1,a2,a3,....] ->  [a1,a1*a2,a1*a2*a3,.....]
alphas_cumprod_prev=torch.cat((torch.tensor([1.0]),alphas_cumprod[:-1]),dim=-1) # alpha_t-1累乘 (T,),  [1,a1,a1*a2,a1*a2*a3,.....]
variance=(1-alphas)*(1-alphas_cumprod_prev)/(1-alphas_cumprod)  # denoise用的方差   (T,)

# 执行前向加噪
def forward_add_noise(x,t): # batch_x: (batch,channel,height,width), batch_t: (batch_size,)
    noise=torch.randn_like(x)   # 为每张图片生成第t步的高斯噪音   (batch,channel,height,width)
    batch_alphas_cumprod=alphas_cumprod[t].view(x.size(0),1,1,1) 
    x=torch.sqrt(batch_alphas_cumprod)*x+torch.sqrt(1-batch_alphas_cumprod)*noise # 基于公式直接生成第t步加噪后图片
    return x,noise

DiT模型

第一步:定义PatchEmbedding

代码语言:javascript
代码运行次数:0
运行
复制
class PatchEmbedding(nn.Module):
    def __init__(self,channel,patch_size,emb_size):
        super().__init__()
        self.patch_size=patch_size
        self.patch_count=img_size//self.patch_size
        self.channel=channel
        self.conv=nn.Conv2d(in_channels=channel,out_channels=channel*patch_size**2,kernel_size=patch_size,padding=0,stride=patch_size) 
        self.patch_emb=nn.Linear(in_features=channel*patch_size**2,out_features=emb_size) 
        self.patch_pos_emb=nn.Parameter(torch.rand(1,self.patch_count**2,emb_size))
    def forward(self,x):
        x=self.conv(x)  # (batch,new_channel,patch_count,patch_count)
        x=x.permute(0,2,3,1)    # (batch,patch_count,patch_count,new_channel)
        x=x.view(x.size(0),self.patch_count*self.patch_count,x.size(3)) # (batch,patch_count**2,new_channel)
        
        x=self.patch_emb(x) # (batch,patch_count**2,emb_size)
        x=x+self.patch_pos_emb # (batch,patch_count**2,emb_size)
        return x

第二步:定义time_embedding

代码语言:javascript
代码运行次数:0
运行
复制
class TimeEmbedding(nn.Module):
    def __init__(self,emb_size):
        super().__init__()
        self.half_emb_size=emb_size//2
        half_emb=torch.exp(torch.arange(self.half_emb_size)*(-1*math.log(10000)/(self.half_emb_size-1)))
        self.register_buffer('half_emb',half_emb)

    def forward(self,t):
        t=t.view(t.size(0),1)
        half_emb=self.half_emb.unsqueeze(0).expand(t.size(0),self.half_emb_size)
        half_emb_t=half_emb*t
        embs_t=torch.cat((half_emb_t.sin(),half_emb_t.cos()),dim=-1)
        return embs_t

第三步:定义DiT Block

代码语言:javascript
代码运行次数:0
运行
复制
class DiTBlock(nn.Module):
    def __init__(self,emb_size,nhead=4):
        super().__init__()
        # 定义各种参数,包括线形层,Transformer
    def forward(self,x,condition):
        # 计算

第四步:定义DiT

代码语言:javascript
代码运行次数:0
运行
复制
class DiT(nn.Module):
    def __init__(self,img_size,patch_size,channel,emb_size,label_num,dit_num,head):
        super().__init__()
        # 定义标签的Embedding,dit_num个DiT block,Layernorm等
    def forward(self,x,label):
        # 计算

开始训练

代码语言:javascript
代码运行次数:0
运行
复制
for epoch in range(EPOCH):
    for imgs,labels in dataloader:
        x=imgs*2-1 # 图像的像素范围从[0,1]转换到[-1,1],和噪音高斯分布范围对应
        t=torch.randint(0,T,(imgs.size(0),))  # 为每张图片生成随机t时刻
        y=labels
    
        x,noise=forward_add_noise(x,t) # x:加噪图 noise:噪音
        pred_noise=model(x.to(DEVICE),t.to(DEVICE),y.to(DEVICE))

        loss=loss_fn(pred_noise,noise.to(DEVICE))
    
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()

开始推理(推理代码详见附件)

计算FID分数值

代码语言:javascript
代码运行次数:0
运行
复制
# 装包
pip install pytorch_fid
代码语言:javascript
代码运行次数:0
运行
复制
from pytorch_fid import fid_score
import torch

torch.multiprocessing.set_start_method('spawn',force=True)

# 定义真实图像和生成图像的路径
path_real_images = './images/real'
path_fake_images = './images/fake'
def value():
    # 计算FID分数
    fid_value = fid_score.calculate_fid_given_paths(
        paths=[path_real_images, path_fake_images],
        batch_size=1,  # 根据你的GPU内存调整批处理大小
        device='cuda',  # 或 'cpu',取决于你的硬件
        dims=2048,       # InceptionV3的特征维度
    )

    print(f"FID score: {fid_value}")
if __name__=='__main__':
    value()

DiT模型延申Sora

Sora模型的诠释:Sora模型代表了一种尖端的视觉技术,它采用独树一帜的方法来生成视频,通过逐步消除噪声来构建出精细的最终画面,从而具备捕捉复杂动态场景的能力。

Sora模型的基石:Sora模型的关键组成部分包括Diffusion Transformer(DiT)、Variational Autoencoder(VAE)和Vision Transformer(ViT)。

DiT专注于从噪声中还原出原始视频数据,VAE则负责将视频数据压缩成潜在的空间表示,而ViT的作用是将视频帧转换成特征向量,供DiT进一步处理。

Diffusion Transformer(DiT):DiT集成了扩散模型和Transformer架构的优点,通过模拟从噪声到清晰数据的过程,DiT能够创造出高质量、逼真的视频内容。在Sora模型中,DiT的任务是从噪声数据中恢复出原始视频。

Variational Autoencoder(VAE):VAE是一种生成模型,它能够将视频或图像数据压缩成低维度的潜在表示,并通过解码器将这些表示恢复成原始数据。在Sora模型中,VAE作为编码器,将视频数据压缩后提供给DiT,以指导其生成与原始视频相似的内容。

Vision Transformer(ViT):ViT是一种基于Transformer的图像处理模型,它将图像分割成多个小块(patches),并将这些小块转换为特征向量作为Transformer的输入。在Sora模型中,ViT可能用于预处理或作为模型的一个组成部分

代码语言:javascript
代码运行次数:0
运行
复制
附件结构
├───DIT(项目名称)
    | ├───images(图片文件夹)
        | └───fake(假图片文件夹)
            | └───real(真图片文件夹)
    | ├───MNIST(MNIST数据集文件夹)
        | ├───test(测试数据文件夹)
        | └───train(训练数据文件夹)
    | ├───config.py(配置文件)
    | ├───dataset.py(数据集处理文件)
    | ├───diffusion.py(扩散模型文件)
    | ├───dit_block.py(块状结构文件)
    | ├───dit.png( dit图片文件)
    | ├───dit.py( ditPython文件)
    | ├───fid_computer.py( fid计算机文件)
    | ├───generator.py(生成器文件)
    | ├───inference.py(推理文件)
    | ├───model.pth(模型文件)
    | ├───time_emb.py(时间嵌入文件)
    | ├───Readme.md
    | └───train.py(训练文件)

​​

希望对你有帮助!加油!

若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-11-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 概述
  • 演示效果
  • 扩散模型
    • 基本原理
      • 扩散(正向过程):
      • 逆扩散(反向过程):
  • DiT架构综述
    • DiT的核心特点:
    • DiT的工作原理:
  • 代码
    • 前向加噪
    • DiT模型
  • 计算FID分数值
  • DiT模型延申Sora
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档