前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >gan训练

gan训练

作者头像
sofu456
发布2019-07-09 14:35:27
5160
发布2019-07-09 14:35:27
举报
文章被收录于专栏:sofu456sofu456

gan对mnist数据集训练

使用非卷积神经网络,对1维数据模拟,卷积是对2维数据模拟

代码语言:javascript
复制
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os
 
if not os.path.exists('./img'):
    os.mkdir('./img')
 
 
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out
 
 
batch_size = 128
num_epoch = 100
z_dimension = 100
 
# Image processing
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])            #可能时mnist数据集数据更新过,以前使用的[0.5,0.5,0.5]会报错
])
# MNIST dataset
mnist = datasets.MNIST(
    root='./tensorflow/data_gan/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)
 
 
# Discriminator
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1), 
            nn.Sigmoid())
 
    def forward(self, x):
        x = self.dis(x)
        return x
 
 
# Generator
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256), 
            nn.ReLU(True), 
            nn.Linear(256, 784), 
            nn.Tanh())
 
    def forward(self, x):
        x = self.gen(x)
        return x
 
 
D = discriminator()
G = generator()
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
 
# Start training
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        img = img.view(num_img, -1)
        real_img = Variable(img)
        real_label = Variable(torch.ones(num_img))
        fake_label = Variable(torch.zeros(num_img))
 
        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better
 
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension))
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better
 
        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
 
        # ===============train generator
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension))
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)
 
        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
 
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'.format(
                      epoch, num_epoch, d_loss.item(), g_loss.item(),
                      real_scores.data.mean(), fake_scores.data.mean()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')
 
    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))
 
torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

单个图片训练(训练使用一张图片中的每个小图片)

代码语言:javascript
复制
import torch
from torch import nn
from torch import autograd
from PIL import Image
from torchvision import transforms,utils

class discriminator(nn.Module):
    def __init__(self,inputsize,outputsize):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(inputsize, outputsize),
            nn.LeakyReLU(0.2),
            nn.Linear(outputsize, outputsize),
            nn.LeakyReLU(0.2),
            nn.Linear(outputsize, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        x = self.dis(x)
        return x

class generator(nn.Module):
    def __init__(self,gsize,inputsize,outputsize):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(gsize, outputsize),
            nn.ReLU(True),
            nn.Linear(outputsize, outputsize),
            nn.ReLU(True),
            nn.Linear(outputsize, inputsize),
            nn.Tanh()
        )
 
    def forward(self, x):
        x = self.gen(x)
        return x

width = 240
height = 360
x = width//8
y = height//12
z_dimension = 96

def to_img(img):
    out = 0.5 * (img + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, x, y)
    return out

image = Image.open('./tensorflow/jpg/text.png')
img = transforms.Compose([transforms.Resize((height,width)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])])(image)

print(img.shape)
img = torch.cat(img.chunk(12,1))#图像数据分割
img = torch.cat(img.chunk(8,2))#图像数据分割
img = img.view(12*8, -1)  # 将格子中图片展开
print(img.shape)

D = discriminator(x*y,8*12)
G = generator(z_dimension,x*y,8*12)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)


epoch = 1000
for i in range(epoch*100):
    num_img = img.size(0)
    real_img = autograd.Variable(img)  # 将tensor变成Variable放入计算图中
    real_label = autograd.Variable(torch.ones(num_img))  # 定义真实label为1
    fake_label = autograd.Variable(torch.zeros(num_img))  # 定义假的label为0

    if i==0:
        utils.save_image(to_img(real_img.data), './tensorflow/gan-png/real_images.png')

    # compute loss of real_img
    real_out = D(real_img)  # 将真实的图片放入判别器中
    d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss  
    real_scores = real_out  # 真实图片放入判别器输出越接近1越好

    # compute loss of fake_img
    z = autograd.Variable(torch.randn(num_img, z_dimension))  # 随机生成一些噪声
    fake_img = G(z)  # 放入生成网络生成一张假的图片
    fake_out = D(fake_img)  # 判别器判断假的图片
    d_loss_fake = criterion(fake_out, fake_label)  # 得到假的图片的loss
    fake_scores = fake_out  # 假的图片放入判别器越接近0越好

    # bp and optimize
    d_loss = d_loss_real+d_loss_fake  # 将真假图片的loss加起来
    d_optimizer.zero_grad()  # 归0梯度
    d_loss.backward()  # 反向传播
    d_optimizer.step()  # 更新参数

    # compute loss of fake_img
    z = autograd.Variable(torch.randn(num_img, z_dimension))  # 得到随机噪声
    fake_img = G(z)  # 生成假的图片
    output = D(fake_img)  # 经过判别器得到结果
    g_loss = criterion(output, real_label)  # 得到假的图片与真实图片label的loss

    # bp and optimize
    g_optimizer.zero_grad()  # 归0梯度
    g_loss.backward()  # 反向传播
    g_optimizer.step()  # 更新生成网络的参数

    if i%(epoch) == 0:
        print("real_scores:%f,fake_scores:%f,d_loss:%f,g_loss:%f"%(real_scores.data.mean().item(),
        fake_scores.data.mean().item(),d_loss_fake.item(),g_loss.item()))
        print("-------------------------%d"%i)
        utils.save_image(to_img(fake_img.data), './tensorflow/gan-png/fake_images_{}.png'.format(i//epoch))

原图

在这里插入图片描述
在这里插入图片描述

生成图片

在这里插入图片描述
在这里插入图片描述
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年05月15日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • gan对mnist数据集训练
  • 单个图片训练(训练使用一张图片中的每个小图片)
  • 原图
  • 生成图片
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档