前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用Keras+TensorFlow,实现ImageNet数据集日常对象的识别

用Keras+TensorFlow,实现ImageNet数据集日常对象的识别

作者头像
量子位
发布2018-03-30 16:23:44
1.9K0
发布2018-03-30 16:23:44
举报
文章被收录于专栏:量子位量子位
王新民 编译自 Deep Learning Sandbox博客 量子位 出品 | 公众号 QbitAI

在计算机视觉领域里,有3个最受欢迎且影响非常大的学术竞赛:ImageNet ILSVRC(大规模视觉识别挑战赛),PASCAL VOC(关于模式分析,统计建模和计算学习的研究)和微软COCO图像识别大赛。这些比赛大大地推动了在计算机视觉研究中的多项发明和创新,其中很多都是免费开源的。

博客Deep Learning Sandbox作者Greg Chu打算通过一篇文章,教你用Keras和TensorFlow,实现对ImageNet数据集中日常物体的识别。

量子位翻译了这篇文章:

你想识别什么?

看看ILSVRC竞赛中包含的物体对象。如果你要研究的物体对象是该列表1001个对象中的一个,运气真好,可以获得大量该类别图像数据!以下是这个数据集包含的部分类别:

椅子

汽车

键盘

箱子

婴儿床

旗杆

iPod播放器

轮船

面包车

项链

降落伞

枕头

桌子

钱包

球拍

步枪

校车

萨克斯管

足球

袜子

舞台

火炉

火把

吸尘器

自动售货机

眼镜

红绿灯

菜肴

盘子

西兰花

红酒

表1 ImageNet ILSVRC的类别摘录

完整类别列表见:https://gist.github.com/gregchu/134677e041cd78639fea84e3e619415b

如果你研究的物体对象不在该列表中,或者像医学图像分析中具有多种差异较大的背景,遇到这些情况该怎么办?可以借助迁移学习(transfer learning)和微调(fine-tuning),我们以后再另外写文章讲。

图像识别

图像识别,或者说物体识别是什么?它回答了一个问题:“这张图像中描绘了哪几个物体对象?”如果你研究的是基于图像内容进行标记,确定盘子上的食物类型,对癌症患者或非癌症患者的医学图像进行分类,以及更多的实际应用,那么就能用到图像识别。

Keras和TensorFlow

Keras是一个高级神经网络库,能够作为一种简单好用的抽象层,接入到数值计算库TensorFlow中。另外,它可以通过其keras.applications模块获取在ILSVRC竞赛中获胜的多个卷积网络模型,如由Microsoft Research开发的ResNet50网络和由Google Research开发的InceptionV3网络,这一切都是免费和开源的。具体安装参照以下说明进行操作:

Keras安装:https://keras.io/#installation

TensorFlow安装:https://www.tensorflow.org/install/

实现过程

我们的最终目标是编写一个简单的python程序,只需要输入本地图像文件的路径或是图像的URL链接就能实现物体识别。

以下是输入非洲大象照片的示例:

代码语言:javascript
复制
1. python classify.py --image African_Bush_Elephant.jpg
2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg

输入:

输出将如下所示:

该图像最可能的前3种预测类别及其相应概率

预测功能

我们接下来要载入ResNet50网络模型。首先,要加载keras.preprocessingkeras.applications.resnet50模块,并使用在ImageNet ILSVRC比赛中已经训练好的权重。

想了解ResNet50的原理,可以阅读论文《基于深度残差网络的图像识别》。地址:https://arxiv.org/pdf/1512.03385.pdf

代码语言:javascript
复制
import numpy as np
from keras.preprocessing import image
from keras.applications.resnet50 
import ResNet50, preprocess_input, decode_predictions
model = ResNet50(weights='imagenet')

接下来定义一个预测函数:

代码语言:javascript
复制
def predict(model, img, target_size, top_n=3):
  """Run model prediction on image
  Args:
    model: keras model
    img: PIL format image
    target_size: (width, height) tuple
    top_n: # of top predictions to return
  Returns:
    list of predicted labels and their probabilities
  """
  if img.size != target_size:
    img = img.resize(target_size)
  x = image.img_to_array(img)
  x = np.expand_dims(x, axis=0)
  x = preprocess_input(x)
  preds = model.predict(x)  
return decode_predictions(preds, top=top_n)[0]

在使用ResNet50网络结构时需要注意,输入大小target_size必须等于(224,224)。许多CNN网络结构具有固定的输入大小,ResNet50正是其中之一,作者将输入大小定为(224,224)

image.img_to_array:将PIL格式的图像转换为numpy数组。

np.expand_dims:将我们的(3,224,224)大小的图像转换为(1,3,224,224)。因为model.predict函数需要4维数组作为输入,其中第4维为每批预测图像的数量。这也就是说,我们可以一次性分类多个图像。

preprocess_input:使用训练数据集中的平均通道值对图像数据进行零值处理,即使得图像所有点的和为0。这是非常重要的步骤,如果跳过,将大大影响实际预测效果。这个步骤称为数据归一化。

model.predict:对我们的数据分批处理并返回预测值。

decode_predictions:采用与model.predict函数相同的编码标签,并从ImageNet ILSVRC集返回可读的标签。

keras.applications模块还提供4种结构:ResNet50、InceptionV3、VGG16、VGG19和XCeption,你可以用其中任何一种替换ResNet50。更多信息可以参考https://keras.io/applications/。

绘图

我们可以使用matplotlib函数库将预测结果做成柱状图,如下所示:

代码语言:javascript
复制
def plot_preds(image, preds):  
  """Displays image and the top-n predicted probabilities 
     in a bar graph  
  Args:    
    image: PIL image
    preds: list of predicted labels and their probabilities  
  """  
  #image
  plt.imshow(image)
  plt.axis('off')  #bar graph
  plt.figure()  
  order = list(reversed(range(len(preds))))  
  bar_preds = [pr[2] for pr in preds]
  labels = (pr[1] for pr in preds)
  plt.barh(order, bar_preds, alpha=0.5)
  plt.yticks(order, labels)
  plt.xlabel('Probability')
  plt.xlim(0, 1.01)
  plt.tight_layout()
  plt.show()

主体部分

为了实现以下从网络中加载图片的功能:

代码语言:javascript
复制
1. python classify.py --image African_Bush_Elephant.jpg
2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg

我们将定义主函数如下:

代码语言:javascript
复制
if __name__=="__main__":
  a = argparse.ArgumentParser()
  a.add_argument("--image", 
help="path to image")
  a.add_argument("--image_url", 
help="url to image")
  args = a.parse_args()
if args.image is None and args.image_url is None:
    a.print_help()
    sys.exit(1)
if args.image is not None:
    img = Image.open(args.image)
    print_preds(predict(model, img, target_size))
if args.image_url is not None:
    response = requests.get(args.image_url)
    img = Image.open(BytesIO(response.content))
    print_preds(predict(model, img, target_size))

其中在写入image_url功能后,用python中的Requests库就能很容易地从URL链接中下载图像。

完工

将上述代码组合起来,你就创建了一个图像识别系统。项目的完整程序和示例图像请查看GitHub链接:

https://github.com/DeepLearningSandbox/DeepLearningSandbox/tree/master/image_recognition

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2017-04-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 量子位 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 你想识别什么?
  • 图像识别
  • Keras和TensorFlow
  • 实现过程
  • 预测功能
  • 绘图
  • 主体部分
  • 完工
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档