前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >mask rcnn训练自己的数据集_fasterrcnn训练自己的数据集

mask rcnn训练自己的数据集_fasterrcnn训练自己的数据集

作者头像
全栈程序员站长
发布2022-09-23 11:15:35
7310
发布2022-09-23 11:15:35
举报

大家好,又见面了,我是你们的朋友全栈君。

这篇博客是 基于 Google Colab 的 mask rcnn 训练自己的数据集(以实例分割为例)文章中 数据集的制作 这部分的一些补充

温馨提示:

代码语言:javascript
复制
实例分割是针对同一个类别的不同个体或者不同部分之间进行区分
我的任务是对同一个类别的不同个体进行区分,在标注的时候,不同的个体需要设置不同的标签名称

在进行标注的时候不要勾选 labelme 界面左上角 File 下拉菜单中的 Stay With Images Data 选项
否则生成的json会包含 Imagedata 信息(是很长的一大串加密的软链接),会占用很大的内存
在这里插入图片描述
在这里插入图片描述

1.首先要人为划分训练集和测试集(图片和标注文件放在同一个文件夹里面)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2.在同级目录下新建一个 labels.txt 文件

代码语言:javascript
复制
__ignore__
__background__
seedling #根据自己的实际情况更改
在这里插入图片描述
在这里插入图片描述

3.在datasets目录下新建 seed_trainseed_val 两个文件夹

代码语言:javascript
复制
分别存放的训练集和测试集图片和整合后的标签文件
seed_train 
seed_val

把整合后的标签文件剪切复制到同级目录下
seed_train_annotation.josn
seed_val_annotation.json
在这里插入图片描述
在这里插入图片描述

完整代码

代码语言:javascript
复制
说明:
一次只能操作一个文件夹,也就是说:
训练集生成需要执行一次代码
测试集生成就需要更改路径之后再执行一次代码
代码语言:javascript
复制
import argparse
import collections
import datetime
import glob
import json
import os
import os.path as osp
import sys
import uuid
import time

import imgviz
import numpy as np

import labelme

try:
    import pycocotools.mask
except ImportError:
    print("Please install pycocotools:\n\n pip install pycocotools\n")
    sys.exit(1)

#https://github.com/pascal1129/kaggle_airbus_ship_detection/blob/master/0_rle_to_coco/1_ships_to_coco.py

def main():
    parser = argparse.ArgumentParser(description="json2coco")
    #原始json文件保存的路径
    parser.add_argument("--input_dir", help="input annotated directory",default="E:/Deep_learning/seed-mask/data/seed/seed_train")
    #整合后的json文件保存的路径
    parser.add_argument("--output_dir", help="output dataset directory",default="E:/Deep_learning/seed-mask/data/seed/datasets/seed_train")
    parser.add_argument("--labels", help="labels file", default='E:/Deep_learning/seed-mask/data/seed/labels.txt')#required=True
    parser.add_argument( "--noviz", help="no visualization", action="store_true" ,default="--noviz")
    args = parser.parse_args()

    now = datetime.datetime.now()
    start= time.time()

    data = dict(
        info=dict(
            description="seedling datasets",
            url=None,
            version="label=4.5.6",
            year=now.year,
            contributor=None,
            date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
        ),
        #licenses=[dict(url=None, id=0, name=None,)],
        images=[
            # license, url, file_name, height, width, date_captured, id
        ],
        type="instances",
        annotations=[
            # segmentation, area, iscrowd, image_id, bbox, category_id, id
        ],
        categories=[
            # supercategory, id, name
        ],
    )

    class_name_to_id = { 
   }
    for i, line in enumerate(open(args.labels).readlines()):
        class_id = i - 1  # starts with -1
        class_name = line.strip()
        if class_id == -1:
            assert class_name == "__ignore__"
            continue
        if class_id == 0:
            assert class_name == "__background__"
            continue        
        class_name_to_id[class_name] = class_id
        #print(class_id,class_name,'\n')
        data["categories"].append(
            dict(supercategory="seedling", id=class_id, name=class_name,)#一类目标+背景,id=0表示背景
        )
    print("categories 生成完成",'\n')

    out_ann_file = osp.join(args.output_dir, "seed_train_anno.json")#自动添加"/" 这里要改 
    
    
    label_files = glob.glob(osp.join(args.input_dir, "*.json"))#图像id从json文件中读取
    for image_id, filename in enumerate(label_files):
        print(image_id, filename)
        #print("Generating dataset from:", filename)

        label_file = labelme.LabelFile(filename=filename)

        base = osp.splitext(osp.basename(filename))[0]#图片名
        out_img_file = osp.join(args.output_dir, base + ".jpg")# 保存图片路径

        img = labelme.utils.img_data_to_arr(label_file.imageData)
        imgviz.io.imsave(out_img_file, img)
        data["images"].append(
            dict(
                #license=0,
                #url=None,
                file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
                height=img.shape[0],
                width=img.shape[1],
                #date_captured=None,
                id=image_id,
            )
        )

        masks = { 
   }  # for area
        segmentations = collections.defaultdict(list)  # for segmentation
        for shape in label_file.shapes:
            points = shape["points"]
            label = shape["label"]
            group_id = shape.get("group_id")
            shape_type = shape.get("shape_type", "polygon")
            mask = labelme.utils.shape.shape_to_mask(img.shape[:2], points, shape_type)#labelme=4.5.6的shape_to_mask函数
            if group_id is None:
                group_id = uuid.uuid1()

            instance = (label, group_id)
            #print(instance)

            if instance in masks:
                masks[instance] = masks[instance] | mask
            else:
                masks[instance] = mask

            if shape_type == "rectangle":
                (x1, y1), (x2, y2) = points
                x1, x2 = sorted([x1, x2])
                y1, y2 = sorted([y1, y2])
                points = [x1, y1, x2, y1, x2, y2, x1, y2]
            else:
                points = np.asarray(points).flatten().tolist()

            segmentations[instance].append(points)
        segmentations = dict(segmentations)

        for instance, mask in masks.items():            
            cls_name, group_id = instance
# if cls_name not in class_name_to_id:
# continue
# cls_id = class_name_to_id[cls_name]

            mask = np.asfortranarray(mask.astype(np.uint8))
            
            mask = pycocotools.mask.encode(mask)
            
            area = float(pycocotools.mask.area(mask))
            bbox = pycocotools.mask.toBbox(mask).flatten().tolist()
            

            data["annotations"].append(
                dict(
                    id=len(data["annotations"]),
                    image_id=image_id,
                    category_id=1,#都是1类cls_id
                    segmentation=segmentations[instance],
                    area=area,
                    bbox=bbox,
                    iscrowd=0,
                )
            )
    
    print("annotations 生成完成",'\n')

# if not args.noviz:
# labels, captions, masks = zip(
# *[
# (class_name_to_id[cnm], cnm, msk)
# for (cnm, gid), msk in masks.items()
# if cnm in class_name_to_id
# ]
# )
# viz = imgviz.instances2rgb(
# image=img,
# labels=labels,
# masks=masks,
# captions=captions,
# font_size=15,
# line_width=2,
# )
# out_viz_file = osp.join(
# args.output_dir, "Visualization", base + ".jpg"
# )
# imgviz.io.imsave(out_viz_file, viz)
    
    with open(out_ann_file, "w") as f:
        json.dump(data, f,indent = 2)
        
    cost_time =(time.time()-start)/1000
    print("cost_time:{:.2f}s".format(cost_time) )


if __name__ == "__main__":
    main()

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/172437.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档