首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >pytorch运行时错误?

pytorch运行时错误?

提问于 2020-03-06 17:15:29
回答 2关注 0查看 1.6K

pytorch报错:

RuntimeError: Expected object of scalar type Byte but got scalar type Double for sequence element 2 in sequence argument at position #1 'tensors'

完整错误描述:

raceback (most recent call last):

File "D:/python workspace/Pytorch-UNet/train.py", line 175, in <module>

val_percent=args.val / 100)

File "D:/python workspace/Pytorch-UNet/train.py", line 67, in train_net

for batch in train_loader:

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\dataloader.py", line 582, in __next__

return self._process_next_batch(batch)

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\dataloader.py", line 608, in _process_next_batch

raise batch.exc_type(batch.exc_msg)

RuntimeError: Traceback (most recent call last):

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\worker.py", line 99, in _worker_loop

samples = collate_fn([dataset[i] for i in batch_indices])

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 63, in default_collate

return {key: default_collate([d[key] for d in batch]) for key in batch[0]}

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 63, in <dictcomp>

return {key: default_collate([d[key] for d in batch]) for key in batch[0]}

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 43, in default_collate

return torch.stack(batch, 0, out=out)

RuntimeError: Expected object of scalar type Byte but got scalar type Double for sequence element 2 in sequence argument at position #1 'tensors'

部分代码:

代码语言:javascript
复制
for epoch in range(epochs):
    net.train()
    epoch_loss = 0
    with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
        for batch in train_loader:    #############报错行
            imgs = batch['image']
            true_masks = batch['mask']
            assert imgs.shape[1] == net.n_channels, \
                f'Network has been defined with {net.n_channels} input channels, ' \
                f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'

            imgs = imgs.to(device=device, dtype=torch.float32)
            mask_type = torch.float32 if net.n_classes == 1 else torch.long
            true_masks = true_masks.to(device=device, dtype=mask_type)

回答

和开发者交流更多问题细节吧,去 写回答
相关文章

相似问题

相关问答用户
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档