# PyTorch实现自由的数据读取

Pytorch数据读入函数介绍

## 问题来源

class custom_dset(Dataset): def __init__(self, img_path, txt_path, img_transform=None, loader=default_loader): with open(txt_path, 'r') as f: lines = f.readlines() self.img_list = [ os.path.join(img_path, i.split()[0]) for i in lines ] self.label_list = [i.split()[1] for i in lines] self.img_transform = img_transform self.loader = loader def __getitem__(self, index): img_path = self.img_list[index] label = self.label_list[index] # img = self.loader(img_path) img = img_path if self.img_transform is not None: img = self.img_transform(img) return img, label def __len__(self): return len(self.label_list)

def collate_fn(batch): batch.sort(key=lambda x: len(x[1]), reverse=True) img, label = zip(*batch) pad_label = [] lens = [] max_len = len(label[0]) for i in range(len(label)): temp_label = [0] * max_len temp_label[:len(label[i])] = label[i] pad_label.append(temp_label) lens.append(len(label[i])) return img, pad_label, lens

