前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用TensorFlow一步步进行目标检测(5)

使用TensorFlow一步步进行目标检测(5)

作者头像
云水木石
发布2019-07-01 11:39:11
4670
发布2019-07-01 11:39:11
举报

本教程进行到这一步,您选择了预训练的目标检测模型,转换现有数据集或创建自己的数据集并将其转换为TFRecord文件,修改模型配置文件,并开始训练模型。接下来,您需要保存模型并将其部署到项目中。

将检查点模型(.ckpt)保存为.pb文件

回到TensorFlow目标检测文件夹,并将export_inference_graph.py文件复制到包含模型配置文件的文件夹中。

python export_inference_graph.py --input_type image_tensor --pipeline_config_path ./rfcn_resnet101_coco.config --trained_checkpoint_prefix ./models/train/model.ckpt-5000 --output_directory ./fine_tuned_model

这将创建一个新目录fine_tuned_model,里面名为frozen_inference_graph.pb的模型就是您训练出来的模型。

在项目中使用模型

我在本教程中一直在研究的项目是创建一个红绿灯分类器。在Python中,我将此分类器实现为一个类。 在类的初始化部分,我创建了一个TensorFlow会话,这样就不需要在每次需要分类时创建它。

class TrafficLightClassifier(object):
   def __init__(self):
       PATH_TO_MODEL = 'frozen_inference_graph.pb'
       self.detection_graph = tf.Graph()
       with self.detection_graph.as_default():
           od_graph_def = tf.GraphDef()
           # Works up to here.
           with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid:
               serialized_graph = fid.read()
               od_graph_def.ParseFromString(serialized_graph)
               tf.import_graph_def(od_graph_def, name='')
           self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
           self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
           self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
           self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
           self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0')
       self.sess = tf.Session(graph=self.detection_graph)

在该类中,我创建了一个函数,该函数对图像进行分类,并返回图像中分类的边界框、分数和类别。

def get_classification(self, img):
   # Bounding Box Detection.
   with self.detection_graph.as_default():
       # Expand dimension since the model expects image to have shape [1, None, None, 3].
       img_expanded = np.expand_dims(img, axis=0)  
       (boxes, scores, classes, num) = self.sess.run(
           [self.d_boxes, self.d_scores, self.d_classes, self.num_d],
           feed_dict={self.image_tensor: img_expanded})
   return boxes, scores, classes, num

此时,您需要过滤低于指定分数阈值的结果。结果会自动从最高分数到最低分数排序,因此这很容易实现。通过上面的函数返回分类结果,就是这样,您做到了!

您可以在下图中看到我实现的红绿灯分类器。

我最初创建本教程是因为我很难找到有关如何使用Object Detection API的资讯。我希望通过阅读本教程,您可以启动项目,让项目快速实现,这样您可以将更多时间集中在您真正感兴趣的内容上!

相关文章
  1. 使用TensorFlow一步步进行目标检测(1)
  2. 使用TensorFlow一步步进行目标检测(2)
  3. 使用TensorFlow一步步进行目标检测(3)
  4. 使用TensorFlow一步步进行目标检测(4)
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-08-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 云水木石 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 将检查点模型(.ckpt)保存为.pb文件
  • 在项目中使用模型
  • 相关文章
相关产品与服务
图像识别
腾讯云图像识别基于深度学习等人工智能技术,提供车辆,物体及场景等检测和识别服务, 已上线产品子功能包含车辆识别,商品识别,宠物识别,文件封识别等,更多功能接口敬请期待。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档