DefaultTrainer类中函数build_train_loader(cfg)的实现流程
文件路径:/detectron2/engine/default.py
类中build_train_loader主要调用build_detection_train_loader(cfg,mapper=None)函数,因此接下来重点分析该函数
该函数所在文件路径:/detectron2/data/build.py
build_detection_train_loader函数功能,官方定义如下:
def build_detection_train_loader(cfg, mapper=None):
"""
A data loader is created by the following steps:
1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
2. Start workers to work on the dicts. Each worker will:
* Map each metadata dict into another format to be consumed by the model.
* Batch them by simply putting dicts into a list.
The batched ``list[mapped_dict]`` is what this dataloader will return.
Args:
cfg (CfgNode): the config
mapper (callable): a callable which takes a sample (dict) from dataset and
returns the format to be consumed by the model.
By default it will be `DatasetMapper(cfg, True)`.
Returns:
a torch DataLoader object
"""
主要步骤:
1)根据给定的数据集名称获取所有的图像字典。
dataset_dicts = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN, ##配置文件中给定的数据集列表
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
)
dataset = DatasetFromList(dataset_dicts, copy=False)
2)对图像做map操作,转换成模型内部调用支持的格式。
if mapper is None:
mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper)
class DatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by the model.
This is the default callable to be used to map your dataset dict into training data.
You may need to follow it to implement your own one for customized logic.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies cropping/geometric transforms to the image and annotations
3. Prepare data and annotations to Tensor and :class:`Instances`
"""
回调函数__call__中具体实现:
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
# it will be modified by code below
dataset_dict = copy.deepcopy(dataset_dict)
##读取图像
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)
##对图像和标签做处理
if "annotations" not in dataset_dict:
image, transforms = T.apply_transform_gens(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
)
else:
# Crop around an instance if there are instances in the image.
# USER: Remove if you don't use cropping
if self.crop_gen:
crop_tfm = utils.gen_crop_transform_with_instance(
self.crop_gen.get_crop_size(image.shape[:2]),
image.shape[:2],
np.random.choice(dataset_dict["annotations"]),
)
image = crop_tfm.apply_image(image)
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
if self.crop_gen:
transforms = crop_tfm + transforms
image_shape = image.shape[:2] # h, w
##将数据转换成tensor格式
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
# Can use uint8 if it turns out to be slow some day
.....
return dataset_dict
补充,图像增强处理方法及实现在文件/detectron2/data/transforms/transform_gen.py中
def apply_transform_gens(transform_gens, img):
"""
Apply a list of :class:`TransformGen` on the input image, and
returns the transformed image and a list of transforms.
We cannot simply create and return all transforms without
applying it to the image, because a subsequent transform may
need the output of the previous one.
Args:
transform_gens (list): list of :class:`TransformGen` instance to
be applied.
img (ndarray): uint8 or floating point images with 1 or 3 channels.
Returns:
ndarray: the transformed image
TransformList: contain the transforms that's used.
"""
当前图像增强方法有:
"RandomBrightness",
"RandomContrast",
"RandomCrop",
"RandomExtent",
"RandomFlip",
"RandomSaturation",
"RandomLighting",
"Resize",
"ResizeShortestEdge",
##旋转操作
“HFlip_rotated_box”,
“Resize_rotated_box”,
如果需要自定义数据增强方法,则需要在该文件中自定义实现。
class MapDataset(data.Dataset):
"""
Map a function over the elements in a dataset.
Args:
dataset: a dataset where map function is applied.
map_func: a callable which maps the element in dataset. map_func is
responsible for error handling, when error happens, it needs to
return None so the MapDataset will randomly use other
elements from the dataset.
"""
def __getitem__(self, idx):
retry_count = 0
cur_idx = int(idx)
while True:
data = self._map_func(self._dataset[cur_idx])
if data is not None:
self._fallback_candidates.add(cur_idx)
return data
# _map_func fails for this idx, use a random new index from the pool
retry_count += 1
self._fallback_candidates.discard(cur_idx)
cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
if retry_count >= 3:
logger = logging.getLogger(__name__)
logger.warning(
"Failed to apply `_map_func` for idx: {}, retry count: {}".format(
idx, retry_count
)
)
3):转换成torch Dataloader
其中包括坐标位置的随机采样及数据的加载,详细参考博客https://www.cnblogs.com/marsggbo/p/11308889.html
此处直接上该部分代码
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
logger = logging.getLogger(__name__)
logger.info("Using training sampler {}".format(sampler_name))
if sampler_name == "TrainingSampler":
sampler = samplers.TrainingSampler(len(dataset))
elif sampler_name == "RepeatFactorTrainingSampler":
sampler = samplers.RepeatFactorTrainingSampler(
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))
batch_sampler = build_batch_data_sampler(
sampler, images_per_worker, group_bin_edges, aspect_ratios
)
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.DATALOADER.NUM_WORKERS,
batch_sampler=batch_sampler,
collate_fn=trivial_batch_collator,
worker_init_fn=worker_init_reset_seed,
)
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。