前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >FCOS3D label assignment

FCOS3D label assignment

作者头像
烤粽子
发布2022-09-01 16:39:31
3880
发布2022-09-01 16:39:31
举报

跟2d的FCOS差不太多, 主要是依靠图片坐标系来分配target:

代码语言:javascript
复制
def _get_target_single(self, gt_bboxes, gt_labels, gt_bboxes_3d,
                           gt_labels_3d, centers2d, depths, attr_labels,
                           points, regress_ranges, num_points_per_lvl):
        """Compute regression and classification targets for a single image."""
        num_points = points.size(0)
        num_gts = gt_labels.size(0)
        if not isinstance(gt_bboxes_3d, torch.Tensor):
            gt_bboxes_3d = gt_bboxes_3d.tensor.to(gt_bboxes.device)
        if num_gts == 0:
            return gt_labels.new_full((num_points,), self.background_label), \
                   gt_bboxes.new_zeros((num_points, 4)), \
                   gt_labels_3d.new_full(
                       (num_points,), self.background_label), \
                   gt_bboxes_3d.new_zeros((num_points, self.bbox_code_size)), \
                   gt_bboxes_3d.new_zeros((num_points,)), \
                   attr_labels.new_full(
                       (num_points,), self.attr_background_label)

        # change orientation to local yaw
        gt_bboxes_3d[..., 6] = -torch.atan2(
            gt_bboxes_3d[..., 0], gt_bboxes_3d[..., 2]) + gt_bboxes_3d[..., 6]

        areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
            gt_bboxes[:, 3] - gt_bboxes[:, 1]) # [tl_x, tl_y, br_x, br_y]--> S_areas
        areas = areas[None].repeat(num_points, 1) # [2] --> [30929, 2]
        regress_ranges = regress_ranges[:, None, :].expand(
            num_points, num_gts, 2) # [30929, 2] --> [30929, 2, 2]
        gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
        centers2d = centers2d[None].expand(num_points, num_gts, 2)
        gt_bboxes_3d = gt_bboxes_3d[None].expand(num_points, num_gts,
                                                 self.bbox_code_size)
        depths = depths[None, :, None].expand(num_points, num_gts, 1)
        # 每个points的坐标(xs,ys)
        xs, ys = points[:, 0], points[:, 1]
        xs = xs[:, None].expand(num_points, num_gts)
        ys = ys[:, None].expand(num_points, num_gts)

        # gt center --> offsets
        ## centers2d: 每个gt在2d image上的坐标
        delta_xs = (xs - centers2d[..., 0])[..., None]
        delta_ys = (ys - centers2d[..., 1])[..., None]
        # 0. 前面的操作是主要是为了这里,获得跟网络输出相同的target_box
        bbox_targets_3d = torch.cat(
            (delta_xs, delta_ys, depths, gt_bboxes_3d[..., 3:]), dim=-1)

        left = xs - gt_bboxes[..., 0]
        right = gt_bboxes[..., 2] - xs
        top = ys - gt_bboxes[..., 1]
        bottom = gt_bboxes[..., 3] - ys
        bbox_targets = torch.stack((left, top, right, bottom), -1)

        assert self.center_sampling is True, 'Setting center_sampling to '\
            'False has not been implemented for FCOS3D.'
        # condition1: inside a `center bbox`
        radius = self.center_sample_radius # 1.5
        center_xs = centers2d[..., 0]
        center_ys = centers2d[..., 1]
        center_gts = torch.zeros_like(gt_bboxes)
        stride = center_xs.new_zeros(center_xs.shape)

        # project the points on current lvl back to the `original` sizes
        # 1. 将各层特征点位置映射回输入图像中
        lvl_begin = 0
        for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): # [23200, 5800, 1450, 375, 104]
            lvl_end = lvl_begin + num_points_lvl
            stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius # [8, 16, 32, 64, 128] * 1.5
            # 每个point的缩放系数 * 半径
            lvl_begin = lvl_end
        
        # 2. 位于物体框内的位置点作为正样本候选
        ## 边长1.5的框 --> 
        center_gts[..., 0] = center_xs - stride
        center_gts[..., 1] = center_ys - stride
        center_gts[..., 2] = center_xs + stride
        center_gts[..., 3] = center_ys + stride

        cb_dist_left = xs - center_gts[..., 0] # points中心点到
        cb_dist_right = center_gts[..., 2] - xs
        cb_dist_top = ys - center_gts[..., 1]
        cb_dist_bottom = center_gts[..., 3] - ys
        center_bbox = torch.stack(
            (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
        inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 # anchor_box中心点落在gt_box中心点1.5单位的正方形内才有效

        # condition2: limit the regression range for each location
        # 3. 某位置点到物体边框的距离只有位于一定范围内才可作为正样本(每层有各自的范围)
        max_regress_distance = bbox_targets.max(-1)[0]
        # 确保在每层level的回归范围内
        inside_regress_range = (
            (max_regress_distance >= regress_ranges[..., 0])
            & (max_regress_distance <= regress_ranges[..., 1]))

        # center-based criterion to deal with ambiguity
        # 4. 基于中心准则的模糊处理
        ## 4.1选出偏移量最小的gt+gt_inds
        dists = torch.sqrt(torch.sum(bbox_targets_3d[..., :2]**2, dim=-1)) # offsets的欧式距离 [30929, 2]
        dists[inside_gt_bbox_mask == 0] = INF # 筛选anchor
        dists[inside_regress_range == 0] = INF
        min_dist, min_dist_inds = dists.min(dim=1)

        labels = gt_labels[min_dist_inds] # 筛选gt
        labels_3d = gt_labels_3d[min_dist_inds]
        attr_labels = attr_labels[min_dist_inds]
        labels[min_dist == INF] = self.background_label  # set as BG 10
        labels_3d[min_dist == INF] = self.background_label  # set as BG
        attr_labels[min_dist == INF] = self.attr_background_label

        ## 4.2 每个point上选出对应的box_target
        bbox_targets = bbox_targets[range(num_points), min_dist_inds] # [30929, 2, 4] --> [30929, 4]
        bbox_targets_3d = bbox_targets_3d[range(num_points), min_dist_inds]
        ## 4.3 筛选centerness_targets
        ## 偏移量--> 斜边 / 边长1.5scale到实际三角形的边长 == 相对距离
        relative_dists = torch.sqrt(
            torch.sum(bbox_targets_3d[..., :2]**2,
                      dim=-1)) / (1.414 * stride[:, 0])
        # [N, 1] / [N, 1]
        centerness_targets = torch.exp(-self.centerness_alpha * relative_dists) # exp(-2.5 * relative_dists) todo?

        return labels, bbox_targets, labels_3d, bbox_targets_3d, \
            centerness_targets, attr_labels
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2022-06-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档