如何捕获一只彩色卓别林?黑白照片AI上色教程很友好 | 哈佛大触

方栗子 编译自 GitHub 量子位 出品 | 公众号 QbitAI
老照片的手动着色魔法

妈妈小时候已经有彩色照片了,不过那些照片,还是照相馆的人类手动上色的。

几十年之后,人们已经开始培育深度神经网络,来给老照片和老电影上色了。

来自哈佛大学的Luke Melas-Kyriazi (我叫他卢克吧) ,用自己训练的神经网络,把卓别林变成了彩色的卓别林,清新自然。

视频内容

作为一只哈佛学霸,卢克还为钻研机器学习的小伙伴们写了一个基于PyTorch的教程。

虽然教程里的模型比给卓别林用的模型要简约一些,但效果也是不错了。

问题是什么?

卢克说,给黑白照片上色这个问题的难点在于,它是多模态的——与一幅灰度图像对应的合理彩色图像,并不唯一。

这并不是正确示范

传统模型需要输入许多额外信息,来辅助上色。

而深度神经网络,除了灰度图像之外,不需要任何额外输入,就可以完成上色。

在彩色图像里,每个像素包含三个值,即亮度饱和度以及色调

而灰度图像,并无饱和度色调可言,只有亮度一个值。

所以,模型要用一组数据,生成另外两足数据。换句话说,以灰度图像为起点,推断出对应的彩色图像。

为了简单,这里只做了256 x 256像素的图像上色。输出的数据量则是256 x 256 x 2。

关于颜色表示,卢克用的是LAB色彩空间,它跟RGB系统包含的信息是一样的。

但对程序猿来说,前者比较方便把亮度和其他两项分离开来。

数据也不难获得,卢克用了MIT Places数据集,中的一部分。内容就是校园里的一些地标和风景。然后转换成黑白图像,就可以了。以下为数据搬运代码——

1# Download and unzip (2.2GB)
2!wget http://data.csail.mit.edu/places/places205/testSetPlaces205_resize.tar.gz
3!tar -xzf testSetPlaces205_resize.tar.gz
1# Move data into training and validation directories
2import os
3os.makedirs('images/train/class/', exist_ok=True) # 40,000 images
4os.makedirs('images/val/class/', exist_ok=True)   #  1,000 images
5for i, file in enumerate(os.listdir('testSet_resize')):
6  if i < 1000: # first 1000 will be val
7    os.rename('testSet_resize/' + file, 'images/val/class/' + file)
8  else: # others will be val
9    os.rename('testSet_resize/' + file, 'images/train/class/' + file)
1# Make sure the images are there
2from IPython.display import Image, display
3display(Image(filename='images/val/class/84b3ccd8209a4db1835988d28adfed4c.jpg'))

好用的工具有哪些?

搭建模型和训练模型是在PyTorch里完成的。

还用了torchvishion,这是一套在PyTorch上处理图像和视频的工具。

另外,scikit-learn能完成图片在RGB和LAB色彩空间之间的转换。

1# Download and import libraries
2!pip install torch torchvision matplotlib numpy scikit-image pillow==4.1.1
1# For plotting
 2import numpy as np
 3import matplotlib.pyplot as plt
 4%matplotlib inline
 5# For conversion
 6from skimage.color import lab2rgb, rgb2lab, rgb2gray
 7from skimage import io
 8# For everything
 9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12# For our model
13import torchvision.models as models
14from torchvision import datasets, transforms
15# For utilities
16import os, shutil, time
1# Check if GPU is available
2use_gpu = torch.cuda.is_available()

模型长什么样?

神经网络里面,第一部分是几层用来提取图像特征;第二部分是一些反卷积层 (Deconvolutional Layers) ,用来给那些特征增加分辨率。

具体来说,第一部分用的是ResNet-18,这是一个图像分类网络,有18层,以及一些残差连接 (Residual Connections) 。

给第一层做些修改,它就可以接受灰度图像输入了。然后把第6层之后的都去掉。

然后,用代码来定义一下这个模型。

从神经网络的第二部分 (就是那些上采样层) 开始。

 1class ColorizationNet(nn.Module):
 2  def __init__(self, input_size=128):
 3    super(ColorizationNet, self).__init__()
 4    MIDLEVEL_FEATURE_SIZE = 128
 5
 6    ## First half: ResNet
 7    resnet = models.resnet18(num_classes=365) 
 8    # Change first conv layer to accept single-channel (grayscale) input
 9    resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) 
10    # Extract midlevel features from ResNet-gray
11    self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])
12
13    ## Second half: Upsampling
14    self.upsample = nn.Sequential(     
15      nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
16      nn.BatchNorm2d(128),
17      nn.ReLU(),
18      nn.Upsample(scale_factor=2),
19      nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
20      nn.BatchNorm2d(64),
21      nn.ReLU(),
22      nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
23      nn.BatchNorm2d(64),
24      nn.ReLU(),
25      nn.Upsample(scale_factor=2),
26      nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
27      nn.BatchNorm2d(32),
28      nn.ReLU(),
29      nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
30      nn.Upsample(scale_factor=2)
31    )
32
33  def forward(self, input):
34
35    # Pass input through ResNet-gray to extract features
36    midlevel_features = self.midlevel_resnet(input)
37
38    # Upsample to get colors
39    output = self.upsample(midlevel_features)
40    return output

下一步,创建模型吧。

1model = ColorizationNet()

它是怎么训练的?

预测每个像素的色值,用的是回归 (Regression) 的方法。

损失函数 (Loss Function)

所以,用了一个均方误差 (MSE) 损失函数——让预测的色值与参考标准 (Ground Truth) 之间的距离平方最小化。

1criterion = nn.MSELoss()

优化损失函数

这里是用Adam Optimizer优化的。

1optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)

加载数据

用torchtext加载数据。首先定义一个专属的数据加载器 (DataLoader) ,来完成RGB到LAB空间的转换。

 1class GrayscaleImageFolder(datasets.ImageFolder):
 2  '''Custom images folder, which converts images to grayscale before loading'''
 3  def __getitem__(self, index):
 4    path, target = self.imgs[index]
 5    img = self.loader(path)
 6    if self.transform is not None:
 7      img_original = self.transform(img)
 8      img_original = np.asarray(img_original)
 9      img_lab = rgb2lab(img_original)
10      img_lab = (img_lab + 128) / 255
11      img_ab = img_lab[:, :, 1:3]
12      img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
13      img_original = rgb2gray(img_original)
14      img_original = torch.from_numpy(img_original).unsqueeze(0).float()
15    if self.target_transform is not None:
16      target = self.target_transform(target)
17    return img_original, img_ab, target

再来,就是定义训练数据验证数据的转换。

1# Training
2train_transforms = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])
3train_imagefolder = GrayscaleImageFolder('images/train', train_transforms)
4train_loader = torch.utils.data.DataLoader(train_imagefolder, batch_size=64, shuffle=True)
5
6# Validation 
7val_transforms = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
8val_imagefolder = GrayscaleImageFolder('images/val' , val_transforms)
9val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=64, shuffle=False)

辅助函数 (Helper Function)

训练开始之前,要把辅助函数写好,来追踪训练损失,并把图像转回RGB形式。

 1class AverageMeter(object):
 2  '''A handy class from the PyTorch ImageNet tutorial''' 
 3  def __init__(self):
 4    self.reset()
 5  def reset(self):
 6    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
 7  def update(self, val, n=1):
 8    self.val = val
 9    self.sum += val * n
10    self.count += n
11    self.avg = self.sum / self.count
12
13def to_rgb(grayscale_input, ab_input, save_path=None, save_name=None):
14  '''Show/save rgb image from grayscale and ab channels
15     Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
16  plt.clf() # clear matplotlib 
17  color_image = torch.cat((grayscale_input, ab_input), 0).numpy() # combine channels
18  color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
19  color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
20  color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
21  color_image = lab2rgb(color_image.astype(np.float64))
22  grayscale_input = grayscale_input.squeeze().numpy()
23  if save_path is not None and save_name is not None: 
24    plt.imsave(arr=grayscale_input, fname='{}{}'.format(save_path['grayscale'], save_name), cmap='gray')
25    plt.imsave(arr=color_image, fname='{}{}'.format(save_path['colorized'], save_name))

验证

不用反向传播 (Back Propagation),直接用torch.no_grad() 跑模型。

 1def validate(val_loader, model, criterion, save_images, epoch):
 2  model.eval()
 3
 4  # Prepare value counters and timers
 5  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
 6
 7  end = time.time()
 8  already_saved_images = False
 9  for i, (input_gray, input_ab, target) in enumerate(val_loader):
10    data_time.update(time.time() - end)
11
12    # Use GPU
13    if use_gpu: input_gray, input_ab, target = input_gray.cuda(), input_ab.cuda(), target.cuda()
14
15    # Run model and record loss
16    output_ab = model(input_gray) # throw away class predictions
17    loss = criterion(output_ab, input_ab)
18    losses.update(loss.item(), input_gray.size(0))
19
20    # Save images to file
21    if save_images and not already_saved_images:
22      already_saved_images = True
23      for j in range(min(len(output_ab), 10)): # save at most 5 images
24        save_path = {'grayscale': 'outputs/gray/', 'colorized': 'outputs/color/'}
25        save_name = 'img-{}-epoch-{}.jpg'.format(i * val_loader.batch_size + j, epoch)
26        to_rgb(input_gray[j].cpu(), ab_input=output_ab[j].detach().cpu(), save_path=save_path, save_name=save_name)
27
28    # Record time to do forward passes and save images
29    batch_time.update(time.time() - end)
30    end = time.time()
31
32    # Print model accuracy -- in the code below, val refers to both value and validation
33    if i % 25 == 0:
34      print('Validate: [{0}/{1}]\t'
35            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
36            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
37             i, len(val_loader), batch_time=batch_time, loss=losses))
38
39  print('Finished validation.')
40  return losses.avg

训练

用loss.backward(),用上反向传播。写一下训练数据跑一遍 (one epoch) 用的函数。

 1def train(train_loader, model, criterion, optimizer, epoch):
 2  print('Starting training epoch {}'.format(epoch))
 3  model.train()
 4
 5  # Prepare value counters and timers
 6  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
 7
 8  end = time.time()
 9  for i, (input_gray, input_ab, target) in enumerate(train_loader):
10
11    # Use GPU if available
12    if use_gpu: input_gray, input_ab, target = input_gray.cuda(), input_ab.cuda(), target.cuda()
13
14    # Record time to load data (above)
15    data_time.update(time.time() - end)
16
17    # Run forward pass
18    output_ab = model(input_gray) 
19    loss = criterion(output_ab, input_ab) 
20    losses.update(loss.item(), input_gray.size(0))
21
22    # Compute gradient and optimize
23    optimizer.zero_grad()
24    loss.backward()
25    optimizer.step()
26
27    # Record time to do forward and backward passes
28    batch_time.update(time.time() - end)
29    end = time.time()
30
31    # Print model accuracy -- in the code below, val refers to value, not validation
32    if i % 25 == 0:
33      print('Epoch: [{0}][{1}/{2}]\t'
34            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
35            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
36            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
37              epoch, i, len(train_loader), batch_time=batch_time,
38             data_time=data_time, loss=losses)) 
39
40  print('Finished training epoch {}'.format(epoch))

然后,定义一个训练回路 (Training Loop) ,跑一百遍训练数据。从Epoch 0开始训练。

1# Move model and loss function to GPU
2if use_gpu: 
3  criterion = criterion.cuda()
4  model = model.cuda()
1# Make folders and set parameters
2os.makedirs('outputs/color', exist_ok=True)
3os.makedirs('outputs/gray', exist_ok=True)
4os.makedirs('checkpoints', exist_ok=True)
5save_images = True
6best_losses = 1e10
7epochs = 100
 1# Train model
 2for epoch in range(epochs):
 3  # Train for one epoch, then validate
 4  train(train_loader, model, criterion, optimizer, epoch)
 5  with torch.no_grad():
 6    losses = validate(val_loader, model, criterion, save_images, epoch)
 7  # Save checkpoint and replace old best model if current model is better
 8  if losses < best_losses:
 9    best_losses = losses
10    torch.save(model.state_dict(), 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))

训练结果什么样?

是时候看看修炼成果了,所以,复制一下这段代码。

 1# Show images 
 2import matplotlib.image as mpimg
 3image_pairs = [('outputs/color/img-2-epoch-0.jpg', 'outputs/gray/img-2-epoch-0.jpg'),
 4               ('outputs/color/img-7-epoch-0.jpg', 'outputs/gray/img-7-epoch-0.jpg')]
 5for c, g in image_pairs:
 6  color = mpimg.imread(c)
 7  gray  = mpimg.imread(g)
 8  f, axarr = plt.subplots(1, 2)
 9  f.set_size_inches(15, 15)
10  axarr[0].imshow(gray, cmap='gray')
11  axarr[1].imshow(color)
12  axarr[0].axis('off'), axarr[1].axis('off')
13  plt.show()

效果还是很自然的,虽然生成的彩色图像不是那么明丽。

卢克说,问题是多模态的,所以损失函数还是值得推敲。

比如,一条灰色裙子可以是蓝色也可以是红色。如果模型选择的颜色和参考标准不同,就会受到严厉的惩罚。

这样一来,模型就会选择哪些不会被判为大错特错的颜色,而不太选择非常显眼明亮的颜色。

没时间怎么办?

卢克还把一只训练好的AI放了出来,不想从零开始训练的小伙伴们,也可以直接感受他的训练成果,只要用以下代码下载就好了。

1# Download pretrained model
2!wget https://www.dropbox.com/s/kz76e7gv2ivmu8p/model-epoch-93.pth
3#https://www.dropbox.com/s/9j9rvaw2fo1osyj/model-epoch-67.pth
1# Load model
2pretrained = torch.load('model-epoch-93.pth', map_location=lambda storage, loc: storage)
3model.load_state_dict(pretrained)
1# Validate
2save_images = True
3with torch.no_grad():
4  validate(val_loader, model, criterion, save_images, 0)

彩色老电影?

如果想要更加有声有色的结局,就不能继续偷懒了。卢克希望大家沿着他精心铺就的路,走到更远的地方。

要替换当前的损失函数,可以参考Zhang et al. (2017): https://richzhang.github.io/ideepcolor/

无监督学习的上色大法,可以参考Larsson et al. (2017): http://people.cs.uchicago.edu/~larsson/color-proxy/

另外,可以做个手机应用,就像谷歌在I/O大会上发布的着色软件那样。

黑白电影,也可以自己去尝试,一帧一帧地上色。

这里有卓别林用到的完整代码

https://github.com/lukemelas/Automatic-Image-Colorization/

原文发布于微信公众号 - 量子位(QbitAI)

原文发表时间:2018-05-29

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏从流域到海域

Python机器学习中的特征选择

原文地址:https://machinelearningmastery.com/feature-selection-machine-learning-pytho...

2K70
来自专栏新智元

【资源】17个最受欢迎的机器学习应用标准数据集

【新智元导读】学好机器学习的关键是用许多不同的数据集来实践。本文介绍了10个最受欢迎的标准机器学习数据集和7个时间序列数据集,既有回归问题也有分类问题,并提供了...

949150
来自专栏AI派

如何使用sklearn加载和下载机器学习数据集

sklearn 中提供了很多常用(或高级)的模型和算法,但是真正决定一个模型效果的最后还是取决于训练(喂养)模型时所用的数据。sklearn 中的 sklear...

83450
来自专栏深度学习自然语言处理

2018 NLPCC Chinese Grammatical Error Correction 论文小结

这一段时间,笔者一直在研究语音识别后的文本纠错,而就在八月26-30日,CCF的自然语言处理和中文计算会议召开了,笔者也从师兄那里拿到了新鲜出炉的会议论文集,其...

48730
来自专栏闪电gogogo的专栏

SAMP论文学习

SAMP:稀疏度自适应匹配追踪 实际应用中信号通常是可压缩的而不一定为稀疏的,而且稀疏信号的稀疏度我们通常也会不了解的。论文中提到过高或者过低估计了信号的稀疏度...

399120
来自专栏从流域到海域

Feature Selection For Machine Learning in Python (Python机器学习中的特征选择)

Feature Selection For Machine Learning in Python 原文作者:Jason Brownlee 原文地址:http...

45760
来自专栏深度学习与数据挖掘实战

【今日热门】优秀资源

11420
来自专栏大数据挖掘DT机器学习

数字识别,从KNN,LR,SVM,RF到深度学习

@蜡笔小轩V 原文:http://blog.csdn.net/Dinosoft/article/details/50734539 之前看了很多入门的资料,如果...

56450
来自专栏iOSDevLog

决策树模型概述

23850
来自专栏小鹏的专栏

机器学习中的Bias(偏差),Error(误差),和Variance(方差)有什么区别和联系?

首先 Error = Bias + Variance + Noise Error反映的是整个模型的准确度,Bias反映的是模型在样本上的输出与真实值之间的误差...

34680

扫码关注云+社区

领取腾讯云代金券