前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >ac GAN MNIST代码详细解析

ac GAN MNIST代码详细解析

作者头像
羽翰尘
修改2019-11-26 16:40:07
7980
修改2019-11-26 16:40:07
举报
文章被收录于专栏:技术向技术向

本文由腾讯云+社区自动同步,原文地址 https://stackoverflow.club/article/acgan_reading_understanding/

背景

最近希望对UCR Time Series的时间序列进行数据扩增,考虑用GAN实现。但是GAN的种类似乎很多,各种教程也只是粗浅说说原理,贴贴代码,达不到能够自己写代码的程度。所以打算花点时间对AC GAN的代码剖析一下,使其与原理较好地对应起来。

原理

生成器

  1. 根据latent生成原始3x3图像
代码语言:txt
复制
   def build_generator(latent_size):
    # we will map a pair of (z, L), where z is a latent vector and L is a
    # label drawn from P_c, to image space (..., 28, 28, 1)
    cnn = Sequential()
    cnn.add(Dense(3 * 3 * 384, input_dim=latent_size, activation='relu'))
    cnn.add(Reshape((3, 3, 384)))
  1. 上采样到7x7

这里提到上采样到7x7, 为什么就能实现上采样呢?因为这里是反卷积,filter数目是192, kernel size是5,并且padding是valid,表示仅对有效像素卷积。如果原始图像是7x7,那么经过这样的卷积之后,输出3x3的图像。

代码语言:txt
复制
# upsample to (7, 7, ...)
cnn.add(Conv2DTranspose(192, 5, strides=1, padding='valid',
                        activation='relu',
                        kernel_initializer='glorot_normal'))
cnn.add(BatchNormalization())
  1. 上采样到14x14

只有原始大小为14x14,经过kernel size=5,且步长为2,有padding的卷积之后,输出为7x7.

由此观察可得,如果希望反卷积之后的大小不成比例,就用no padding, 步长为1;反之用padding,步长为2.

代码语言:txt
复制
# upsample to (14, 14, ...)
cnn.add(Conv2DTranspose(96, 5, strides=2, padding='same',
                        activation='relu',
                        kernel_initializer='glorot_normal'))
cnn.add(BatchNormalization())
  1. 继续上采样到28x28,且输出一个通道

显然,只有filter的数目为1,我们才能得到单独的一张图片。

代码语言:txt
复制
# upsample to (28, 28, ...)
cnn.add(Conv2DTranspose(1, 5, strides=2, padding='same',
                        activation='tanh',
                        kernel_initializer='glorot_normal'))
# this is the z space commonly referred to in GAN papers
latent = Input(shape=(latent_size, ))
# this will be our label
image_class = Input(shape=(1,), dtype='int32')
  1. 嵌入层的用处?

Keras的嵌入层官方文档

文档上说,Embedding层的输入是(batch_size, sequence_length), 输出是(batch_size, sequence_length, output_dim)

官方说嵌入层的作用是Turns positive integers (indexes) into dense vectors of fixed size.

eg1. 4, 20 -> 0.25, 0.1, 0.6, -0.2

eg2. 32x10(1000内的数字) -> 32x10x64(64是output_dim)

代码语言:txt
复制
cls = Embedding(num_classes, latent_size,
                embeddings_initializer='glorot_normal')(image_class)
# hadamard product between z-space and a class conditional embedding
h = layers.multiply([latent, cls])
fake_image = cnn(h)
return Model([latent, image_class], fake_image)
  1. 反卷积层
代码语言:txt
复制
keras.layers.convolutional.Conv2DTranspose(filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)

该层是转置的卷积操作(反卷积)。需要反卷积的情况通常发生在用户想要对一个普通卷积的结果做反方向的变换。例如,将具有该卷积层输出shape的tensor转换为具有该卷积层输入shape的tensor。同时保留与卷积层兼容的连接模式。

当使用该层作为第一层时,应提供input_shape参数。例如input_shape = (3,128,128)代表128*128的彩色RGB图像

padding:补0策略,为“valid”, “same” 。“valid”代表只进行有效的卷积,即对边界数据不处理。“same”代表保留边界处的卷积结果,通常会导致输出shape与输入shape相同。

参考官方教程

训练代码

此处使用的代码是Keras的官方示例代码,地址

重要的部分的代码由两个for循环构成,外层for循环控制Epoch,内层for循环控制batch,从内层for循环开始剖析。

  1. 获取训练集的一部分,即真实图片 12image_batch = x_trainindex * batch_size:(index + 1) * batch_sizelabel_batch = y_trainindex * batch_size:(index + 1) * batch_size
  2. 生成noise,这个noise对生成假图片有重要作用 1noise = np.random.uniform(-1, 1, (len(image_batch), latent_size))
  3. 随机生成一些标签 1sampled_labels = np.random.randint(0, num_classes, len(image_batch))
  4. 生成假图片

此处用到了noise和假标签。 generator接受两个输入,一个是随机噪声,另一个是label,并且根据label生成对应的图片。

代码语言:txt
复制
generated_images = generator.predict(
                [noise, sampled_labels.reshape((-1, 1))], verbose=0)
  1. 连接真假图片

这里的image_batch是真图片,generated_images是假图片。 np.concatenate()将两个矩阵连在一起,如 1,2,3,4 连接 5,6之后就是1,2,3,4,5,6.

代码语言:txt
复制
x = np.concatenate((image_batch, generated_images))
  1. 生成标签?

生成了两种标签,一种是代表真和假的y,另一种是代表具体类别的aux_y.

最后生成的y类似于这种0.95,0.95,0.95,…(长度是len(image_batch)),0,0,0,…(长度是len(image_batch))

代码语言:txt
复制
soft_zero, soft_one = 0, 0.95
y = np.array([soft_one] * len(image_batch) + [soft_zero] * len(image_batch))
aux_y = np.concatenate((label_batch, sampled_labels), axis=0)
代码语言:txt
复制
# we don't want the discriminator to also maximize the classification
# accuracy of the auxiliary classifier on generated images, so we
# don't train discriminator to produce class labels for generated
# images (see https://openreview.net/forum?id=rJXTf9Bxg).
# To preserve sum of sample weights for the auxiliary classifier,
# we assign sample weight of 2 to the real images.
  1. 分配权重并训练分类器

TODO 这里的权重有点看不太懂

分类器在训练的时候输入的x是真假图片都有,标签由两部分构成,一部分是y,代表是不是真的,一部分是aux_y, 代表对应的类别。

代码语言:txt
复制
disc_sample_weight = [np.ones(2 * len(image_batch)),np.concatenate((np.ones(len(image_batch)) * 2,
                     np.zeros(len(image_batch))))]   
epoch_disc_loss.append(discriminator.train_on_batch(x, [y, aux_y], sample_weight=disc_sample_weight))
代码语言:txt
复制
# make new noise. we generate 2 * batch size here such that we have
# the generator optimize over an identical number of images as the
# discriminator
  1. 再次生成假数据

注意这里的都乘了2,是希望生成器获得更多的训练,从而比分类器性能更好。

代码语言:txt
复制
noise = np.random.uniform(-1, 1, (2 * len(image_batch), latent_size))
sampled_labels = np.random.randint(0, num_classes, 2 * len(image_batch))
代码语言:txt
复制
# we want to train the generator to trick the discriminator
# For the generator, we want all the {fake, not-fake} labels to say
# not-fake
  1. 训练combined模型

combined模型由分类器和生成器构成,分类器的权重被限制不能更改。

注意这里训练时输入的第二组数[trick, sampled_labels], trick代表是否是真图片,sampled_labels代表图片的分类。这里为了在训练时让生成器更倾向于生成更像真图片的假图片,人为让trick的值全部为soft_one,即非常接近1.

代码语言:txt
复制
trick = np.ones(2 * len(image_batch)) * soft_one
epoch_gen_loss.append(combined.train_on_batch(
                [noise, sampled_labels.reshape((-1, 1))],
                [trick, sampled_labels]))
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019-06-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 背景
  • 原理
  • 生成器
  • 训练代码
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档