首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在S3DISDataset中,我不理解一些代码,所以我想知道它在S3DISDataset中做了什么

在S3DISDataset中,我不理解一些代码,所以我想知道它在S3DISDataset中做了什么
EN

Stack Overflow用户
提问于 2022-01-31 15:16:06
回答 1查看 57关注 0票数 0

https://github.com/yanx27/Pointnet_Pointnet2_pytorch中的源代码

代码语言:javascript
复制
class S3DISDataset(Dataset):
    def __init__(self, split='train', data_root='trainval_fullarea',
                 num_point=4096, test_area=5, block_size=1.0, sample_rate=1.0, transform=None):
        super().__init__()
        self.num_point = num_point    # 采样点数
        self.block_size = block_size  # 将采样房间变为的大小
        self.transform = transform
        rooms = sorted(os.listdir(data_root))
        rooms = [room for room in rooms if 'Area_' in room]

        if split == 'train':
            rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room]   # ['Area_1_WC_1.txt','Area_1_conferenceRoom_1.txt',...]
        else:
            rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room]     # ['Area_5_WC_1.txt', 'Area_5_WC_2.txt',...]

        self.room_points, self.room_labels = [], []        # 点云数据,点云目标值
        self.room_coord_min, self.room_coord_max = [], []  # 最小坐标,最大坐标
        num_point_all = []
        labelweights = np.zeros(13)     # array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

        for room_name in tqdm(rooms_split, total=len(rooms_split)):
            room_path = os.path.join(data_root, room_name)
            room_data = np.load(room_path)  # xyzrgbl, N*7
            points, labels = room_data[:, 0:6], room_data[:, 6]  # xyzrgb, N*6; l, N
            tmp, _ = np.histogram(labels, range(14))
            labelweights += tmp
            coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
            self.room_points.append(points), self.room_labels.append(labels)
            self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max)
            num_point_all.append(labels.size)
        labelweights = labelweights.astype(np.float32)
        labelweights = labelweights / np.sum(labelweights)
        self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
        print(self.labelweights)
        sample_prob = num_point_all / np.sum(num_point_all)
        num_iter = int(np.sum(num_point_all) * sample_rate / num_point)
        room_idxs = []
        for index in range(len(rooms_split)):
            room_idxs.extend([index] * int(round(sample_prob[index] * num_iter)))
        self.room_idxs = np.array(room_idxs)
        print("Totally {} samples in {} set.".format(len(self.room_idxs), split))

    def __getitem__(self, idx):
        room_idx = self.room_idxs[idx]
        points = self.room_points[room_idx]   # N * 6
        labels = self.room_labels[room_idx]   # N
        N_points = points.shape[0]

        while (True):
            center = points[np.random.choice(N_points)][:3]
            block_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0]
            block_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0]
            point_idxs = np.where((points[:, 0] >= block_min[0]) & (points[:, 0] <= block_max[0]) & (points[:, 1] >= block_min[1]) & (points[:, 1] <= block_max[1]))[0]
            if point_idxs.size > 1024:
                break

        if point_idxs.size >= self.num_point:
            selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=False)
        else:
            selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=True)

        # normalize
        selected_points = points[selected_point_idxs, :]  # num_point * 6
        current_points = np.zeros((self.num_point, 9))  # num_point * 9
        current_points[:, 6] = selected_points[:, 0] / self.room_coord_max[room_idx][0]
        current_points[:, 7] = selected_points[:, 1] / self.room_coord_max[room_idx][1]
        current_points[:, 8] = selected_points[:, 2] / self.room_coord_max[room_idx][2]
        selected_points[:, 0] = selected_points[:, 0] - center[0]
        selected_points[:, 1] = selected_points[:, 1] - center[1]
        selected_points[:, 3:6] /= 255.0
        current_points[:, 0:6] = selected_points
        current_labels = labels[selected_point_idxs]
        if self.transform is not None:
            current_points, current_labels = self.transform(current_points, current_labels)
        return current_points, current_labels

    def __len__(self):
        return len(self.room_idxs)

以上是获取S3DIS数据集代码,但我无法理解一些代码。

下面的代码是我不明白的部分,我想知道它是做什么的。

我无法理解labelweights的含义,也不理解代码的最后一行。

代码语言:javascript
复制
labelweights = labelweights.astype(np.float32)
labelweights = labelweights / np.sum(labelweights)
self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)

我不明白下面的代码。我想知道每一行代码意味着什么。

代码语言:javascript
复制
sample_prob = num_point_all / np.sum(num_point_all)
num_iter = int(np.sum(num_point_all) * sample_rate / num_point)
room_idxs = []
for index in range(len(rooms_split)):
room_idxs.extend([index] * int(round(sample_prob[index] * num_iter)))
self.room_idxs = np.array(room_idxs)

你能给我解释一下密码吗?非常感谢!

EN

回答 1

Stack Overflow用户

发布于 2022-01-31 16:37:07

我还没有读过这篇文章,我建议你深入研究,希望能得到更多的细节。既然如此,这就是我从代码中了解到的。

代码语言:javascript
复制
self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)

在分配之前,标签权重包含一个概率分布。然后用最大概率对labelweights阵列进行归一化,然后倒置(np.amax(labelweights) / labelweights)。最后,应用三次根的倒置比。

总之,如果p是概率分布,则该行计算cubic_root((p / max(p)) ^ {-1})

给定名称后,该变量可能用于根据其类频率对观测数据进行采样,方法是为表示较少的类分配更多的权重,而三次根则以某种方式使权重变得更小。

代码语言:javascript
复制
sample_prob = num_point_all / np.sum(num_point_all)
num_iter = int(np.sum(num_point_all) * sample_rate / num_point)
room_idxs = []
for index in range(len(rooms_split)):
    room_idxs.extend([index] * int(round(sample_prob[index] * num_iter)))
self.room_idxs = np.array(room_idxs)

sample_prob是包含每个房间点的频率的概率分布。

num_iter是所有房间(np.sum(num_point_all))总点数(np.sum(num_point_all))的百分比sample_rate。这样的百分比被num_point除以,得到的可能是批次的总数或类似的东西。

room_idxs是通过将几个list连接在一起创建的([el] * n生成一个包含n乘以el元素的list )。

每个list都包含相同的房间索引复制sample_prob[index] * num_iter时间(因此len(room_idxs)num_iter ca.)。

总之,这段代码的目的是根据包含的点数对房间索引进行示例。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70928310

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档