首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

keras学习笔记-黑白照片自动着色的神经网络-Beta版

正文共3894个字,8张图,预计阅读时间11分钟。

Alpha版本不能很好地给未经训练的图像着色。接下来,我们将在Beta版本中做到这一点——将上面的将神经网络泛化。

以下是使用Beta版本对测试图像着色的结果。

特征提取器

我们的神经网络要做的是发现将灰度图像与其彩色版本相链接的特征。

试想,你必须给黑白图像上色,但一次只能看到9个像素。你可以从左上角到右下角扫描每个图像,并尝试预测每个像素应该是什么颜色。

例如,这9个像素就是上面那张女性人脸照片上鼻孔的边缘。要很好的着色几乎是不可能的,所以你必须把它分解成好几个步骤。

首先,寻找简单的模式:对角线,所有黑色像素等。在每个滤波器的扫描方块中寻找相同的精确的模式,并删除不匹配的像素。这样,就可以从64个迷你滤波器生成64个新图像。

如果再次扫描图像,你会看到已经检测到的相同的模式。要获得对图像更高级别的理解,你可以将图像尺寸减小一半。

你仍然只有3×3个滤波器来扫描每个图像。但是,通过将新的9个像素与较低级别的滤波器相结合,可以检测更复杂的图案。一个像素组合可能形成一个半圆,一个小点或一条线。再一次地,你从图像中反复提取相同的图案。这次,你会生成128个新的过滤图像。

经过几个步骤,生成的过滤图像可能看起来像这样:

这个过程就像大多数处理视觉的神经网络,也即卷积神经网络的行为。结合几个过滤图像了解图像中的上下文。

from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D

from keras.layers import Activation, Dense, Dropout, Flatten, InputLayer

from keras.layers.normalization import BatchNormalization

from keras.callbacks import TensorBoard

from keras.models import Sequential

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

from skimage.color import rgb2lab, lab2rgb, rgb2gray

from skimage.io import imsave

import numpy as np

import osimport random

import tensorflow as tf

Using TensorFlow backend.

# Get imagesX = []

for filename in os.listdir('data/color/Train/'):

X.append(img_to_array(load_img('data/color/Train/'+filename)))X = np.array(X, dtype=float)

# Set up train and test data

split = int(0.95*len(X))Xtrain = X[:split]Xtrain = 1.0/255*Xtrainmodel = Sequential()

model.add(InputLayer(input_shape=(256, 256, 1)))model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))model.add(Conv2D(64, (3, 3), activation='relu', padding='same', strides=2))model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))model.add(Conv2D(128, (3, 3), activation='relu', padding='same', strides=2))model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))model.add(Conv2D(256, (3, 3), activation='relu', padding='same', strides=2))model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))model.add(UpSampling2D((2, 2)))model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))model.add(UpSampling2D((2, 2)))model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))model.add(UpSampling2D((2, 2)))model.compile(optimizer='rmsprop', loss='mse')

# Image transformerdatagen = ImageDataGenerator( shear_range=0.2, zoom_range=0.2, rotation_range=20, horizontal_flip=True)

# Generate training databatch_size = 10def image_a_b_gen(batch_size): for batch in datagen.flow(Xtrain, batch_size=batch_size): lab_batch = rgb2lab(batch) X_batch = lab_batch[:,:,:,0] Y_batch = lab_batch[:,:,:,1:] / 128 yield (X_batch.reshape(X_batch.shape+(1,)), Y_batch)

# Train model

tensorboard = TensorBoard(log_dir="data/color/output/first_run")model.fit_generator(image_a_b_gen(batch_size), callbacks=[tensorboard], epochs=1, steps_per_epoch=10)

Epoch 1/110/10 [==============================] - 178s - loss: 0.5208

# Save modelmodel_json = model.to_json()with open("model.json", "w") as json_file: json_file.write(model_json)model.save_weights("model.h5")

# Test imagesXtest = rgb2lab(1.0/255*X[split:])[:,:,:,0]

Xtest = Xtest.reshape(Xtest.shape+(1,))Ytest = rgb2lab(1.0/255*X[split:])[:,:,:,1:]Ytest = Ytest / 128

print(model.evaluate(Xtest, Ytest, batch_size=batch_size))

color_me = []for filename in os.listdir('data/color/Test/'): color_me.append(img_to_array(load_img('data/color/Test/'+filename)))color_me = np.array(color_me, dtype=float)color_me = rgb2lab(1.0/255*color_me)[:,:,:,0]color_me = color_me.reshape(color_me.shape+(1,))# Test modeloutput = model.predict(color_me)output = output * 128# Output colorizationsfor i in range(len(output)): cur = np.zeros((256, 256, 3)) cur[:,:,0] = color_me[i][:,:,0] cur[:,:,1:] = output[i] imsave("data/color/output/img1_"+str(i)+".png", lab2rgb(cur))

/usr/local/lib/python3.6/site-packages/skimage/util/dtype.py:122: UserWarning: Possible precision loss when converting from float64 to uint8 .format(dtypeobj_in, dtypeobj_out))

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180126A0VZ6100?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券