我正在尝试使用PyTorch数据加载器加载图像数据集,但生成的转换是平铺的,并且没有像我预期的那样将原始图像裁剪到中心。
transform = transforms.Compose([transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor()])
dataset = datasets.ImageFolder('ml-models/downloads/', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
images, labels = next(iter(dataloader))
import matplotlib.pyplot as plt
plt.imshow(images[6].reshape(224, 224, 3))生成的图像是平铺的,而不是中心裁剪的。
所提供的转换中是否存在错误?(下图显示链接:) 1:https://i.stack.imgur.com/HtrIa.png
发布于 2020-11-06 05:41:08
Pytorch以通道优先格式存储张量,因此3通道图像是形状(3,H,W)的张量。Matplotlib期望数据是通道最后的格式,即(H,W,3)。调整形状不会重新排列尺寸,因此您需要Tensor.permute。
plt.imshow(images[6].permute(1, 2, 0))https://stackoverflow.com/questions/64705364
复制相似问题