pytorch运行时错误?

  • 回答 (2)
  • 关注 (0)
  • 查看 (643)

pytorch报错:

RuntimeError: Expected object of scalar type Byte but got scalar type Double for sequence element 2 in sequence argument at position #1 'tensors'

完整错误描述:

raceback (most recent call last):

File "D:/python workspace/Pytorch-UNet/train.py", line 175, in <module>

val_percent=args.val / 100)

File "D:/python workspace/Pytorch-UNet/train.py", line 67, in train_net

for batch in train_loader:

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\dataloader.py", line 582, in __next__

return self._process_next_batch(batch)

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\dataloader.py", line 608, in _process_next_batch

raise batch.exc_type(batch.exc_msg)

RuntimeError: Traceback (most recent call last):

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\worker.py", line 99, in _worker_loop

samples = collate_fn([dataset[i] for i in batch_indices])

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 63, in default_collate

return {key: default_collate([d[key] for d in batch]) for key in batch[0]}

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 63, in <dictcomp>

return {key: default_collate([d[key] for d in batch]) for key in batch[0]}

File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 43, in default_collate

return torch.stack(batch, 0, out=out)

RuntimeError: Expected object of scalar type Byte but got scalar type Double for sequence element 2 in sequence argument at position #1 'tensors'

部分代码:

for epoch in range(epochs):
    net.train()
    epoch_loss = 0
    with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
        for batch in train_loader:    #############报错行
            imgs = batch['image']
            true_masks = batch['mask']
            assert imgs.shape[1] == net.n_channels, \
                f'Network has been defined with {net.n_channels} input channels, ' \
                f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'

            imgs = imgs.to(device=device, dtype=torch.float32)
            mask_type = torch.float32 if net.n_classes == 1 else torch.long
            true_masks = true_masks.to(device=device, dtype=mask_type)
腾讯智能钛AI开发者

腾讯云 · 智能钛产品团队 (已认证)

腾讯智能钛产品团队官方运营账号。分享产品最新动态,第一时间解答用户疑问。回答于

您好,您这边可以提供下您的具体任务流链接吗?

用户5012811回答于

可能回答问题的人

  • Superbeet

    8 粉丝0 提问0 回答
  • 腾讯云AI中心

    腾讯云 · 人工智能 (已认证)

    102 粉丝0 提问19 回答
  • 腾讯智能钛AI开发者

    腾讯云 · 智能钛产品团队 (已认证)

    164 粉丝0 提问61 回答
  • rodson

    腾讯 · web前端开发 (已认证)

    4 粉丝0 提问0 回答
  • DJ213

    2 粉丝0 提问0 回答
  • 晏栋栋栋

    3 粉丝0 提问2 回答

扫码关注云+社区

领取腾讯云代金券