前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >AI绘画训练一个扩散模型-上集

AI绘画训练一个扩散模型-上集

作者头像
Dlimeng
发布2023-12-25 09:21:21
1250
发布2023-12-25 09:21:21
举报
文章被收录于专栏:开源心路开源心路

介绍

AI绘画,其中最常见方案基于扩散模型,Stable Diffusion 在此基础上,增加了 VAE 模块和 CLIP 模块,本文搞了一个测试Demo,分为上下两集,第一集是denoising_diffusion_pytorch ,第二集是diffusers。 对于专业的算法同学而言,我更推荐使用 diffusers 来训练。原因是 diffusers 工具包在实际的 AI 绘画项目中用得更多,并且也更易于我们修改代码逻辑,实现定制化功能。

https://arxiv.org/abs/2112.10752
https://arxiv.org/abs/2112.10752

基础模块

  • 创建UNet模型和高斯扩散模型(Gaussian Diffusion)。

UNet是一个编码器-解码器结构的全卷积神经网络。Gaussian Diffusion用于定义噪声过程和损失函数。

  • 将模型加载到GPU上(如果有GPU)。
  • 使用随机初始化的图片进行一次训练,计算损失并反向传播。

这一步的目的是对模型进行一次预热,更新权重。

  • 使用diffusion模型采样生成图片。

这里采样1000步,也就是将噪声逐步减少,每步用UNet预测下一步的图像,最终输出生成的图片。

  • 如果图片在GPU上,将其移回到CPU。
  • 可视化第一张生成图片。

plt.imshow显示图片。

这样通过DDPM框架,可以从随机噪声生成符合数据分布的新图片。每次训练会使模型逐步逼近真实数据分布,从而产生更高质量的图片。

代码语言:javascript
复制
# 创建UNet和扩散模型

from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import torch

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
).cuda()

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000   # number of steps
).cuda()

# 使用随机初始化的图片进行一次训练
training_images = torch.randn(8, 3, 128, 128)
loss = diffusion(training_images.cuda())
loss.backward()


# 采样1000步生成一张图片
sampled_images = diffusion.sample(batch_size = 4)
代码语言:javascript
复制
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms as transforms


# 如果张量在 GPU上,需要移动到 CPU上
if sampled_images.is_cuda:
    sampled_images = sampled_images.cpu()

# 检查我们生成的一张图
img = sampled_images[0].clone().detach().permute(1, 2, 0)

plt.imshow(img)

数据集

  • 导入所需的库:PIL、io、datasets等。
  • 使用datasets库中的load_dataset方法加载Oxford Flowers数据集。
  • 创建一个目录来保存图片。
  • 遍历数据集的训练、验证、测试split,逐个图像获取图片bytes数据,并保存为PNG格式图片。
  • 使用PIL库的Image对象将bytes数据加载并保存为图片文件。
  • 使用tqdm显示循环进度。
代码语言:javascript
复制
# 数据集下载
from PIL import Image
from io import BytesIO
from datasets import load_dataset
import os
from tqdm import tqdm

dataset = load_dataset("nelorth/oxford-flowers")

# 创建一个用于保存图片的文件夹
images_dir = "./oxford-datasets/raw-images"
os.makedirs(images_dir, exist_ok=True)

# 遍历所有图片并保存,针对oxford-flowers,整个过程要持续15分钟左右
for split in dataset.keys():
    for index, item in enumerate(tqdm(dataset[split])):
        image = item['image']
        image.save(os.path.join(images_dir, f"{split}_image_{index}.png"))

模型训练

  • 定义Unet模型架构和Gaussian Diffusion过程。
  • 加载数据,指定训练参数:
    • 训练总步数20000
    • batch size 16
    • 学习率2e-5
    • 梯度累积步数2
    • EMA指数衰减参数0.995
    • 使用混合精度训练
    • 每2000步保存一次模型
    • 创建Trainer进行模型训练。Trainer封装了训练循环逻辑。
  • 调用trainer.train()进行训练。
代码语言:javascript
复制
# 模型训练
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
).cuda()

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000   # 加噪总步数
).cuda()

trainer = Trainer(
    diffusion,
    './oxford-datasets/raw-images',
    train_batch_size = 16,
    train_lr = 2e-5,
    train_num_steps = 20000,          # 总共训练20000步
    gradient_accumulate_every = 2,    # 梯度累积步数
    ema_decay = 0.995,                # 指数滑动平均decay参数
    amp = True,                       # 使用混合精度训练加速
    calculate_fid = False,            # 我们关闭FID评测指标计算(比较耗时)。FID用于评测生成质量。
    save_and_sample_every = 2000      # 每隔2000步保存一次模型
)

trainer.train()
代码语言:javascript
复制
# 你可以等待上面的模型训练完成后,查看生成结果

from glob import glob

result_images = glob(r"./results/*.png")
print(result_images)
代码语言:javascript
复制
# 可视化图像看看
from PIL import Image

img = Image.open("./results/sample-1.png")
img

测试

https://colab.research.google.com/github/NightWalker888/ai_painting_journey/blob/main/lesson12/train_diffusion_v2.ipynb#scrollTo=8BVjfBPI93Ar

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2023-12-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 介绍
  • 基础模块
  • 数据集
  • 模型训练
  • 测试
相关产品与服务
大模型图像创作引擎
大模型图像创作引擎是一款提供 AI 图像生成与处理能力的 API 技术服务,可以结合输入的文本或图片智能创作出与输入相关的图像内容,具有更强大的中文理解能力、更多样化的风格选择,更好支持中文场景下的建筑风景生成、古诗词理解、水墨剪纸等中国元素风格生成,以及各种动漫、游戏风格的高精度图像生成和风格转换,为高质量的内容创作、内容运营提供技术支持。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档