前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >task8 GAN text-to-image

task8 GAN text-to-image

作者头像
平凡的学生族
发布2019-05-25 09:58:18
6610
发布2019-05-25 09:58:18
举报
文章被收录于专栏:后端技术后端技术

1. tensorflow API学习

1.1 batch normalization

https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization https://www.tensorflow.org/programmers_guide/variableshttps://www.tensorflow.org/programmers_guide/variables https://www.tensorflow.org/api_guides/python/reading_data#Multiple_input_pipelines

1.2 variables

主要是tf.name_scope, tf.variable_scope, tf.get_variable, tf.Variable,tf.op_scope

op_scope和variable_op_scope: 在r0.11版本已经废弃,并且替换为name_scopevariable_scope. 查阅官方API文档https://www.tensorflow.org/api_docs/python/tf/variable_op_scopehttps://www.tensorflow.org/api_docs/python/tf/op_scope 确实如此。

tf.name_scope, tf.variable_scope, tf.get_variable, tf.Variable: What's the difference of name scope and a variable scope in tensorflow? tensorflow scope命名方法(variable_scope()与name_scope()解析 看完这两篇文章就知道怎么用了。简单来讲就是下面三段代码:

代码语言:javascript
复制
with tf.name_scope("my_scope"):
    v1 = tf.get_variable("var1", [1], dtype=tf.float32)
    v2 = tf.Variable(1, name="var2", dtype=tf.float32)
    a = tf.add(v1, v2)

print(v1.name)  # var1:0
print(v2.name)  # my_scope/var2:0
print(a.name)   # my_scope/Add:0
代码语言:javascript
复制
with tf.variable_scope("my_scope"):
    v1 = tf.get_variable("var1", [1], dtype=tf.float32)
    v2 = tf.Variable(1, name="var2", dtype=tf.float32)
    a = tf.add(v1, v2)

print(v1.name)  # my_scope/var1:0
print(v2.name)  # my_scope/var2:0
print(a.name)   # my_scope/Add:0
代码语言:javascript
复制
with tf.name_scope("foo"):
    with tf.variable_scope("var_scope"):
        v = tf.get_variable("var", [1])
with tf.name_scope("bar"):
    with tf.variable_scope("var_scope", reuse=True):
        v1 = tf.get_variable("var", [1])
assert v1 == v
print(v.name)   # var_scope/var:0
print(v1.name)  # var_scope/var:0
  1. tf.name_scope下,tf.Variable受影响,tf.get_variable不受影响。
  2. tf.variable_scope下,tf.Variabletf.get_variable都受影响。
  3. 因此可以在不同的name_scope下轻松地共享同一变量。另外就是,tf.Variable每次调用都会创建新的变量,而tf.get_variable会在没有该变量时创建;而在已创建该变量时,调用reuse就可以复用该变量(如果不调用reuse就会报错)。

2. GAN与DCGAN

此次项目的构造是基于DCGAN的。

本节会依次讲解GAN和DCGAN。在下一节介绍项目所采用的,基于DCGAN改造的网络。

2.1 GAN

GAN构造图

另一个GAN的示例

GAN即生成对抗网络,分为生成器和辨别器两个网络:

  1. 生成器接受一个随机生成的D维“噪声”向量,生成一张图片。
  2. 辨别器则接受真实图片和虚假的图片,并预测所接受的图片是真实的还是生成的。

整个网络的训练需要同时训练两个网络,这两个网络的目标是相互对立的:

  1. 训练生成器以让其生成更加逼真的图片,尽可能迷惑辨别器
  2. 训练辨别器以让其有更强的辨认能力,尽可能识别出生成器生成的假图片。

训练目标: 经过多次的“攻防”训练,生成器和辨别器都会达到很高的水平,最终当生成器生成的图片难以被辨别器识别,辨别器只能靠1/2的概率瞎猜时,训练就完成了。

2.2 DCGAN

DCGAN结构图

DCGAN(Deep Convolutional Generative Adversarial Networks)原理是和GAN一样的,只是把G和D用卷积神经网络实现而已。但DCGAN的网络具有以下特点:

  1. 不使用pooling层,取而代之的是生成器在upsizing使用带有stride的逆卷积层,辨别器在downsizing使用带有stride的卷积层
  2. 数据流每次卷积/逆卷积之后,要经过batch normalization层
  3. 不使用全连接层
  4. 生成器的激活函数使用relu(输出层使用tanh)
  5. 辨别器的激活函数使用leaky_relu

总的来说,数据的流动如下:

  1. 随机生成一个[1, n]的向量,称为噪声,将其reshape成[n, n]的方阵
  2. 经过生成器。在生成器内经过若干个deconvolution层-batch_normalization层-relu层(除了在输出层里,激活层是tanh层),每经过一次这样的组合层,矩阵的尺寸会放大(一般是变大到两倍)。
  3. 经过辨别器。在辨别器内经过若干个convolution层-batch_normalization层-leaky_relu层,每经过一次这样的组合层,矩阵的尺寸会缩小(一般是缩小到二分之一)。
  4. 最终将矩阵展开为[1, n]的向量,并通过矩阵相乘转化为[1, 1]的输出(1就代表True,0就代表False)

3. 论文分析

3.1 GAN-CLS

如果是普通的GAN,那么D只要负责判断G的生成图片是真是假就行。 D只需接受两种数据:

  1. 符合描述的真实图像
  2. 搭配任意描述的生成图像

但是在这个任务中,G要生成符合t描述的图像,就要使用GAN-CLS。在GAN-CLS中,D应当判断出真实图像,且图像要符合文字描述。 此时D需要接受三种数据:

  1. 符合描述的真实图像
  2. 搭配任意描述的生成图像
  3. 搭配错误描述的真实图像

这样训练出来的辨别器不仅能判断图像是否真实,还能判断图像与文字描述是否吻合。

3.2 GAN-INT

论文认为,深度网络学习到的特征表示其实具有可插值性——如果有一对文字描述的深度特征和它们对应的两张图片,那么我们经过差值合成一个新的特征,这个特征所生成的图片也应当与那两张图片相近。 也就是说,”一只小鸟在天上飞”和”黑头乌龟在地上晒太阳”,它们的深度特征插值后,没准就成了“一只黑头小鸟在地上晒太阳了”。

于是论文提出了GAN-INT:

GAN-INT的生成器的目标

其中t1和t2是两个文本,β是差值比例。生成器的目标是:用某种比例合成两个文本,并根据合成文本生成图像,图像要尽可能贴合这个合成文本。通过合成文本,并调整β参数,我们可以得到理论上无数的文本数据,并用它训练D和G。(根据经验,β=0.5时效果不错)

所以,尽管差值后的合成特征没有对应的真实图片(当然没有,因为特征本来都是合成的),我们仍然可以用合成特征搭配真实图片来训练辨别器D,在训练过程中,辨别器将更善于判断图片与文字描述是否匹配反过来,能够促使生成器G生成更符合文字描述的图片

3.3 噪声向量z的作用

很多时候,文本描述只是定义了图片的content信息,而没涉及style信息(比如鸟的姿势,背景颜色)。所以我们就希望z能起到调整style的作用。于是论文就训练了一个invertG:G将文本转换为图片,而invertG把图片再转换回文本。

style的损失函数

然后作者把图片按照style进行聚类(平均背景颜色,鸟类姿态等),然后取属于同一聚类的样本,求得z。通过求解相同聚类的图片的style表示z之间的consine距离,从而确定不同训练方式,z所起到的style作用的强弱。

上图显示,不管是在姿势还是在背景颜色上,GAN-INT-CLS和GAN-INT的噪声向量z都起到更大的作用。而在整体表现上,使用了文本差值方法的GAN(GAN-INT、GAN-INT-CLS)都表现得更好

4. 项目介绍

4.1 dependencies

4.2 h5py

h5py结构

h5py把存储的数据看做两种:dataset和group。dataset就相当于文件,group就相当于文件夹。group可以包含dataset和其它group,就好比文件夹里可以包含文件和其它文件夹。 h5py文件的读写,只需要类似如下操作即可:

代码语言:javascript
复制
import h5py
h = h5py.File("myh5py.hdf5","w")
h.create_dataset('any_key', data='123') # 保存数据“123”到“any_key”的索引下
h.close()

f = h5py.File("myh5py.hdf5","r")
value = f['any_key']
print(value) # 输出123

4.3 skip thought vectors

skip-thought vectors是一个github上的自然语言项目它的作用是把一个句子转换成固定维度的向量(长度为4800),并且意思相近的句子所转换成的向量也相距较近。 转换成的向量都是长度为4800的combine-skip model,其中前2400称为uni-skip model,后2400称为bi-skip model作者推荐使用完整的4800长度向量进行计算,因为这样能达到更好的效果

4.4 网络结构

text-to-image的网络结构

网络的实现是如下形式的:

  1. 生成网络的输入
    1. 生成一个随机产生的向量。
    2. 利用skip-thoughts项目处理描述文字,生成一个长度为4800的向量
    3. 通过矩阵乘法将2.中的向量通过全连接层转化为一个长度较短(比如256)的特征向量。
    4. 将随即向量与特征向量拼接成一个更长的向量。然后通过矩阵乘法让向量长度变化,使之能在之后reshape成一个带有channel维度的方阵(当以一个batch的角度看时,方阵维度为[batch_size, size, size, channels])
  2. 通过生成器。行为与上文的DCGAN一致。
  3. 通过辨别器。
    1. 前半部分与DCGAN一致。
    2. 当把64 * 64的图片处理到维度是[batch_size, 4, 4, channels]时,需要把描述文字生成的特征向量复制4份,分别拼接到矩阵的后边,就像图中描述的一样。
    3. 再进行一次卷积层-批标准化层-lrelu层操作, 最后拉扯成向量,再通过矩阵相乘得到[batch_size, 1]的结果。值为1代表预测这张图为真实的,否则为生成的。

而两个网络的目标也不同于传统的DCGAN:

  1. 生成器要尽量生成符合文字描述的,尽可能真实的图片。
  2. 辨别器要辨别出真实且符合文字描述的图片。

也就是说,辨别器的训练会接受三种数据:

  1. 生成器根据文字描述生成的图片。辨别器要识别并给出0的输出。
  2. 真实但不符合文字描述的图片。辨别器要识别并给出0的输出。
  3. 符合文字描述的真实图片。辨别器要识别并给出1的输出。

我在参考了文章后,主要做了以下改造:

  1. 原本生成器的激活层主要用relu,辨别器的激活层主要用leaky_relu。现在tanh层不变以外,其它的激活层都用leaky_relu(alpha=0.2)代替。
  2. 原本文字描述转换为长度为4800的向量后,要先压缩到长度为256的向量后才拼接到“噪声”向量上。现在由于显卡内存充裕,我将长度为4800的向量直接拼接到“噪声”向量上。
  3. 将tensorflow各种网络层的api的调用方式替换为封装程度更高的调用方式。

4. 代码实现

参考

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018.06.22 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. tensorflow API学习
    • 1.1 batch normalization
      • 1.2 variables
      • 2. GAN与DCGAN
        • 2.1 GAN
          • 2.2 DCGAN
          • 3. 论文分析
            • 3.1 GAN-CLS
              • 3.2 GAN-INT
                • 3.3 噪声向量z的作用
                • 4. 项目介绍
                  • 4.1 dependencies
                    • 4.2 h5py
                      • 4.3 skip thought vectors
                        • 4.4 网络结构
                        • 4. 代码实现
                        • 参考
                        相关产品与服务
                        图片处理
                        图片处理(Image Processing,IP)是由腾讯云数据万象提供的丰富的图片处理服务,广泛应用于腾讯内部各产品。支持对腾讯云对象存储 COS 或第三方源的图片进行处理,提供基础处理能力(图片裁剪、转格式、缩放、打水印等)、图片瘦身能力(Guetzli 压缩、AVIF 转码压缩)、盲水印版权保护能力,同时支持先进的图像 AI 功能(图像增强、图像标签、图像评分、图像修复、商品抠图等),满足多种业务场景下的图片处理需求。
                        领券
                        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档