DiT的全称是Diffusion in Transformer,它将Transformer引入到扩散模型中,替换了Stable Diffusion中常用的U-Net主干。通过增加Transformer深度/宽度或增加输入令牌数量,具有较高Gflops(浮点数运算次数)的DiT始终具有较低的FID(Fréchet初始距离,用于描述生成的图片和原始的图片之间的距离)。本文会讲解扩散模型的原理,从零开始逐步复现DiT模型。
如下为DiT模型生成的图片效果
扩散模型(Diffusion Models)是一类用于生成数据(如图像、音频等)的深度学习模型,它们在近年来逐渐成为生成模型领域的重要研究方向,尤其是在图像生成任务中表现出与生成对抗网络(GANs)相媲美甚至更优的性能。
扩散模型的基本原理是模拟数据从真实分布逐渐转变到已知分布(通常是高斯噪声分布)的过程,然后再从这个已知分布逆向恢复出数据。这个过程可以分为两个主要阶段:扩散(或正向过程)和逆扩散(或反向过程)。
从原始数据开始,逐步添加随机噪声,直到数据变成不包含任何原始信息的高斯噪声。 这个过程通常是固定的,不需要学习。每一步添加的噪声量可以是一个预定义的 schedule。 扩散过程可以用一个马尔可夫链来表示,每一步都依赖于前一步的状态。
初始化:从真实数据 x0x0,开始,它服从数据分布PdataPdata。 逐步添加噪声:在时间步 t(从 0 到 T),逐步向数据中添加噪声,直到数据变成不包含任何原始信息的高斯噪声。这个过程可以用以下公式表示: q(xt∣x0)=N(xt;αtx0,(1−αt)I)q(xt∣x0)=N(xt;αtx0,(1−αt)I) 累积噪声:在每一步,我们都有xtxt的条件分布,但通常我们需要得到xtxt关于x0x0的分布。这可以通过累积每一步的噪声来实现: q(xt∣x0)=N(xt;αtx0,(1−αt)I)q(xt∣x0)=N(xt;αtx0,(1−αt)I) 其中,αt=∏i=1t(1−βi)αt=∏i=1t(1−βi)是累积的噪声比例。
从高斯噪声开始,通过学习到的模型逐步去除噪声,恢复出原始数据。 这个过程需要通过神经网络来学习,网络需要预测在每一步应该如何去除噪声,以恢复出接近原始数据的状态。 逆扩散过程通常涉及到条件概率的估计,即给定当前噪声数据,预测原始数据在该步骤的潜在状态。
初始化:从高斯噪声 xTxT开始,它服从标准高斯分布N(xT;0,I)N(xT;0,I)。 逐步去噪:在时间步 t(从 T 到 0),逐步去除噪声以恢复原始数据。这个过程通常是通过神经网络pθ(xt−1∣xt)pθ(xt−1∣xt)来预测的。神经网络的训练目标是最大化以下似然函数的对数:logpθ(x0)≈∑t=1Tlogpθ(xt−1∣xt)logpθ(x0)≈∑t=1Tlogpθ(xt−1∣xt) 其中,pθ(xt−1∣xt)pθ(xt−1∣xt)是神经网络预测的xt−1xt−1给定的xtxt的条件概率。 预测和重参数化:在每一步,神经网络预测xt−1xt−1的均值和方差,然后通过重参数化技巧从这些参数中采样。这个过程可以用以下公式表示: xt−1=μθ(xt,t)+σθ(xt,t)⋅ϵxt−1=μθ(xt,t)+σθ(xt,t)⋅ϵ 其中,ϵϵ是从标准高斯分布中采样的噪声,μθμθ和σθσθ是神经网络预测的均值和标准差。
结合Transformer架构的扩散模型:DiT使用Transformer作为扩散模型的骨干网络,而不是传统的卷积神经网络(如U-Net)。这使得模型能够通过自注意力机制捕捉图像中的长距离依赖关系。 潜在空间操作:DiT在潜在空间中训练,通常比直接在像素空间训练更高效。通过使用变分自编码器(VAE)将图像编码到潜在空间,DiT减少了计算复杂度。 可扩展性:DiT展示了出色的可扩展性,通过增加模型的计算量(以Gflops衡量),可以显著提高生成图像的质量。这种可扩展性允许DiT在不同的分辨率和复杂度下生成图像。
数据准备:使用预训练的变分自编码器(VAE)将输入图像编码成潜在空间的表示。 分块化(Patchification):将输入的潜在表示分割成一系列的小片段(patches),每个片段对应于Transformer模型的一个输入标记(token)。 Transformer Blocks模块:输入的标记序列通过一系列的Transformer块进行处理,这些块包括自注意力层、前馈神经网络以及层归一化等组件。 条件扩散过程:在训练过程中,DiT模型学习逆向扩散过程,即从噪声数据中恢复出清晰的图像。 样本生成:在训练完成后,可以通过DiT模型生成新的图像。首先,从标准正态分布中采样一个潜在表示,然后通过DiT模型逆向扩散过程,逐步去除噪声,最终解码回像素空间,得到生成的图像。
import torch
from config import *
# 前向diffusion计算参数
betas=torch.linspace(0.0001,0.02,T) # (T,)
alphas=1-betas # (T,)
alphas_cumprod=torch.cumprod(alphas,dim=-1) # alpha_t累乘 (T,) [a1,a2,a3,....] -> [a1,a1*a2,a1*a2*a3,.....]
alphas_cumprod_prev=torch.cat((torch.tensor([1.0]),alphas_cumprod[:-1]),dim=-1) # alpha_t-1累乘 (T,), [1,a1,a1*a2,a1*a2*a3,.....]
variance=(1-alphas)*(1-alphas_cumprod_prev)/(1-alphas_cumprod) # denoise用的方差 (T,)
# 执行前向加噪
def forward_add_noise(x,t): # batch_x: (batch,channel,height,width), batch_t: (batch_size,)
noise=torch.randn_like(x) # 为每张图片生成第t步的高斯噪音 (batch,channel,height,width)
batch_alphas_cumprod=alphas_cumprod[t].view(x.size(0),1,1,1)
x=torch.sqrt(batch_alphas_cumprod)*x+torch.sqrt(1-batch_alphas_cumprod)*noise # 基于公式直接生成第t步加噪后图片
return x,noise
第一步:定义PatchEmbedding
class PatchEmbedding(nn.Module):
def __init__(self,channel,patch_size,emb_size):
super().__init__()
self.patch_size=patch_size
self.patch_count=img_size//self.patch_size
self.channel=channel
self.conv=nn.Conv2d(in_channels=channel,out_channels=channel*patch_size**2,kernel_size=patch_size,padding=0,stride=patch_size)
self.patch_emb=nn.Linear(in_features=channel*patch_size**2,out_features=emb_size)
self.patch_pos_emb=nn.Parameter(torch.rand(1,self.patch_count**2,emb_size))
def forward(self,x):
x=self.conv(x) # (batch,new_channel,patch_count,patch_count)
x=x.permute(0,2,3,1) # (batch,patch_count,patch_count,new_channel)
x=x.view(x.size(0),self.patch_count*self.patch_count,x.size(3)) # (batch,patch_count**2,new_channel)
x=self.patch_emb(x) # (batch,patch_count**2,emb_size)
x=x+self.patch_pos_emb # (batch,patch_count**2,emb_size)
return x
第二步:定义time_embedding
class TimeEmbedding(nn.Module):
def __init__(self,emb_size):
super().__init__()
self.half_emb_size=emb_size//2
half_emb=torch.exp(torch.arange(self.half_emb_size)*(-1*math.log(10000)/(self.half_emb_size-1)))
self.register_buffer('half_emb',half_emb)
def forward(self,t):
t=t.view(t.size(0),1)
half_emb=self.half_emb.unsqueeze(0).expand(t.size(0),self.half_emb_size)
half_emb_t=half_emb*t
embs_t=torch.cat((half_emb_t.sin(),half_emb_t.cos()),dim=-1)
return embs_t
第三步:定义DiT Block
class DiTBlock(nn.Module):
def __init__(self,emb_size,nhead=4):
super().__init__()
# 定义各种参数,包括线形层,Transformer
def forward(self,x,condition):
# 计算
第四步:定义DiT
class DiT(nn.Module):
def __init__(self,img_size,patch_size,channel,emb_size,label_num,dit_num,head):
super().__init__()
# 定义标签的Embedding,dit_num个DiT block,Layernorm等
def forward(self,x,label):
# 计算
开始训练
for epoch in range(EPOCH):
for imgs,labels in dataloader:
x=imgs*2-1 # 图像的像素范围从[0,1]转换到[-1,1],和噪音高斯分布范围对应
t=torch.randint(0,T,(imgs.size(0),)) # 为每张图片生成随机t时刻
y=labels
x,noise=forward_add_noise(x,t) # x:加噪图 noise:噪音
pred_noise=model(x.to(DEVICE),t.to(DEVICE),y.to(DEVICE))
loss=loss_fn(pred_noise,noise.to(DEVICE))
optimzer.zero_grad()
loss.backward()
optimzer.step()
开始推理(推理代码详见附件)
# 装包
pip install pytorch_fid
from pytorch_fid import fid_score
import torch
torch.multiprocessing.set_start_method('spawn',force=True)
# 定义真实图像和生成图像的路径
path_real_images = './images/real'
path_fake_images = './images/fake'
def value():
# 计算FID分数
fid_value = fid_score.calculate_fid_given_paths(
paths=[path_real_images, path_fake_images],
batch_size=1, # 根据你的GPU内存调整批处理大小
device='cuda', # 或 'cpu',取决于你的硬件
dims=2048, # InceptionV3的特征维度
)
print(f"FID score: {fid_value}")
if __name__=='__main__':
value()
Sora模型的诠释:Sora模型代表了一种尖端的视觉技术,它采用独树一帜的方法来生成视频,通过逐步消除噪声来构建出精细的最终画面,从而具备捕捉复杂动态场景的能力。
Sora模型的基石:Sora模型的关键组成部分包括Diffusion Transformer(DiT)、Variational Autoencoder(VAE)和Vision Transformer(ViT)。
DiT专注于从噪声中还原出原始视频数据,VAE则负责将视频数据压缩成潜在的空间表示,而ViT的作用是将视频帧转换成特征向量,供DiT进一步处理。
Diffusion Transformer(DiT):DiT集成了扩散模型和Transformer架构的优点,通过模拟从噪声到清晰数据的过程,DiT能够创造出高质量、逼真的视频内容。在Sora模型中,DiT的任务是从噪声数据中恢复出原始视频。
Variational Autoencoder(VAE):VAE是一种生成模型,它能够将视频或图像数据压缩成低维度的潜在表示,并通过解码器将这些表示恢复成原始数据。在Sora模型中,VAE作为编码器,将视频数据压缩后提供给DiT,以指导其生成与原始视频相似的内容。
Vision Transformer(ViT):ViT是一种基于Transformer的图像处理模型,它将图像分割成多个小块(patches),并将这些小块转换为特征向量作为Transformer的输入。在Sora模型中,ViT可能用于预处理或作为模型的一个组成部分
附件结构
├───DIT(项目名称)
| ├───images(图片文件夹)
| └───fake(假图片文件夹)
| └───real(真图片文件夹)
| ├───MNIST(MNIST数据集文件夹)
| ├───test(测试数据文件夹)
| └───train(训练数据文件夹)
| ├───config.py(配置文件)
| ├───dataset.py(数据集处理文件)
| ├───diffusion.py(扩散模型文件)
| ├───dit_block.py(块状结构文件)
| ├───dit.png( dit图片文件)
| ├───dit.py( ditPython文件)
| ├───fid_computer.py( fid计算机文件)
| ├───generator.py(生成器文件)
| ├───inference.py(推理文件)
| ├───model.pth(模型文件)
| ├───time_emb.py(时间嵌入文件)
| ├───Readme.md
| └───train.py(训练文件)
希望对你有帮助!加油!
若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!