首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >AttributeError:DataLoader属性缺失'persistent_workers'?

AttributeError:DataLoader属性缺失'persistent_workers'?

提问于 2023-01-11 16:48:56
回答 0关注 0查看 73
代码语言:js
复制
  File "C:\Users\26001\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\utils\data\dataloader.py", line 428, in __iter__
    if self.persistent_workers and self.num_workers > 0:
AttributeError: 'DataLoader' object has no attribute 'persistent_workers'

使用DataLoader时遇到AttributeError:属性缺失'persistent_workers'

代码语言:js
复制
Traceback (most recent call last):
  File "H:\Users\Administrator\Desktop\fakeNewsDetection\FND\_023_modelingMain.py", line 93, in <module>
    mainFunc()
  File "H:\Users\Administrator\Desktop\fakeNewsDetection\FND\_023_modelingMain.py", line 86, in mainFunc
    net.trainTVT()
  File "H:\Users\Administrator\Desktop\fakeNewsDetection\FND\_022_train.py", line 71, in trainTVT
    for it,(text,mask,label,domainLabel) in enumerate(self.dataLoaderTra):
  File "C:\Users\26001\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\utils\data\dataloader.py", line 428, in __iter__
    if self.persistent_workers and self.num_workers > 0:
AttributeError: 'DataLoader' object has no attribute 'persistent_workers'

代码语言:js
复制
        for i in range(self.start_epoch, self.final_epoch):
            timeStart=time.time()            
            # train
            self.net.train()
            correct=0.0
            total=0.0
            for it,(text,mask,label,domainLabel) in enumerate(self.dataLoaderTra):
                text,label=text.long(),label.long()
                text,mask,label,domainLabel=(text.to(self.device),
                                             mask.to(self.device),
                                             label.to(self.device),
                                             domainLabel.to(self.device))
                self.optimizer.zero_grad()
                outputs=self.net(text, mask)
                outputsLog=torch.log(outputs)
                loss=self.loss(outputsLog,label)                
                loss.backward()

class定义为

代码语言:js
复制
class Network(object):
    def __init__(self, opt):
        self.seed=opt.seed
        setupSeed(self.seed)
        self.batch=opt.batch
        self.lr=opt.lr
        self.start_epoch=opt.start_epoch
        self.final_epoch=opt.final_epoch
        self.inter=opt.inter
        self.mode=opt.mode
        self.models=opt.models
        self.plots=opt.plots
        # define dataloader
        self.dataLoaderTra, self.dataLoaderVal, self.dataLoaderTes, self.weightAr, self.coreList = getDatasetTVT(opt.traValTesName)
        # add weightAr to opt
        opt.weightAr = self.weightAr

def getDatasetTVT(traValTesName):
    dataLoaderTraValTesWeightAr=pickle.load(open(traValTesName,'rb'))
    dataLoaderTra, dataLoaderVal, dataLoaderTes, weightArray,coreListTra=dataLoaderTraValTesWeightAr
    return dataLoaderTra, dataLoaderVal, dataLoaderTes, weightArray,coreListTra

dataloader文件中

代码语言:js
复制
    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

请问问题出在哪里呢?应该如何修改

回答

和开发者交流更多问题细节吧,去 写回答
相关文章

相似问题

相关问答用户
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档