前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Keras-RetinaNet训练自己的数据详细教程

Keras-RetinaNet训练自己的数据详细教程

作者头像
大黄大黄大黄
发布2019-03-15 10:52:20
2.4K0
发布2019-03-15 10:52:20
举报

准备工作:

1、代码开源框架使用的是 fizyr/keras-retinanet 2、Keras版本要2.2.4以上

下面进入正题。


第一部分:模型准备

(1)模型下载地址:fizyr/keras-retinanet (2)模型安装可以使用以下命令:

pip install numpy --user
pip install . --user

安装过程中,会检查依赖库,比如opencv-python,如果没有安装,会加载并安装。这里提一句,如果在安装时某个包下载安装不成功,自己记下来版本,比如opencv-python 3.4.5.20,可以直接先去利用pip或conda安装,但是一定要记得对应的版本。

(3)模型编译可以使用以下命令:

python setup.py build_ext --inplace

编译的时候可能会出现提示,没有某个版本C++的编译器,我提示的时没有2014版,把错误提示直接百度,就会出现解决方法,我是下载了一个3M的14版的编译工具。(当然,最好就是有相应版本的完整C++)

第二部分:数据准备

(1)在keras-retinanet-master/keras_retinanet/文件夹下面新建一个文件夹CSV用来存放自己制作的数据集。

数据文件夹格式如下:

<CSV>
|———— train_annotations.csv # 必须
|———— val_annotations.csv # 必须
|———— classes.csv # 必须
|
|____ data  # (可选),这样 annotations.csv可以使用图片的相对路径       
         └─ *.jpg

(2)根据官网的样例,自己制作的Annotations数据集格式如下:

path/to/image.jpg,x1,y1,x2,y2,class_name

如果一张图片中没有包含任何要检测的物体,则格式如下:

path/to/image.jpg,,,,,

一个完整的例子:

/data/imgs/img_001.jpg,837,346,981,456,cow
/data/imgs/img_002.jpg,215,312,279,391,cat
/data/imgs/img_002.jpg,22,5,89,84,bird
/data/imgs/img_003.jpg,,,,,

下面,我就贴出自己写的一个代码:

def restrict_image_info(label_path):
    with open(label_path, 'r') as load_f:
        load_dict = json.load(load_f)
        image_collect = load_dict['images']
        image_num = len(image_collect)
        anno_collect = load_dict['annotations']
        anno_num = len(anno_collect)

        img_path_list = []
        x1_list = []
        y1_list = []
        x2_list = []
        y2_list = []
        category_list = []

        mapper = {0: 'tieke', 1: 'heiding',
                  2: 'daoju', 3: 'dian', 4: 'jiandao'}

        for i in range(image_num):
            img = image_collect[i]
            img_name = img['file_name']
            img_id = img['id']
            img_height = img['height']
            img_width = img['width']

            for j in range(anno_num):
                if anno_collect[j]['image_id'] == img_id:
                    bbox = anno_collect[j]['bbox']
                    img_path_list.append(restrict_rele_path+img_name)
                    x1_list.append(int(np.rint(bbox[0])))
                    y1_list.append(int(np.rint(bbox[1])))
                    x2_list.append(
                        int(np.rint(bbox[0] + bbox[2])))
                    y2_list.append(
                        int(np.rint((bbox[1]+bbox[3]))))
                    category_list.append(anno_collect[j]['category_id']-1)

        anno = pd.DataFrame()
        anno['img_path'] = img_path_list
        anno['x1'] = x1_list
        anno['y1'] = y1_list
        anno['x2'] = x2_list
        anno['y2'] = y2_list
        anno['class'] = category_list
        anno['class'] = anno['class'].map(mapper)

        # anno.to_csv('CSV/annotations.csv', index=None, header=None)
        train_anno, val_anno = train_test_split(anno, test_size=0.1)
        train_anno.to_csv('CSV/train_annotations.csv', index=None, header=None)
        val_anno.to_csv('CSV/val_annotations.csv', index=None, header=None)

其中代码段:

   train_anno, val_anno = train_test_split(anno, test_size=0.1)
   train_anno.to_csv('CSV/train_annotations.csv', index=None, header=None)
   val_anno.to_csv('CSV/val_annotations.csv', index=None, header=None)

是对图片进行训练集、验证集的随机划分。

训练图片生成的数据格式如下:

data/jinnan2_round1_train_20190305/restricted/190119_184244_00166940.jpg,88,253,206,295,daoju
data/jinnan2_round1_train_20190305/restricted/190119_184244_00166940.jpg,296,244,414,344,jiandao
data/jinnan2_round1_train_20190305/restricted/190119_184244_00166940.jpg,231,239,299,341,jiandao
data/jinnan2_round1_train_20190305/restricted/190119_184244_00166940.jpg,99,278,194,320,dian

验证图片生成的数据格式如下:

data/jinnan2_round1_train_20190305/restricted/190119_182957_00166754.jpg,314,237,326,265,dian
data/jinnan2_round1_train_20190305/restricted/190127_100838_00177153.jpg,246,229,304,279,tieke
data/jinnan2_round1_train_20190305/restricted/190119_184522_00166980.jpg,668,409,717,432,dian
data/jinnan2_round1_train_20190305/restricted/190119_183142_00166782.jpg,565,326,708,432,jiandao
data/jinnan2_round1_train_20190305/restricted/190127_143529_00178527.jpg,8,262,45,326,heiding

(3)根据官网的样例,自己制作的classes数据集格式如下:

class_name,id

一个完整的例子:

cow,0
cat,1
bird,2

最后生成的数据格式如下:

tieke,0
heiding,1
daoju,2
dian,3
jiandao,4

注意:保存的csv文件是没有头部行的,不然后续代码会报错!

(4)检查生成的数据是否合格

要进行这一步,必须先要完成第一步中模型的下载与编译!

检查数据可以使用以下命令:

python keras_retinanet/bin/debug.py csv keras_retinanet/CSV/train_annotations.csv keras_retinanet/CSV/classes.csv

其中第一个参数csv代表要检查的数据是自己制作的数据集,第二个参数是train_annotations.csv对应的路径,第三个参数是classes.csv对应的路径。

(5)图片存放位置

这个可以根据自己的需要定,但是最好放在上面新建的CSV文件夹下面,这个使用路径比较方便。在我自己这个代码中,我是在CSV文件夹下新建一个data文件夹下存放自己的图片,此时注意与train_annotations.csv文件中的图片路径要一致,比如我这时候就应该是这样:

data/jinnan2_round1_train_20190222/restricted/190119_185206_00167075.jpg,125,279,177,339,tieke
data/jinnan2_round1_train_20190222/restricted/190119_185206_00167075.jpg,153,363,238,549,daoju

(6)关于模型的图片输入尺寸

https://github.com/fizyr/keras-retinanet/blob/master/keras_retinanet/bin/train.py中的409、410行有设置输入的默认参数(800*1333):

parser.add_argument('--image-min-side',   help='Rescale the image so the smallest side is min_side.', type=int, default=800)
parser.add_argument('--image-max-side',   help='Rescale the image if the largest side is larger than max_side.', type=int, default=1333)

第三部分:模型训练

模型训练可以使用以下命令:

python keras_retinanet/bin/train.py csv keras_retinanet/CSV/train_annotations.csv keras_retinanet/CSV/classes.csv --val-annotations keras_retinanet/CSV/val_annotations.csv

其中第一个参数csv代表要检查的数据是自己制作的数据集,第二个参数是train_annotations.csv对应的路径,第三个参数是classes.csv对应的路径,第四个参数--val-annotationsval_annotations.csv对应的路径。

多卡训练可用如下命令:

python keras_retinanet/bin/train.py --multi-gpu-force --multi-gpu 2 --batch-size 2 csv keras_retinanet/CSV/train_annotations.csv keras_retinanet/CSV/classes.csv --val-annotations keras_retinanet/CSV/val_annotations.csv

可能会遇到的错误:

(1)ImportError: No module named 'keras_resnet'

解决办法:pip install keras-resnet --user


参考资料: 1、Retinanet训练自己的数据(2):模型准备

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年03月06日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档