首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在pytorch中更新我的ImageFolder数据集?

如何在pytorch中更新我的ImageFolder数据集?
EN

Stack Overflow用户
提问于 2021-06-26 14:23:32
回答 1查看 43关注 0票数 1

我正在处理一个数据集,其中我需要找到少于20个样本的类的准确性。因此,首先我使用pytorch的ImageFolder来获取文件夹中的所有图像。

代码语言:javascript
复制
dataset = ImageFolder('/content/drive/MyDrive/data/Dataset/')

现在,为了得到我使用的少于20个样本的类:

代码语言:javascript
复制
def get_class_distribution(dataset_obj):
    count_dict = {k:0 for k,v in dataset_obj.class_to_idx.items()}
    
    for element in dataset_obj:
        y_lbl = element[1]
        y_lbl = idx2class[y_lbl]
        count_dict[y_lbl] += 1
            
    return count_dict
# print("Distribution of classes: \n", get_class_distribution(dataset))
class_distribution = get_class_distribution(dataset)

sampled_classes = [classes  for (classes, samples) in class_distribution.items() if samples <= 20]

我正确地获得了类的列表,但我的疑问是,我如何进一步进行推理?如何将其转换/更新为ImageFolder,以便可以在以下代码中使用过滤后的数据集:

代码语言:javascript
复制
# Test model performance for classes with less than 20 samples.

y_pred_list = []
y_true_list = []
with torch.no_grad():
    for x_batch, y_batch in tqdm(data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        y_test_pred = model(x_batch)
        _, y_pred_tag = torch.max(y_test_pred, dim = 1)
        y_pred_list.append(y_pred_tag.cpu().numpy())
        y_true_list.append(y_batch.cpu().numpy())
EN

回答 1

Stack Overflow用户

发布于 2021-06-26 15:28:22

不需要写入第一个块

改用这个

代码语言:javascript
复制
test_data = datasets.ImageFolder('test/', transform=test_transforms)
data_loader = torch.utils.data.DataLoader(test_data, batch_size=16)

y_pred_list = []
accuracy = []
with torch.no_grad():
    for x_batch, y_batch in tqdm(data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        y_test_pred = model(x_batch)
        top_p, top_class = y_test_pred.topk(1, dim=1)
        equals = top_class == y_batch.view(*top_class.shape)
        accuracy += torch.mean(equals.type(torch.FloatTensor)).item()


print(accuracy/len(data_loader)*100) # this would print %
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68139811

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档