首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >将多类图像分类简化为二值分类

将多类图像分类简化为二值分类
EN

Stack Overflow用户
提问于 2022-04-15 22:48:48
回答 3查看 226关注 0票数 0

我正在处理一个stl-10图像数据集,它包含10个不同的类。我想把这个多类图像分类问题归结为二值类图像分类,如1类Vs rest。我正在使用PyTorch torchvision下载和使用stl数据,但是我无法像一个人和其他人一样做到这一点。

代码语言:javascript
运行
复制
train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)

train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)
EN

回答 3

Stack Overflow用户

发布于 2022-04-18 13:49:37

对于torchvision数据集,有一种内置的方法可以做到这一点。您需要定义转换函数或类,并在创建数据集时将其添加到target_transform中。

代码语言:javascript
运行
复制
torchvision.datasets.STL10(root: str, split: str = 'train', folds: Union[int, NoneType] = None, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)

下面是一个可供参考的示例:

代码语言:javascript
运行
复制
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms


class Multi2UniLabelTfm():
    def __init__(self,pos_label=5):
        if isinstance(pos_label,int) or isinstance(pos_label,float):
            pos_label = [pos_label,]
        self.pos_label = pos_label

    def __call__(self,y):
        # if y==self.pos_label:
        if y in self.pos_label:
            return 1
        else:
            return 0

if __name__=='__main__':

    test_tfms = transforms.Compose([
        transforms.ToTensor()
    ])
    data_transforms = {'val':test_tfms}


    #Original Labels
    # target_transform = None   

    # Label 5 is converted to 1. Rest are 0.
    # target_transform = Multi2UniLabelTfm(pos_label=5)     

    # Labels 5,6,7 are converted to 1. Rest are 0.
    target_transform = Multi2UniLabelTfm(pos_label=[5,6,7])
    test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True, target_transform=target_transform)
    test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

    for idx,(x,y) in enumerate(test_dataloader):
        print(idx,y)

        if idx == 5:
            break
票数 1
EN

Stack Overflow用户

发布于 2022-04-16 14:37:08

你需要重新命名图像。开始时,0类对应于标签0,1类对应于标签1,.,10类对应于标签9。如果要实现二进制分类,则需要将类别1(或其他)的图片的标签更改为0,将所有其他类别的图片更改为1。

票数 0
EN

Stack Overflow用户

发布于 2022-04-17 05:29:11

一种方法是在运行时更新标签值,然后将它们传递给训练循环中的丢失函数。假设我们要将5类重命名为1,其余部分为0:

代码语言:javascript
运行
复制
my_class_id = 5
for imgs, labels in train_dataloader:
    labels = torch.where(labels == my_class_id, 1, 0)
    ...

您可能还需要对test_dataloader进行类似的重新标记。另外,我不确定labels的数据类型。如果它的浮动,相应地改变。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71889622

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档