专栏首页wOw的Android小站[Tensorflow] 使用SSD-MobileNet训练模型

[Tensorflow] 使用SSD-MobileNet训练模型

使用SSD-MobileNet训练模型

因为Android Demo里的模型是已经训练好的,模型保存的label都是固定的,所以我们在使用的时候会发现还有很多东西它识别不出来。那么我们就需要用它来训练我们自己的数据。下面就是使用SSD-MobileNet训练模型的方法。

下载

  • 到Github上下载/克隆TensorModels,后面的操作都要在这个目录下执行
  • 下载数据集(数据集应该是自己制作的,制作数据集需要用到一些工具,另外介绍),我们使用VOC2012数据集
  • 下载SSD-MobileNet,我们做得,相当于在这个基础上进行再次训练(retrain)

环境设置

进入下载的Model目录:

cd models/research/
# 执行
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH="${PYTHONPATH}:/home/Github/models:/home/Github/models/research/slim/"

这里的PYTHONPATH路径一定要填对,不然会影响到后面运行代码。 然后还要注意代码版本不同文件路径有差别,所以要对照自己目录看好。

继续在research/目录下执行:

# 如果找不到setup.py, 用find命令找对应的路径
python setup.py build
python setup.py install

配置及训练

object_detection/目录下创建目录ssd_model

mkdir object_detection/ssd_model

把下载好的数据集解压进去,数据集路径为

./object_detection/ssd_model/VOCdevkit/

执行配置文件

python ./object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=object_detection/data/pascal_label_map.pbtxt --data_dir=object_detection/ssd_model/VOCdevkit/ --year=VOC2012 --set=train --output_path=object_detection/ssd_model/pascal_train.record

python ./object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=object_detection/data/pascal_label_map.pbtxt --data_dir=object_detection/ssd_model/VOCdevkit/ --year=VOC2012 --set=val --output_path=object_detection/ssd_model/pascal_val.record

然后会在ssd_model/目录下生成pascal_train.recordpascal_val.record两个文件,分别有600M左右。 下一步复制训练pet数据用到的文件,我们在这个基础上修改配置,训练我们的数据

cp object_detection/data/pascal_label_map.pbtxt object_detection/ssd_model/
cp object_detection/samples/configs/ssd_mobilenet_v1_pets.config object_detection/ssd_model/

我们打开pascal_label_map.pbtxt看一下,这个文件里面是类似Json格式的label集,列出了数据集里有哪些label。Pascal这个数据集label共有20个。

然后打开配置文件ssd_mobilenet_v1_pets.config,把num_classes改为20 配置默认训练次数num_steps: 200000,我们根据自己需要改,注意这个训练是很慢的,差不多以天为单位,所以可以适当改小点。

然后改一些文件路径:

train_input_reader: {
  tf_record_input_reader {
    input_path: "/home/wow/Github/models/research/object_detection/ssd_model/pascal_train.record"
  }
  label_map_path: "/home/wow/Github/models/research/object_detection/ssd_model/pascal_label_map.pbtxt"
}

eval_input_reader: {
  tf_record_input_reader {
    input_path: "/home/wow/Github/models/research/object_detection/ssd_model/pascal_val.record"
  }
  label_map_path: "/home/wow/Github/models/research/object_detection/ssd_model/pascal_label_map.pbtxt"
  shuffle: false
  num_readers: 1
}

把之前下载的ssd_mobilenet解压到/object_detection/ssd_model/ssd_mobilenet

把路径填进配置文件

fine_tune_checkpoint: "/home/wow/Github/models/research/object_detection/ssd_model/ssd_mobilenet/model.ckpt"

完成之后,就可以训练模型了

python object_detection/train.py --train_dir object_detection/train --pipeline_config_path object_detection/ssd_model/ssd_mobilenet_v1_pets.config

经过漫长的等待,可以看到在/object_detection/train目录下生成了模型。然后创建文件夹ssd_model/model

python object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path object_detection/ssd_model/ssd_mobilenet_v1_pets.config --trained_checkpoint_prefix object_detection/train/model.ckpt-30000 --output_directory object_detection/ssd_model/model/

生成pb文件,再把pascal_label_map.pbtxt的内容改成.txt作为label文件,这个模型就可以使用了。

错误解决

错误1:

TypeError: x and y must have the same dtype, got tf.float32 != tf.int32

修改./object_detection/builders/post_processing_builder.py

def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale):
  """Create a function to scale logits then apply a Tensorflow function."""
  def score_converter_fn(logits):
    cr = logit_scale
    cr = tf.constant([[cr]],tf.float32)
    print(logit_scale)
    print(logits)
    scaled_logits = tf.divide(logits, cr, name='scale_logits') #change logit_scale
    return tf_score_converter_fn(scaled_logits, name='convert_scores')
  score_converter_fn.__name__ = '%s_with_logit_scale' % (
      tf_score_converter_fn.__name__)
  return score_converter_fn

修改之后,需要再执行:

python setup.py build
python setup.py install

错误2:

ImportError: cannot import name rewriter_config_pb2
# 或者
AttributeError: 'module' object has no attribute 'mark_flag_as_required'

修改:

# 前一个错
pip install --upgrade tensorflow==1.2.0
# 后一个错
pip install --upgrade tensorflow==1.4.0

测试模型

import cv2
import numpy as np
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util


class TOD(object):
    def __init__(self):
        self.PATH_TO_CKPT = '/home/wow/Github/models/research/object_detection/ssd_model/model/frozen_inference_graph.pb'
        self.PATH_TO_LABELS = '/home/wow/Github/models/research/object_detection/ssd_model/pascal_label_map.pbtxt'
        self.NUM_CLASSES = 1
        self.detection_graph = self._load_model()
        self.category_index = self._load_label_map()

    def _load_model(self):
        detection_graph = tf.Graph()
        with detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        return detection_graph

    def _load_label_map(self):
        label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        categories = label_map_util.convert_label_map_to_categories(label_map,
                                                                    max_num_classes=self.NUM_CLASSES,
                                                                    use_display_name=True)
        category_index = label_map_util.create_category_index(categories)
        return category_index

    def detect(self, image):
        with self.detection_graph.as_default():
            with tf.Session(graph=self.detection_graph) as sess:
                # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                image_np_expanded = np.expand_dims(image, axis=0)
                image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
                boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
                scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
                classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
                # Actual detection.
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
                # Visualization of the results of a detection.
                vis_util.visualize_boxes_and_labels_on_image_array(
                    image,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    self.category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8)

        cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
        cv2.imshow("detection", image)
        cv2.waitKey(0)

if __name__ == '__main__':
    image = cv2.imread('/home/wow/Github/models/research/object_detection/ssd_model/img/cow-in-pasture.jpg')
    detecotr = TOD()
    detecotr.detect(image)

效果:

数据制作

安装labelImg工具,进行图片的标注.标注后的是xml格式的文件,将这些文件按照一定比例分到traintest目录.

下载datitran,作者自己写了一套xmlcsv再转为record文件的脚本.比SSD的脚本方便使用.

首先编辑xml_to_csv.py,修改main函数:

def main():
    #image_path = os.path.join(os.getcwd(), 'annotations')
    image_path = os.path.join('/home/Github/models/research/object_detection/ssd_model/MyImgs/labels/test')
    xml_df = xml_to_csv(image_path)
    #xml_df.to_csv('raccoon_labels.csv', index=None)
    xml_df.to_csv('fish_test_labels.csv', index=None)
    print('Successfully converted xml to csv.')

执行

python xml_to_csv.py

会生成test的csv,同样,修改代码,生成train的csv.

然后进行csv到record的转换

首先修改generate_tfrecord.py,把main函数的path改成我们图片路径,然后把if row_label == 'raccoon':改成我们的label,比如fish.之后执行下面代码:

python generate_tfrecord.py --csv_input=fish_train_labels.csv --output_path=fish_train.record
python generate_tfrecord.py --csv_input=fish_test_labels.csv --output_path=fish_test.record

然后会生成对应的record文件:

Successfully created the TFRecords: /home/Github/raccoon_dataset/fish_train.record
Successfully created the TFRecords: /home/Github/raccoon_dataset/fish_test.record

回到我们之前训练SSD的目录,创建自己的label文件my_label_map.pbtxt

item {
  id: 1
  name: 'fish'
}

修改训练配置文件:

num_classes: 1 #20

再把所有PATH_TO_BE_CONFIGURED的地方改掉,就可以用前面的命令进行训练.

训练时会遇到这个错误:

INFO:tensorflow:Restoring parameters from object_detection/train/model.ckpt-5390
INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>, Assign requires shapes of both tensors to match. lhs shape= [1,1,128,12] rhs shape= [1,1,128,126]

这是因为之前我有训练过模型,训练到5390次就停了.这里配置写的是200k次,所以它会接着之前的结果继续跑.但我们的数据发生了变化,所以会出现这个错误.解决方法就是把train目录删掉,重新生成即可

参考

深度学习入门篇—手把手教你用 TensorFlow 训练模型

tensorflow ssd mobilenet模型训练

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • [Objective-C] 常量和枚举

    #define宏定义 #define是一条预编译指令, 编译器在编译阶段前期会将所有使用到宏的地方简单地进行替换.

    wOw
  • [Android] Service服务详解以及如何使service服务不被杀死

      服务是一个应用程序组件,可以在后台执行长时间运行的操作,不提供用户界面。一个应用程序组件可以启动一个服务,它将继续在后台运行,即使用户切换到另一个应用程序。...

    wOw
  • [设计模式]之二:策略模式

    也很简单,同一个方法,把折扣作为一个参数,默认值为1,代码改为“单价 数量 折扣”即可。

    wOw
  • 那些不常见,但却非常实用的css属性(整理不易)

    可以把 块容器 中的内容限制为指定的行数。并且在超过行数后,在最后一行显示"..."

    winty
  • Xctf攻防世界-Web进阶题攻略

    攻防世界答题模块是一款提升个人信息安全水平的益智趣味答题,用户可任意选择题目类型进行答题。

    Aran
  • 订单自动过期实现方案

    这个太简单了,就是在查询的时候判断是否失效,如果失效了就给他设置失效状态。但是弊端也很明显,每次查询都要对未失效的订单做判断,如果用户不查询,订单就不失效,那么...

    Mshu
  • Medium网友开发了一款应用程序 让学习算法和数据结构变得更有趣

    Medium网友Peter Weinberg开发了一款名叫CS-Playground-React的应用程序,可以使大家更有意思、也更加轻松地学习算法和数据结构。...

    AiTechYun
  • Redis 过期时间与内存管理

    当 Redis 作为缓存使用时(此时缓存仅作为热点数据提高服务的访问性能),需要考虑内存的限制,以及如何随着业务的增长,仅保留热点数据。

    斯武丶风晴
  • App Store把app的评论扒下来

    链接:https://github.com/freesan44/AppReviews

    freesan44
  • 自定义动画函数JQuery实现

    skylark

扫码关注云+社区

领取腾讯云代金券