TensorFlow应用实战-14-编写训练的python文件

编写训练的python文件

# -*- coding: UTF-8 -*-

"""
训练 DCGAN
"""

import os
import glob
import numpy as np
from scipy import misc
import keras as tf.keras

from network import *


def train():

if __name__ == "__main__":
    train()

获取训练数据

# 获取训练数据
    data = []
    for image in glob.glob("images/*"):
        # 读取图片,返回一个数组对象
        image_data = misc.imread(image)  # imread 利用 PIL 来读取图片数据
        data.append(image_data)
    input_data = np.array(data)

将数据进行标准化

 # 将数据标准化成 [-1, 1] 的取值, 这也是 Tanh 激活函数的输出范围
    input_data = (input_data.astype(np.float32) - 127.5) / 127.5

tanh的取值范围是-1 到 1

像素值最大255 减去一半127.5 再除以 127.5 被限制到-1到1之间。

构造生成器和判别器

    # 构造 生成器 和 判别器
    g = generator_model()
    d = discriminator_model()

构建生成器和判别器组成的网络模型

# 构建 生成器 和 判别器 组成的网络模型
    d_on_g = generator_containing_discriminator(g, d)

里面的参数传入g和d

优化器使用Adam optimizers

 # 优化器用 Adam Optimizer
    g_optimizer = tf.keras.optimizers.Adam(lr=LEARNING_RATE, beta_1=BETA_1)
    d_optimizer = tf.keras.optimizers.Adam(lr=LEARNING_RATE, beta_1=BETA_1)

学习率是我们之前定义的学习率。beta_1 参数。

使用compile方法对于神经网络进行配置 生成器 和 判别器

# 配置 生成器 和 判别器
    g.compile(loss="binary_crossentropy", optimizer=g_optimizer)
    d_on_g.compile(loss="binary_crossentropy", optimizer=g_optimizer)
    d.trainable = True
    d.compile(loss="binary_crossentropy", optimizer=d_optimizer)

交叉熵损失函数。固定住判别器去优化生成器。相反固定一方优化另一方。

开始训练

# 开始训练
    for epoch in range(EPOCHS):
        # 每经过一个batchsize大小训练一下
        for index in range(int(input_data.shape[0] / BATCH_SIZE)):
            # 数据切片
            input_batch = input_data[index * BATCH_SIZE : (index + 1) * BATCH_SIZE]

            # 连续型均匀分布的随机数据(噪声)
            random_data = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
            # 生成器 生成的图片数据
            generated_images = g.predict(random_data, verbose=0)
            # 首尾相连,输入自身以及产生的图片
            input_batch = np.concatenate((input_batch, generated_images))
            # 输出的数据要么是0 要么是 1。1就是通过检测,跟真实图片一致。
            output_batch = [1] * BATCH_SIZE + [0] * BATCH_SIZE

            # 训练 判别器,让它具备识别不合格生成图片的能力
            d_loss = d.train_on_batch(input_batch, output_batch)

            # 当训练 生成器 时,让 判别器 不可被训练
            d.trainable = False

            # 重新生成随机数据。很关键
            random_data = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))

            # 训练 生成器,并通过不可被训练的 判别器 去判别
            g_loss = d_on_g.train_on_batch(random_data, [1] * BATCH_SIZE)

            # 恢复 判别器 可被训练
            d.trainable = True

            # 打印损失
            print("Epoch {}, 第 {} 步, 生成器的损失: {:.3f}, 判别器的损失: {:.3f}".format(epoch, index, g_loss, d_loss))

保存生成器和判别器的参数

        # 保存 生成器 和 判别器 的参数
        # 大家也可以设置保存时名称不同(比如后接 epoch 的数字),参数文件就不会被覆盖了
        if epoch % 10 == 9:
            g.save_weights("generator_weight", True)
            d.save_weights("discriminator_weight", True)

当我们训练完成,会生成一个generator_weight文件

它是一个h5py的文件。

pip install h5py

编写神经网络生成图片的方法

# -*- coding: UTF-8 -*-

"""
用 DCGAN 的生成器模型 和 训练得到的生成器参数文件 来生成图片
"""

import numpy as np
from PIL import Image
import keras as tf.keras

from network import *


def generate():
    # 构造生成器
    g = generator_model()

    # 配置 生成器
    g.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(lr=LEARNING_RATE, beta_1=BETA_1))

    # 加载训练好的 生成器 参数
    g.load_weights("generator_weight")

    # 连续型均匀分布的随机数据(噪声)
    random_data = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))

    # 用随机数据作为输入,生成器 生成图片数据
    images = g.predict(random_data, verbose=1)

    # 用生成的图片数据生成 PNG 图片
    for i in range(BATCH_SIZE):
        # 将被限制到-1到1之间的数据进行还原
        image = images[i] * 127.5 + 127.5
        Image.fromarray(image.astype(np.uint8)).save("image-%s.png" % i)


if __name__ == "__main__":
    generate()

代码完成与测试模型

一个错误的个人使用,因为我的TensorFlow版本较老。keras并没有被集成进来。

我以为可以

import keras as tf.keras

但是测试失败了,直接把全部的tf.keras全部替换为keras

新的风暴

throws OOM when allocating tensor with shape

又是穷人才会遇到的问题

将batch_size大小从128改为64可以正常训练

mark

然后使用generator.py生成图片。

 # 配置 生成器 和 判别器
    g.compile(loss="binary_crossentropy", optimizer=g_optimizer)
    d_on_g.compile(loss="binary_crossentropy", optimizer=g_optimizer)
    d.trainable = True
    d.compile(loss="binary_crossentropy", optimizer=d_optimizer)

让判别器先可以训练,再设置。我们训练生成器的随机数据不应该和训练整个dong 的一样,不然不够随机化。

基本都得训练好几个小时。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

用深度学习每次得到的结果都不一样,怎么办?

AI研习社按:本文作者 Jason Brownlee 为澳大利亚知名机器学习专家、教育者,对时间序列预测尤有心得。原文发布于其博客。AI研习社崔静闯、朱婷编译。...

5353
来自专栏崔庆才的专栏

TensorFlow RNN Cell源码解析

本文介绍下 RNN 及几种变种的结构和对应的 TensorFlow 源码实现,另外通过简单的实例来实现 TensorFlow RNN 相关类的调用。 RNN R...

3875
来自专栏简书专栏

基于tensorflow+RNN的MNIST数据集手写数字分类

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 RNN是recurrent neural network的简...

683
来自专栏有趣的Python

TensorFlow应用实战-7-工具类与网络模型

编写转换midi到mp3的方法 # -*- coding: UTF-8 -*- import os import subprocess def conver...

2755
来自专栏Soul Joy Hub

TensorFlow指南(一)——上手TensorFlow

http://blog.csdn.net/u011239443/article/details/79066094 TensorFlow是谷歌开源的深度学习库...

4155
来自专栏智能算法

深度学习三人行(第1期)---- TensorFlow爱之初体验

前面十个系列,我们一起学习了机器学习的相关知识,详情可在“智能算法”微信公众号中回复“机器学习”进行查看学习及代码实战。从该期开始,我们将一起学习深度学习相关知...

37214
来自专栏人工智能LeadAI

深度学习框架之一:Theano | Lasagne简单教程

参考Lasagne官网(http://lasagne.readthedocs.io/en/latest/)tutorial进行总结而来。 01 简介 Lasag...

4085
来自专栏悦思悦读

决策树告诉你Hello Kitty到底是人是猫

Hello Kitty,一只以无嘴造型40年来风靡全球的萌萌猫,在其40岁生日时,居然被其形象拥有者宣称:HelloKitty不是猫! 2014年八月,研究 H...

3127
来自专栏ATYUN订阅号

【学术】在C ++中使用TensorFlow训练深度神经网络

你可能知道TensorFlow的核心是用C++构建的,然而只有python的API才能获得多种便利。 当我写上一篇文章时,目标是仅使用TensorFlow的C ...

50311
来自专栏深度学习思考者

一文解决OpenCV训练分类器制作xml文档的所有问题

一 前言 关于训练分类器制作XML文档时需要的两个exe应用程序的解释。   opencv_createsamples :用来准备训练用的正样本数据和测试数据...

4206

扫码关注云+社区