pytorch报错:
RuntimeError: Expected object of scalar type Byte but got scalar type Double for sequence element 2 in sequence argument at position #1 'tensors'
完整错误描述:
raceback (most recent call last):
File "D:/python workspace/Pytorch-UNet/train.py", line 175, in <module>
val_percent=args.val / 100)
File "D:/python workspace/Pytorch-UNet/train.py", line 67, in train_net
for batch in train_loader:
File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\dataloader.py", line 582, in __next__
return self._process_next_batch(batch)
File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\dataloader.py", line 608, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
RuntimeError: Traceback (most recent call last):
File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 63, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 63, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
File "E:\anaconda\envs\python36\lib\site-packages\torch\utils\data\_utils\collate.py", line 43, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: Expected object of scalar type Byte but got scalar type Double for sequence element 2 in sequence argument at position #1 'tensors'
部分代码:
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader: #############报错行
imgs = batch['image']
true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
相似问题