我有一批形状[5,1,100,100] (batch_size x dims x ht x wd)的分段掩码,我必须用RGB图像批处理[5,3,100,100]在tensorboardX中显示。我想在分割掩码的第二个轴上添加两个虚拟尺寸,使其成为[5,3,100,100],这样当我将其传递给torch.utils.make_grid时,就不会出现任何维度不匹配的错误。我试过unsqueeze,expand和view,但是我做不到。有什么建议吗?
发布于 2019-07-09 12:44:05
您可以使用expand、repeat或repeat_interleave
import torch
x = torch.randn((5, 1, 100, 100))
x1_3channels = x.expand(-1, 3, -1, -1)
x2_3channels = x.repeat(1, 3, 1, 1)
x3_3channels = x.repeat_interleave(3, dim=1)
print(x1_3channels.shape) # torch.Size([5, 3, 100, 100])
print(x2_3channels.shape) # torch.Size([5, 3, 100, 100])
print(x3_3channels.shape) # torch.Size([5, 3, 100, 100])注意,如文档中所述:
扩展张量的不分配新的内存,而只是在现有张量上创建一个新视图,其中尺寸为1的维度通过将量程设置为0扩展到更大的尺寸。大小为1的任何维度都可以扩展为任意值,而无需分配新的内存。
与
expand()不同,这个函数复制张量的数据。
https://stackoverflow.com/questions/56952598
复制相似问题