我正在处理一个stl-10图像数据集,它包含10个不同的类。我想把这个多类图像分类问题归结为二值类图像分类,如1类Vs rest。我正在使用PyTorch torchvision下载和使用stl数据,但是我无法像一个人和其他人一样做到这一点。
train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)
train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)发布于 2022-04-18 13:49:37
对于torchvision数据集,有一种内置的方法可以做到这一点。您需要定义转换函数或类,并在创建数据集时将其添加到target_transform中。
torchvision.datasets.STL10(root: str, split: str = 'train', folds: Union[int, NoneType] = None, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)下面是一个可供参考的示例:
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
class Multi2UniLabelTfm():
def __init__(self,pos_label=5):
if isinstance(pos_label,int) or isinstance(pos_label,float):
pos_label = [pos_label,]
self.pos_label = pos_label
def __call__(self,y):
# if y==self.pos_label:
if y in self.pos_label:
return 1
else:
return 0
if __name__=='__main__':
test_tfms = transforms.Compose([
transforms.ToTensor()
])
data_transforms = {'val':test_tfms}
#Original Labels
# target_transform = None
# Label 5 is converted to 1. Rest are 0.
# target_transform = Multi2UniLabelTfm(pos_label=5)
# Labels 5,6,7 are converted to 1. Rest are 0.
target_transform = Multi2UniLabelTfm(pos_label=[5,6,7])
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True, target_transform=target_transform)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)
for idx,(x,y) in enumerate(test_dataloader):
print(idx,y)
if idx == 5:
break发布于 2022-04-16 14:37:08
你需要重新命名图像。开始时,0类对应于标签0,1类对应于标签1,.,10类对应于标签9。如果要实现二进制分类,则需要将类别1(或其他)的图片的标签更改为0,将所有其他类别的图片更改为1。
发布于 2022-04-17 05:29:11
一种方法是在运行时更新标签值,然后将它们传递给训练循环中的丢失函数。假设我们要将5类重命名为1,其余部分为0:
my_class_id = 5
for imgs, labels in train_dataloader:
labels = torch.where(labels == my_class_id, 1, 0)
...您可能还需要对test_dataloader进行类似的重新标记。另外,我不确定labels的数据类型。如果它的浮动,相应地改变。
https://stackoverflow.com/questions/71889622
复制相似问题