首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >加快用于自定义消息丢失的pytorch操作

加快用于自定义消息丢失的pytorch操作
EN

Stack Overflow用户
提问于 2022-04-17 09:31:44
回答 1查看 85关注 0票数 0

我试图在我的自定义MessagePassing卷积中实现PyTorch几何中的消息丢失。消息丢失包括随机忽略图中边沿的p%。我的想法是从forward()中的输入forward()中随机删除其中的p%。

edge_index是形状(2, num_edges)的张量,其中第1维是从节点ID,第2维是从节点ID到节点ID。所以我认为我可以做的是选择一个随机的range(N)样本,然后用它来掩盖其余的索引:

代码语言:javascript
运行
复制
    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # TODO: this is way too slow (4-5 times slower than without it)
            # message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
            random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
            edge_index_to_use = edge_index[:, random_keep_inx]
            edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...

然而,它太慢了,它使一个时代走向5‘而不是1’没有(5倍的慢)。在PyTorch中有更快的方法来做到这一点吗?

编辑:瓶颈似乎是random.sample()调用,而不是掩蔽。所以我想我应该要求的是更快的替代方案。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-04-17 09:56:37

我使用PyTorch的Dropout函数创建了一个布尔掩码,速度更快。现在一个时代又需要1‘了。比其他我在其他地方找到的置换解决方案要好。

代码语言:javascript
运行
复制
    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # message dropout -> randomly ignore p % of edges in the graph
            mask = F.dropout(torch.ones(edge_index.shape[1]), self.message_dropout, self.training) > 0
            edge_index_to_use = edge_index[:, mask]
            edge_attr_to_use = edge_attr[mask] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71900767

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档