前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >自动编解码器的训练与数据识别效果解析

自动编解码器的训练与数据识别效果解析

作者头像
望月从良
发布2019-10-30 14:58:07
4570
发布2019-10-30 14:58:07
举报
文章被收录于专栏:Coding迪斯尼Coding迪斯尼

上一节我们构建了自动编解码器网络。我们将图片输入到编码器后,它将数据”压缩“成只包含2个分量的一维向量,该向量输入解码器后,又会被还原成与输入图片非常相似的图片,本节我们看看如何训练编解码器网络,同时看看它构造的中间向量到底意味着什么。

首先我们在上节代码基础上完成网络的训练代码:

代码语言:javascript
复制
from keras.callbacks import Callback, LearningRateScheduler
import os
import matplotlib.pyplot as plt
class CustomCallback(Callback): #在模型训练时通过回调函数输出相关信息
    def  __init__(self, run_folder, print_every_n_batches, initial_epoch, vae):
        self.epoch = initial_epoch
        self.run_folder = run_folder
        self.print_every_n_batches = print_every_n_batches
        self.vae = vae
    def  on_batch_end(self, batch_, logs = {}): #在一次训练循环结束时被调用
        if  batch % self.print_every_n_batches == 0:
            z_new = np.random.normal(size = (1, self.vae.z_dim))
            reconst = self.vae.decoder.predict(np.array(z_new))[0].squeeze() #使用关键向量重新构造图片
            filepath = os.path.join(self.run_folder, 'images', 'img_' + str(self.epoch).zfill(3) + '_' + str(batch) + '.jpg')
            if len(reconst.shape) == 2:
                plt.imsave(filepath, reconst, cmap = 'gray_r')
            else:
                plt.imsave(filepath, reconst)
    def  on_epoch_begin(self, epoch, logs = {}):
        self.epoch += 1
def  step_decay_schedule(initial_lr, decay_factor = 0.5, step_size = 1):
    def  schedule(epoch):
        new_lr = initial_lr * (decay_factor ** np.floor(epoch / step_size))
        return new_lr
    return LearningRateScheduler(schedule)

class AutoEncoder():
    ....
    def  compile(self, learning_rate):
        self.learning_rate = learning_rate
        optimizer = Adam(lr = learning_rate)
        def  r_loss(y_true, y_pred): #输出图像与输入图像像素点差值的平方和越小表示网络解码后恢复的图像越好
            return K.mean(K.square(y_true - y_pred), axis = [1,2,3])
        self.model.compile(optimizer = optimizer, loss = r_loss)
    def  train(self, x_train, batch_size,run_folder, epochs, print_every_n_batchs = 100,
               initial_epoch = 0, lr_decay = 1):
        custom_callback = CustomCallback(run_folder, print_every_n_batches, initial_epoch, self)
        lr_sched = step_decay_schedule(initial_lr = self.learning_rate, decay_factor = lr_decay, step_size = 1) #每训练一次学习率就适当减少一些
        checkpoint2 = ModelCheckpoint(os.path.join(run_folder, 'weights/weights.h5'), save_weights_only = True, verbose = 1)
        callback_list = [checkpoint2, custom_callback, lr_sched]
        self.model.fit(x_train, x_train, batch_size = batch_size, shuffle = True,
                       epochs = epochs, intial_epoch = initial_epoch,
                      callbacks = callback_list)
    ....

完成网络的训练代码后,我们就可以加载数据训练网络:

代码语言:javascript
复制
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1, ))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))
LEARNING_RATE = 0.0005
BATCH_SIZE = 32
INITIAL_EPOCH = 0
AE.compile(LEARNING_RATE)

AE.train(x_train[: 1000], batch_size = BATCH_SIZE, epochs = 200,
        run_folder = 'C:\\Users\\cheny\\Desktop\\VAE\\', initial_epoch = INITIAL_EPOCH)

在运行代码前,我们现在目录 ‘C:\Users\cheny\Desktop\VAE\‘创建两个文件夹名为’images’和’weights’,然后执行上面代码,启动训练流程,然后在images文件夹中会出现网络还原的数字图片,随着训练循环的增加,还原图片的显示效果会越来越好。

我们看看上面代码运行时解码器还原的图片状况,首先我们看看循环训练在63次时还原的图片情况如下:

训练循环经过200次后,解码器还原的图片情况如下:

这时我们看到解码器还原图片的质量有了很大提升,从上图我们知道他还原的图片是手写数字9,我们可以看到还原的图片质量不是很好,在后面我们会研究如何提升它的还原质量,从这里也可以看出编解码器就像一个有损压缩和解压器,它把输入图片压缩成只有2个元素的向量,然后再将向量还原成输入图片,但是还原的图片质量有所下降。

现在我们需要探究的是,编码器将图片转换成的含有2个元素的向量意味着什么,我们将图片输入编码器,得到他转换的向量,然后将向量绘制到二维平面上看看是什么情况,代码如下:

代码语言:javascript
复制
import matplotlib.pyplot as plt
n_to_show = 5000
grid_size = 15
figsize = 12
example_idx = np.random.choice(range(len(x_test)), n_to_show)
example_images = x_test[example_idx]
example_labels = y_test[example_idx]
z_points = AE.encoder.predict(example_images)
plt.figure(figsize=(figsize, figsize))
plt.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels
            , alpha=0.5, s=2)
plt.colorbar()
plt.show()

代码运行后给出结果如下图:

我们注意看图中点的分布。相同颜色的点会聚集在一起形成一个集团。事实上相同颜色的点对应同一种数字图片,也就是说相同数字的图片,即使图片内显示的内容有差异,导致编码器解读图片后生成的含有两个元素的一维向量不一样,但是他们在距离上非常接近,以至于能形成一个集合。

于是两张内容为6的手写数字图片所经过编码器后所产生的向量在距离上比内容为6和内容为4的图片所产生的向量之间距离更短。于是编码器将相同数字的手写图片转换为同一个区域内的二维点集合在一起。

从上图也能看出,有些颜色的点很聚集,有些颜色的点很分散,例如左下角紫色的点,这意味着编码器对紫色点对应的手写数字图片特征认识还不是很深入,这也是解码器生成图片质量不够好的原因。

现在编码器生成向量对应的二维点属于离散性质,这意味着假设坐标点(2, -2)能解码成图片4,但点(2.1, -2.1)可就不能保证也解码成图片4了,下一节我们就要处理这个问题,只有编码器将图片转换成的向量是”连续“性质,也就是两个点的距离只要足够小,它们就应该能解码成相同图片,即使图片的展现形式可能有所不同。而编码器将点转换成的向量具有“连续性”我们才有可能实现“变脸”功能。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-10-29,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Coding迪斯尼 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
图片处理
图片处理(Image Processing,IP)是由腾讯云数据万象提供的丰富的图片处理服务,广泛应用于腾讯内部各产品。支持对腾讯云对象存储 COS 或第三方源的图片进行处理,提供基础处理能力(图片裁剪、转格式、缩放、打水印等)、图片瘦身能力(Guetzli 压缩、AVIF 转码压缩)、盲水印版权保护能力,同时支持先进的图像 AI 功能(图像增强、图像标签、图像评分、图像修复、商品抠图等),满足多种业务场景下的图片处理需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档