前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >对抗生成网络学习(七)——SRGAN生成超分辨率影像(tensorflow实现)「建议收藏」

对抗生成网络学习(七)——SRGAN生成超分辨率影像(tensorflow实现)「建议收藏」

作者头像
Java架构师必看
发布2022-06-13 08:34:35
5.7K1
发布2022-06-13 08:34:35
举报
文章被收录于专栏:Java架构师必看

大家好,我是架构君,一个会写代码吟诗的架构师。今天说一说对抗生成网络学习(七)——SRGAN生成超分辨率影像(tensorflow实现)「建议收藏」,希望能够帮助大家进步!!!

一、背景

SRGAN(Super-Resolution Generative Adversarial Network)即超分辨率GAN,是Christian Ledig等人于16年9月提出的一种对抗神经网络。利用卷积神经网络实现单影像的超分辨率,其瓶颈仍在于如何恢复图像的细微纹理信息。对于GAN而言,将一组随机噪声输入到生成器中,生成的图像质量往往较差。因此,作者提出了SRGAN,并定义一个loss函数以驱动模型,SRGAN最终可以生成一幅原始影像扩大4倍的高分辨率影像。

本文试验基于ImageNet数据,用尽可能少的代码,实现SRGAN生成超分辨率影像的过程。

1文章链接:https://arxiv.org/pdf/1609.04802.pdf

二、SRGAN原理

关于SRGAN,网上的介绍不是很多,这里先推荐一篇:

2(https://javajgs.com/go?url=https://blog.csdn.net/Cloudox_/article/details/78666910)

先给出SRGAN的效果图:

最左边是经过三次插值的处理结果(也就是原图resize4倍之后再插值得到的),第二个SRResNet是MSE网络处理的结果,第三个是感知loss驱动的SRGAN模型结果,第四个是原始高清影像。

文章中,作者的主要贡献在于:

Our main contributions are: • We set a new state of the art for image SR with high upscaling factors (4×) as measured by PSNR and structural similarity (SSIM) with our 16 blocks deep ResNet (SRResNet) optimized for MSE. (用PSNR和SSIM来评估4倍上采样的超分辨率影像。) • We propose SRGAN which is a GAN-based network optimized for a new perceptual loss. Here we replace the MSE-based content loss with a loss calculated on feature maps of the VGG network 49, which are more invariant to changes in pixel space . (提出了SRGAN,并用感知loss进行驱动。) • We confirm with an extensive mean opinion score (MOS) test on images from three public benchmark datasets that SRGAN is the new state of the art, by a large margin, for the estimation of photo-realistic SR images with high upscaling factors (4×).(证实了SRGAN对于4倍上采样获得超分辨率影像而言,是一个新阶段。)

对于网络结构,作者受到启发,generator采用了block layout,discriminator中作者使用了LeakyReLU而没有采用max-pooling,网络结构的示意图为:

但是,作者并没有局限在SRGAN的网络结构。后续试验中,作者发现,当网络结构大于16层的时候,可以带来比较好的试验效果:

We found that even deeper networks (B > 16) can further increase the performance of SRResNet, however, come at the cost of longer training and testing times.We further found SRGAN variants of deeper networks are increasingly difficult to train due to the appearance of high-frequency artifacts.

因此,作者进行了不同层数的网络结构试验:

根据试验结果,54层的VGG网络结构的效果更好。

网络层数越深,模型训练越慢。因此,本试验主要参考代码3,使用了MSE来训练SRGAN,并对其进行了删减。如果大家有兴趣,可以自行试验其他网络层数的SRGAN。顺便再推荐几个比较好的代码:

3(https://javajgs.com/go?url=https://github.com/nnUyi/SRGAN)

4(https://javajgs.com/go?url=https://github.com/tensorlayer/srgan)

5(https://javajgs.com/go?url=https://github.com/brade31919/SRGAN-tensorflow)

三、实现过程

1.文件结构

所有文件的结构如下:

代码语言:javascript
复制
-- utils.py                # image的操作文件
-- layer.py                # 图层的定义文件
-- SRGAN.py                # 模型定义文件
-- main.py                 # 主函数,包括训练及测试部分
-- data                    # 原始数据文件夹
    |------ train
            |------ ImageNet
                    |------ image01.png
                    |------ image02.png
                    |------ ......
    |------ test                
            |------ Set14
                    |------ image01.png
                    |------ image02.png
                    |------ ......

只听到从架构师办公室传来架构君的声音:

莫倚倾国貌,嫁取个,有情郎。有谁来对上联或下联?

2.数据准备

主要的训练数据集来自ImageNet,当然我是使用的3的数据集,即只使用了ImageNet中的3137张影像作为训练数据,训练数据的下载地址为:https://pan.baidu.com/s/1eSJC0lc,下载好之后,将所有图片解压,并放在路径'./data/train/ImageNet/'下。

测试数据集的下载地址为:https://pan.baidu.com/s/1nvmUkBn,下载好之后,解压将所有图片放在路径'./data/test/Set14/'下。

准备好数据之后,可以打开数据看看,所有图像的大小不一:

当然,这里的数据集只是为了训练模型学习图像的细节纹理,因此你也完全自己准备数据集。

3.image处理文件utils.py

首先是一些预处理文件,包括读取影像,存储影像等函数,utils中的代码为:

代码语言:javascript
复制
此代码由Java架构师必看网-架构君整理
import numpy as np
import scipy.misc


def get_images(filename, is_crop, fine_size, images_norm):
    img = scipy.misc.imread(filename, mode='RGB')
    if is_crop:
        size = img.shape
        start_h = int((size[0] - fine_size)/2)
        start_w = int((size[1] - fine_size)/2)
        img = img[start_h:start_h+fine_size, start_w:start_w+fine_size,:]
    img = np.array(img).astype(np.float32)
    if images_norm:
        img = (img-127.5)/127.5
    return img


def save_images(images, size, filename):
    return scipy.misc.imsave(filename, merge_images(images, size))


def merge_images(images, size):
    h,w = images.shape[1], images.shape[2]
    imgs = np.zeros((size[0]*h, size[1]*w, 3))
    
    for index, image in enumerate(images):
        i = index // size[1]
        j = index % size[0]
        imgs[i*h:i*h+h, j*w:j*w+w, :] = image

    return imgs

4.layer处理文件layer.py

前面也提到过,SRGAN使用了leakyRelu而没有采用池化,因此layer文件主要定义一些层函数,具体的代码为:

代码语言:javascript
复制
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np


def res_block(input_x, out_channels=64, k=3, s=1, scope='res_block'):
    with tf.variable_scope(scope):
        x = input_x
        input_x = slim.conv2d_transpose(input_x, out_channels, k, s)
        input_x = slim.batch_norm(input_x, scope='bn1')
        input_x = tf.nn.relu(input_x)
        input_x = slim.conv2d_transpose(input_x, out_channels, k, s)
        input_x = slim.batch_norm(input_x, scope='bn2')
    
    return x+input_x


def pixel_shuffle_layer(x, r, n_split):
    def PS(x, r):
        bs, a, b, c = x.get_shape().as_list()
        x = tf.reshape(x, (bs, a, b, r, r))
        x = tf.transpose(x, [0, 1, 2, 4, 3])
        x = tf.split(x, a, 1)
        x = tf.concat([tf.squeeze(x_) for x_ in x], 2)
        x = tf.split(x, b, 1)
        x = tf.concat([tf.squeeze(x_) for x_ in x], 2)
        return tf.reshape(x, (bs, a*r, b*r, 1))

    xc = tf.split(x, n_split, 3)
    return tf.concat([PS(x_, r) for x_ in xc], 3)


def down_sample_layer(input_x):
    K = 4
    arr = np.zeros([K, K, 3, 3])
    arr[:, :, 0, 0] = 1.0 / K ** 2
    arr[:, :, 1, 1] = 1.0 / K ** 2
    arr[:, :, 2, 2] = 1.0 / K ** 2
    weight = tf.constant(arr, dtype=tf.float32)
    downscaled = tf.nn.conv2d(
        input_x, weight, strides=[1, K, K, 1], padding='SAME')
    return downscaled


def leaky_relu(input_x, negative_slop=0.2):
    return tf.maximum(negative_slop*input_x, input_x)


def PSNR(real, fake):
    mse = tf.reduce_mean(tf.square(127.5*(real-fake)+127.5), axis=(-3, -2, -1))
    psnr = tf.reduce_mean(10 * (tf.log(255*255 / tf.sqrt(mse)) / np.log(10)))
    return psnr

5.SRGAN模型文件SRGAN.py

SRGAN的模型文件是这个程序的关键,这里我直接给出SRGAN的代码:

代码语言:javascript
复制
此代码由Java架构师必看网-架构君整理
from glob import glob
from skimage import transform
import time
import os

from layer import *
from utils import *


class SRGAN:
    model_name = 'SRGAN'

    def __init__(self, dataset_dir='./data/', is_crop=True,
                 batch_size=1, input_height=256, input_width=256, input_channels=3, sess=None):

        self.learning_rate = 0.0001
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.lambd = 0.001
        self.epoches = 200
        self.fine_size = 128

        self.checkpoint_dir = './checkpoint/'
        self.test_dir = './test'
        self.model_dirs = 'ImageNet/'
        self.train_set = 'ImageNet/'
        self.val_set = 'Set5/'
        self.test_set = 'Set14/'

        self.dataset_dir = dataset_dir
        self.is_crop = is_crop

        self.input_height = input_height
        self.input_width = input_width
        self.input_channels = input_channels
        self.batch_size = batch_size
        self.images_norm = True
        self.dataset_name = 'ImageNet'
        self.sess = sess
        self.check_dir()

    def check_dir():
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        # if not os.path.exists(sample_dir):
        #     os.makedirs(sample_dir)
        # if not os.path.exists(logs_dir):
        #     os.makedirs(logs_dir)
        if not os.path.exists(self.test_dir):
            os.makedirs(self.test_dir)
        
    def generator(self, input_x, reuse=False):
        with tf.variable_scope('generator') as scope:
            if reuse:
                scope.reuse_variables()

            
            with slim.arg_scope([slim.conv2d_transpose],
                                weights_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                weights_regularizer=None,
                                activation_fn=None,
                                normalizer_fn=None,
                                padding='SAME'):
                conv1 = tf.nn.relu(slim.conv2d_transpose(input_x, 64, 3, 1, scope='g_conv1'))
                shortcut = conv1
                # res_block(input_x, out_channels=64, k=3, s=1, scope='res_block'):
                res1 = res_block(conv1, 64, 3, 1, scope='g_res1')
                res2 = res_block(res1, 64, 3, 1, scope='g_res2')
                res3 = res_block(res2, 64, 3, 1, scope='g_res3')
                res4 = res_block(res3, 64, 3, 1, scope='g_res4')
                res5 = res_block(res4, 64, 3, 1, scope='g_res5')
                
                conv2 = slim.batch_norm(slim.conv2d_transpose(res5, 64, 3, 1, scope='g_conv2'), scope='g_bn_conv2')
                conv2_out = shortcut+conv2
                # pixel_shuffle_layer(x, r, n_split):
                conv3 = slim.conv2d_transpose(conv2_out, 256, 3, 1, scope='g_conv3')
                shuffle1 = tf.nn.relu(pixel_shuffle_layer(conv3, 2, 64)) #64*2*2
                conv4 = slim.conv2d_transpose(shuffle1, 256, 3, 1, scope='g_conv4')
                shuffle2 = tf.nn.relu(pixel_shuffle_layer(conv4, 2, 64))
                conv5 = slim.conv2d_transpose(shuffle2, 3, 3, 1, scope='g_conv5')
                self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
                return tf.nn.tanh(conv5)
            
    def discriminator(self, input_x, reuse=False):
        with tf.variable_scope('discriminator') as scope:
            if reuse:
                scope.reuse_variables()
            with slim.arg_scope([slim.conv2d, slim.fully_connected],
                                weights_initializer = tf.truncated_normal_initializer(stddev=0.02),
                                weights_regularizer = None,
                                activation_fn=None,
                                normalizer_fn=None):
                                
                conv1 = leaky_relu(slim.conv2d(input_x, 64, 3, 1, scope='d_conv1'))
                conv1_1 = leaky_relu(slim.batch_norm(slim.conv2d(conv1, 64, 3, 2, scope='d_conv1_1'), scope='d_bn_conv1_1'))

                conv2 = leaky_relu(slim.batch_norm(slim.conv2d(conv1_1, 128, 3, 1, scope='d_conv2'), scope='d_bn_conv2'))
                conv2_1 = leaky_relu(slim.batch_norm(slim.conv2d(conv2, 128, 3, 2, scope='d_conv2_1'), scope='d_bn_conv2_1'))
                
                conv3 = leaky_relu(slim.batch_norm(slim.conv2d(conv2_1, 256, 3, 1, scope='d_conv3'), scope='d_bn_conv3'))
                conv3_1 = leaky_relu(slim.batch_norm(slim.conv2d(conv3, 256, 3, 2, scope='d_conv3_1'), scope='d_bn_conv3_1'))

                conv4 = leaky_relu(slim.batch_norm(slim.conv2d(conv3_1, 512, 3, 1, scope='d_conv4'), scope='d_bn_conv4'))
                conv4_1 = leaky_relu(slim.batch_norm(slim.conv2d(conv4, 512, 3, 2, scope='d_conv4_1'), scope='d_bn_conv4_1'))

                conv_flat = tf.reshape(conv4_1, [self.batch_size, -1])
                dense1 = leaky_relu(slim.fully_connected(conv_flat, 1024, scope='d_dense1'))
                dense2 = slim.fully_connected(dense1, 1, scope='d_dense2')
                
                self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
                return dense2, tf.nn.sigmoid(dense2)

    def build_model(self):
        self.input_target = tf.placeholder(tf.float32, [self.batch_size, self.input_height,
                                                        self.input_width, self.input_channels], name='input_target')
        
        self.input_source = down_sample_layer(self.input_target)
        
        self.real = self.input_target
        self.fake = self.generator(self.input_source, reuse=False)
        self.psnr = PSNR(self.real, self.fake)
        self.d_loss, self.g_loss, self.content_loss = self.inference_loss(self.real, self.fake)
        self.d_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1,
                                              beta2=self.beta2).minimize(self.d_loss, var_list=self.d_vars)
        self.g_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1,
                                              beta2=self.beta2).minimize(self.g_loss, var_list=self.g_vars)
        self.srres_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1,
                                                  beta2=self.beta2).minimize(self.content_loss, var_list=self.g_vars)
        print('builded model...') 

    def inference_loss(self, real, fake):
        def inference_mse_content_loss(real, fake):
            return tf.reduce_mean(tf.square(real-fake))
            
        def inference_adversarial_loss(x, y, w=1, type_='gan'):
            if type_ == 'gan':
                return w * tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
            elif type_ == 'lsgan':
                return w*(x-y)**2
            else:
                raise ValueError('no {} loss type'.format(type_))
        
        content_loss = inference_mse_content_loss(real, fake)
        d_real_logits, d_real_sigmoid = self.discriminator(real, reuse=False)
        d_fake_logits, d_fake_sigmoid = self.discriminator(fake, reuse=True)
        d_fake_loss = tf.reduce_mean(inference_adversarial_loss(d_real_logits, tf.ones_like(d_real_sigmoid)))
        d_real_loss = tf.reduce_mean(inference_adversarial_loss(d_fake_logits, tf.zeros_like(d_fake_sigmoid)))
        g_fake_loss = tf.reduce_mean(inference_adversarial_loss(d_fake_logits, tf.ones_like(d_fake_sigmoid)))
        
        d_loss = self.lambd*(d_fake_loss+d_real_loss)
        g_loss = content_loss + self.lambd*g_fake_loss
        
        return d_loss, g_loss, content_loss
        
    def train(self):
        tf.global_variables_initializer().run()

        data = glob(os.path.join(self.dataset_dir, 'train', self.train_set, '*.*'))
        batch_idxs = int(len(data)/self.batch_size)
        bool_check, counter = self.load_model(self.checkpoint_dir)
        if bool_check:
            print('[!!!] load model successfully')
            counter = counter + 1
        else:
            print('[***] fail to load model')
            counter = 1
        
        print('total steps:{}'.format(self.epoches*batch_idxs))
        
        start_time = time.time()
        for epoch in range(self.epoches):
            np.random.shuffle(data)
            for idx in range(batch_idxs):
                batch_files = data[idx*self.batch_size:(idx+1)*self.batch_size]
                batch_x = [get_images(batch_file, self.is_crop, self.fine_size, self.images_norm) for batch_file in batch_files]
                batch_x = np.array(batch_x).astype(np.float32)
    
                if counter < 2e4:                      
                    _, content_loss, psnr = self.sess.run([self.srres_optim, self.content_loss, self.psnr], feed_dict={self.input_target:batch_x})
                    end_time = time.time()
                    print('epoch{}[{}/{}]:total_time:{:.4f},content_loss:{:4f},psnr:{:.4f}'.format(epoch, idx, batch_idxs, end_time-start_time, content_loss, psnr))
                else:
                    _, d_loss = self.sess.run([self.d_optim, self.d_loss,], feed_dict={self.input_target:batch_x})
                    _, g_loss, psnr = self.sess.run([self.g_optim, self.g_loss, self.psnr], feed_dict={self.input_target:batch_x})
                    end_time = time.time()
                    print('epoch{}[{}/{}]:total_time:{:.4f},d_loss:{:.4f},g_loss:{:4f},psnr:{:.4f}'.format(epoch, idx, batch_idxs, end_time-start_time, d_loss, g_loss, psnr))

                if np.mod(counter, 500)==0:
                    self.save_model(self.checkpoint_dir, counter)
                counter = counter+1
            
    def test(self):
        print('testing')
        bool_check, counter = self.load_model(self.checkpoint_dir)
        if bool_check:
            print('[!!!] load model successfully')
        else:
            print('[***] fail to load model')
        
        test = glob(os.path.join(self.dataset_dir, 'test', self.test_set, '*.*'))
        batch_files = test[:self.batch_size]
        batch_x = [get_images(batch_file, True, self.fine_size, self.images_norm) for batch_file in batch_files]
        batchs = np.array(batch_x).astype(np.float32)
        
        sample_images, input_sources = self.sess.run([self.fake, self.input_source],
                                                     feed_dict={self.input_target:batchs})
        #images = np.concatenate([sample_images, batchs], 2)
        for i in range(len(batch_x)):
            batch = np.expand_dims(batchs[i],0)
            sample_image = np.expand_dims(sample_images[i],0)
            input_source = np.expand_dims(input_sources[i],0)
            save_images(batch, [1,1], '{}/gt_hr_{}.png'.format(self.test_dir, i))
            save_images(sample_image, [1,1], '{}/test_hr_{}.png'.format(self.test_dir, i))
            save_images(input_source, [1,1], '{}/gt_lr_{}.png'.format(self.test_dir, i))

            resize_sample = merge_images(sample_image, [1,1])
            resize_sample = transform.resize(resize_sample, (self.input_height, self.input_width))
            scipy.misc.imsave('{}/skimage_Tran_{}.png'.format(self.test_dir, i), resize_sample)


    def save_model(self, checkpoint_dir, step):
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dirs, self.model_name)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)

    def load_model(self, checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dirs, self.model_name)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
            print(" [*] Success to read {}".format(ckpt_name))
            return True, counter
        else:
            print(" [*] Failed to find a checkpoint")
            return False, 0

6.主函数文件main.py

主函数文件主要控制整个过程,首先创建各个需要的文件夹,其次运行试验,先进行train,然后再进行test,下面直接给出main的代码:

代码语言:javascript
复制
from SRGAN import *

# 是否需要执行的步骤
is_crop = True
is_testing = True
is_training = False


def main():
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9, allow_growth=True)
    config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
    with tf.Session(config=config) as sess:
        srgan = SRGAN(dataset_dir='data/',
                      is_crop=is_crop,
                      batch_size=8,
                      input_height=128, input_width=128, input_channels=3,
                      sess=sess)
        srgan.build_model()
        if is_training:
            srgan.train()
        if is_testing:
            srgan.test()


if __name__=='__main__':
    main()

四、试验结果

由于我的电脑内存比较小,如果设置影像为256*256,batch_size为8的情况下会提示内存不足,因此我将图像的大小设置为了128*128,设置200个epoch,训练的一晚上,大约只训练了140个epoch,后面我就没有再继续进行训练了。

训练140个epoch的试验结果为:

上图中,第一列表示输入的32*32尺寸的影像(即对原始影像进行下采样得到的),第二列表示SRGAN处理的结果,第三列表示利用skimage库的resize函数对输入的32*32影像的处理结果,最后一列是原始影像。

从结果可以看出,SRGAN能够实现图像的超分辨率,和原图比虽然一些细微纹理还无法较好的还原,但是相较于skimage的resize方法而言,结果已经非常好了。

五、分析

1.文件结构参见三.

2.SRGAN可以实现超分辨率过程。如果想获得更好的试验结果,可以考虑加深图像的网络层数。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-06-122,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、背景
  • 二、SRGAN原理
  • 三、实现过程
    • 1.文件结构
      • 2.数据准备
        • 3.image处理文件utils.py
          • 4.layer处理文件layer.py
            • 5.SRGAN模型文件SRGAN.py
              • 6.主函数文件main.py
              • 四、试验结果
              • 五、分析
              领券
              问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档