专栏首页AI研习社只要130 行代码即可生成二维样本,心动了吗?

只要130 行代码即可生成二维样本,心动了吗?

AI 研习社按:此前我们曾编译了一篇英文教程,详细介绍了如何基于 PyTorch 平台用 50 行代码实现 GAN(生成对抗网络),详情参见:在 PyTorch 上跑 GAN 只需要 50 行代码,不试试?近期,针对文中介绍的“50 行代码 GAN 模型”,有开发者指出了局限性,并基于此模型给出了改进版本,也就是本文将要介绍的“130 行代码实现 GAN 二维样本”。本文原载于知乎专栏,作者達聞西,雷锋网 AI 研习社经授权发布。文中相关链接详见文末“阅读原文”。

50行GAN代码的问题

Dev Nag 写的 50 行代码的 GAN,大概是网上流传最广的,关于GAN最简单的小例子。这是一份用一维均匀样本作为特征空间(latent space)样本,经过生成网络变换后,生成高斯分布样本的代码。结构非常清晰,却有一个奇怪的问题,就是判别器(Discriminator)的输入不是2维样本,而是把整个mini-batch整体作为一个维度是batch size(代码中batch size等于cardinality)那么大的样本。也就是说判别网络要判别的不是一个一维的目标分布,而是batch size那么大维度的分布:

... d_input_size = 100 # Minibatch size - cardinality of distributions ... class Discriminator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Discriminator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) def forward(self, x): x = F.elu(self.map1(x)) x = F.elu(self.map2(x)) return F.sigmoid(self.map3(x)) ... D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size) ... for epoch in range(num_epochs): for d_index in range(d_steps): # 1. Train D on real+fake D.zero_grad() # 1A: Train D on real d_real_data = Variable(d_sampler(d_input_size)) d_real_decision = D(preprocess(d_real_data)) d_real_error = criterion(d_real_decision, Variable(torch.ones(1))) # ones = true d_real_error.backward() # compute/store gradients, but don't change params # 1B: Train D on fake d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size)) d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels d_fake_decision = D(preprocess(d_fake_data.t())) d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1))) # zeros = fake d_fake_error.backward() d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward() for g_index in range(g_steps): # 2. Train G on D's response (but DO NOT train D on these labels) G.zero_grad() gen_input = Variable(gi_sampler(minibatch_size, g_input_size)) g_fake_data = G(gen_input) dg_fake_decision = D(preprocess(g_fake_data.t())) g_error = criterion(dg_fake_decision, Variable(torch.ones(1))) # we want to fool, so pretend it's all genuine g_error.backward() g_optimizer.step() # Only optimizes G's parameters ...

不知作者是疏忽了还是有意为之,总之这么做的结果就是如此简单的例子收敛都好。可能作者自己也察觉了收敛问题,就想把方差信息也放进来,于是又写了个预处理函数(decorate_with_diffs)计算出每个样本距离一批样本中心的距离平方,作为给判别网络的额外输入,其实这样还增加了输入维度。结果当然是加不加这个方差信息都能勉强收敛,但是都不稳定。甚至作者自己贴出来的生成样本分布(下图)都不令人满意:

如果直接把这份代码改成二维的,就会发现除了简单的对称分布以外,其他分布基本都无法生成。

理论上讲神经网络作为一种通用的近似函数,只要capacity够,学习多少维分布都不成问题,但是这样写法显然极大增加了收敛难度。更自然的做法应该是:判别网络只接受单个二维样本,通过batch size或是多步迭代学习分布信息。

另:这份代码其实有130行。

从自定义的二维分布采样

不管怎样Dev Nag的代码还是提供了一个用于理解和试验GAN的很好的框架,做一些修改就可以得到一份更适合直观演示,且更容易收敛的代码,也就是本文的例子。

从可视化的角度二维显然比一维更直观,所以我们采用二维样本。第一步,当然是要设定一个目标分布,作为二维的例子,分布的定义方式应该尽量自由,这个例子中我们的思路是通过灰度图像定义的概率密度,进而来产生样本,比如下面这样:

二维情况下,这种采样的一个实现方法是:求一个维度上的边缘(marginal)概率+另一维度上近似的条件概率。比如把图像中白色像素的值作为概率密度的相对大小,然后沿着x求和,然后在y轴上求出marginal probability density,接着再根据y的位置,近似得到对应x关于y的条件概率。采样的时候先采y的值,再采x的值就能近似得到符合图像描述的分布的样本。具体细节就不展开讲解了,看代码:

from functools import partial import numpy from skimage import transform EPS = 1e-6 RESOLUTION = 0.001 num_grids = int(1/RESOLUTION+0.5) def generate_lut(img): """ linear approximation of CDF & marginal :param density_img: :return: lut_y, lut_x """ density_img = transform.resize(img, (num_grids, num_grids)) x_accumlation = numpy.sum(density_img, axis=1) sum_xy = numpy.sum(x_accumlation) y_cdf_of_accumulated_x = [[0., 0.]] accumulated = 0 for ir, i in enumerate(range(num_grids-1, -1, -1)): accumulated += x_accumlation[i] if accumulated == 0: y_cdf_of_accumulated_x[0][0] = float(ir+1)/float(num_grids) elif EPS < accumulated < sum_xy - EPS: y_cdf_of_accumulated_x.append([float(ir+1)/float(num_grids), accumulated/sum_xy]) else: break y_cdf_of_accumulated_x.append([float(ir+1)/float(num_grids), 1.]) y_cdf_of_accumulated_x = numpy.array(y_cdf_of_accumulated_x) x_cdfs = [] for j in range(num_grids): x_freq = density_img[num_grids-j-1] sum_x = numpy.sum(x_freq) x_cdf = [[0., 0.]] accumulated = 0 for i in range(num_grids): accumulated += x_freq[i] if accumulated == 0: x_cdf[0][0] = float(i+1) / float(num_grids) elif EPS < accumulated < sum_xy - EPS: x_cdf.append([float(i+1)/float(num_grids), accumulated/sum_x]) else: break x_cdf.append([float(i+1)/float(num_grids), 1.]) if accumulated > EPS: x_cdf = numpy.array(x_cdf) x_cdfs.append(x_cdf) else: x_cdfs.append(None) y_lut = partial(numpy.interp, xp=y_cdf_of_accumulated_x[:, 1], fp=y_cdf_of_accumulated_x[:, 0]) x_luts = [partial(numpy.interp, xp=x_cdfs[i][:, 1], fp=x_cdfs[i][:, 0]) if x_cdfs[i] is not None else None for i in range(num_grids)] return y_lut, x_luts def sample_2d(lut, N): y_lut, x_luts = lut u_rv = numpy.random.random((N, 2)) samples = numpy.zeros(u_rv.shape) for i, (x, y) in enumerate(u_rv): ys = y_lut(y) x_bin = int(ys/RESOLUTION) xs = x_luts[x_bin](x) samples[i][0] = xs samples[i][1] = ys return samples if __name__ == '__main__': from skimage import io density_img = io.imread('batman.jpg', True) lut_2d = generate_lut(density_img) samples = sample_2d(lut_2d, 10000) from matplotlib import pyplot fig, (ax0, ax1) = pyplot.subplots(ncols=2, figsize=(9, 4)) fig.canvas.set_window_title('Test 2D Sampling') ax0.imshow(density_img, cmap='gray') ax0.xaxis.set_major_locator(pyplot.NullLocator()) ax0.yaxis.set_major_locator(pyplot.NullLocator()) ax1.axis('equal') ax1.axis([0, 1, 0, 1]) ax1.plot(samples[:, 0], samples[:, 1], 'k,') pyplot.show()

二维GAN的小例子

虽然网上到处都有,这里还是贴一下GAN的公式:

就是一个你追我赶的零和博弈,这在Dev Nag的代码里体现得很清晰:判别网络训一拨,然后生成网络训一拨,不断往复。按照上节所述,本文例子在Dev Nag代码的基础上,把判别网络每次接受一个batch作为输入的方式变成了:每次接受一个二维样本,通过每个batch的多个样本计算loss。GAN部分的训练代码如下:

DIMENSION = 2 ... generator = SimpleMLP(input_size=z_dim, hidden_size=args.g_hidden_size, output_size=DIMENSION) discriminator = SimpleMLP(input_size=DIMENSION, hidden_size=args.d_hidden_size, output_size=1) ... for train_iter in range(args.iterations): for d_index in range(args.d_steps): # 1. Train D on real+fake discriminator.zero_grad() # 1A: Train D on real real_samples = sample_2d(lut_2d, bs) d_real_data = Variable(torch.Tensor(real_samples)) d_real_decision = discriminator(d_real_data) labels = Variable(torch.ones(bs)) d_real_loss = criterion(d_real_decision, labels) # ones = true # 1B: Train D on fake latent_samples = torch.randn(bs, z_dim) d_gen_input = Variable(latent_samples) d_fake_data = generator(d_gen_input).detach() # detach to avoid training G on these labels d_fake_decision = discriminator(d_fake_data) labels = Variable(torch.zeros(bs)) d_fake_loss = criterion(d_fake_decision, labels) # zeros = fake d_loss = d_real_loss + d_fake_loss d_loss.backward() d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward() for g_index in range(args.g_steps): # 2. Train G on D's response (but DO NOT train D on these labels) generator.zero_grad() latent_samples = torch.randn(bs, z_dim) g_gen_input = Variable(latent_samples) g_fake_data = generator(g_gen_input) g_fake_decision = discriminator(g_fake_data) labels = Variable(torch.ones(bs)) g_loss = criterion(g_fake_decision, labels) # we want to fool, so pretend it's all genuine g_loss.backward() g_optimizer.step() # Only optimizes G's parameters ... ...

和Dev Nag的版本比起来除了上面提到的判别网络,和样本维度的修改,还加了可视化方便直观演示和理解,比如用一个二维高斯分布产生一个折线形状的分布,执行:

python gan_demo.py inputs/zig.jpg

训练过程的可视化如下:

更多可视化例子可以参考如下链接:

http://t.cn/Ro8aNJz

Conditional GAN

对于一些复杂的分布,原始的GAN就会很吃力,比如用一个二维高斯分布产生两坨圆形的分布:

因为latent space的分布就是一坨二维的样本,所以即使模型有很强的非线性,也难以把这个分布“切开”并变换成两个很好的圆形分布。因此在上面的动图里能看到生成的两坨样本中间总是有一些残存的样本,像是两个天体在交换物质。要改进这种情况,比较直接的想法是增加模型复杂度,或是提高latent space维度。也许模型可以学习到用其中部分维度产生一个圆形,用另一部分维度产生另一个圆形。不过我自己试了下,效果都不好。

其实这个例子人眼一看就知道是两个分布在一个图里,假设我们已经知道这个信息,那么生成依据的就是个条件概率。把这个条件加到GAN里,就是Conditional GAN,公式如下:

示意图如下:

条件信息变相降低了生成样本的难度,所以生成的样本效果好很多。

在网络中加入条件的方式没有固定的原则,这里我们采用的是可能最常见的方法:用one-hot方式将条件编码成一个向量,然后和原始的输入拼一下。注意对于判别网络和生成网络都要这么做,所以上面公式和C-GAN原文简化过度的公式比起来多了两个y,避免造成迷惑。

C-GAN的代码实现就是GAN的版本基础上,利用pytorch的torch.cat()对条件和输入进行拼接。其中条件的输入就是多张图片,每张定义一部分分布的PDF。比如对于上面两坨分布的例子,就拆成两张图像来定义PDF:

具体实现就不贴这里了,参考本文的Github页面:

http://t.cn/Ro8Svq4

加入条件信息后,两坨分布的生成就轻松搞定了,执行:

python cgan_demo.py inputs/binary

得到下面的训练过程可视化:

对于一些更复杂的分布也不在话下,比如:

这两个图案对应的原始GAN和C-GAN的训练可视化对比可以在这里看到。

本文分享自微信公众号 - AI研习社(okweiwu),作者:達聞西

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2017-07-09

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 只需 130 行代码!用 GAN 生成二维样本的小例子

    50行GAN代码的问题 Dev Nag 写的 50 行代码的 GAN,大概是网上流传最广的,关于GAN最简单的小例子。这是一份用一维均匀样本作为特征空间(l...

    AI研习社
  • 微软推出开源自动驾驶仿真平台 AirSim 教程,机器学习新手也能快速上手自动驾驶

    AirSim 是微软推出的开源项目,用于测试人工智能系统的安全性,该系统提供仿真的环境、车辆动力和感知能力,促进无人驾驶技术的发展。近期,微软的工程师推出了一...

    AI研习社
  • 数据分析从业者必看!10 个加速 python 数据分析的简易小技巧

    AI 开发者按,一些小的技巧在编程领域可能会非常有用,在数据科学领域同样如此。数据科学爱好者 Parul Pandey 在近日发表了一篇博文,分享了在数据科学中...

    AI研习社
  • REdis之maxmemory解读

    redis.conf中的maxmemory定义REdis可用最大物理内存,有多种书写方式,以下均为合法:

    一见
  • 小程序案例丨【明星在线】互联网+明星经济,实现资源高效对接

    明星经纪、IP经济已经成为娱乐向产品买卖主流收益,通过明星影响力、流量数据、衍生产品等维度形成to C经济效益圈,大批商业主已然意识到大浪淘金的年代里,想凝汇“...

    齿轮易创说互联网
  • Human Interface Guidelines — Requesting Permission

    霖酱
  • WebSocket 简介及应用实例

    HTML5 的出现,标志着后 Flash 时代各种现代浏览器的集体爆发,也是谨防 Adobe 一家独大的各家厂商们,历经多年各自为战,想换个活法儿并终于达成一定...

    江米小枣
  • 小程序问世 微信的一手好牌被打烂?

    T客汇官网:tikehui.com 撰文 | 人称T客 ? 昨晚小程序如约而至,被朋友圈刷必不可,有人惊呼狼来了,小程序这是要干掉所有 App 的节奏,但也有人...

    人称T客
  • Flask中无法在其他函数中查询Sqlachemy的解决办法

    方式一 直接实例化app 不要写create_app函数了,在启动文件中直接导入app对象:

    用户1214487
  • 今天,你过节了吗?

    >>> 1024 == 2 ** 10 True 1024 是个神奇的数字。它是 2 的 10 次方,在以二进制为基础的计算机世界里有着特殊的含义。比如文件小的...

    Crossin先生

扫码关注云+社区

领取腾讯云代金券