我想在标签不相交的地方连接多个数据集(所以不要共享标签)。我做了:
class ConcatDataset(Dataset):
"""
ref: https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
"""
def __init__(self, datasets: list[Dataset]):
"""
"""
# I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
# maps a class label to a list of sample indices with that label.
self.labels_to_indices = defaultdict(list)
# maps a sample index to its corresponding class label.
self.indices_to_labels = defaultdict(None)
# - do the relabeling
offset: int = 0
new_idx: int = 0
for dataset_idx, dataset in enumerate(datasets):
assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
assert dataset == self.concat_datasets.datasets[dataset_idx]
for x, y in dataset:
y = int(y)
_x, _y = self.concat_datasets[new_idx]
_y = int(_y)
# assert y == _y
assert torch.equal(x, _x)
new_label = y + offset
self.indices_to_labels[new_idx] = new_label
self.labels_to_indices[new_label] = new_idx
num_labels_for_current_dataset: int = max([y for _, y in dataset])
offset += num_labels_for_current_dataset
new_idx += 1
assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
# contains the list of labels from 0 - total num labels after concat
self.labels = range(offset)
self.target_transform = lambda data: torch.tensor(data, dtype=torch.int)
def __len__(self):
return len(self.concat_datasets)
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
x = self.concat_datasets[idx]
y = self.indices_to_labels[idx]
if self.target_transform is not None:
y = self.target_transform(y)
return x, y
但是它甚至不能对齐x图像(所以不管我的连接是否有效!)为什么?
def check_xs_align_cifar100():
from pathlib import Path
root = Path("~/data/").expanduser()
# root = Path(".").expanduser()
train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
concat = ConcatDataset([train, test])
print(f'{len(concat)=}')
print(f'{len(concat.labels)=}')
错误
Files already downloaded and verified
Files already downloaded and verified
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1491, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 405, in <module>
check_xs_align()
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 391, in check_xs_align
concat = ConcatDataset([train, test])
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 71, in __init__
assert torch.equal(x, _x)
TypeError: equal(): argument 'input' (position 1) must be Tensor, not Image
python-BaseException
奖励:请告诉我重标是否正确。
编辑1: PIL比较失败
我根据Compare images Python PIL做了一个PIL图像比较,但是失败了:
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1491, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 419, in <module>
check_xs_align_cifar100()
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 405, in check_xs_align_cifar100
concat = ConcatDataset([train, test])
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 78, in __init__
assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}'
AssertionError: comparison of imgs failed: diff.getbbox()=None
python-BaseException
diff
PyDev console: starting.
<PIL.Image.Image image mode=RGB size=32x32 at 0x7FBE897A21C0>
代码比较:
diff = ImageChops.difference(x, _x) # https://stackoverflow.com/questions/35176639/compare-images-python-pil
assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}'
这也失败了:
assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
AssertionError: ...long msg...
断言声明是:
assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
编辑2:张量比较失败
我试图将图像转换为张量,但仍然失败:
AssertionError: Error for some reason, got: data_idx=1, x.norm()=tensor(45.9401), _x.norm()=tensor(33.9407), x=tensor([[[1.0000, 0.9922, 0.9922, ..., 0.9922, 0.9922, 1.0000],
代码:
class ConcatDataset(Dataset):
"""
ref:
- https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
- https://stackoverflow.com/questions/73913522/why-dont-the-images-align-when-concatenating-two-data-sets-in-pytorch-using-tor
"""
def __init__(self, datasets: list[Dataset]):
"""
"""
# I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
# maps a class label to a list of sample indices with that label.
self.labels_to_indices = defaultdict(list)
# maps a sample index to its corresponding class label.
self.indices_to_labels = defaultdict(None)
# - do the relabeling
img2tensor: Callable = torchvision.transforms.ToTensor()
offset: int = 0
new_idx: int = 0
for dataset_idx, dataset in enumerate(datasets):
assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
assert dataset == self.concat_datasets.datasets[dataset_idx]
for data_idx, (x, y) in enumerate(dataset):
y = int(y)
# - get data point from concataned data set (to compare with the data point from the data set list)
_x, _y = self.concat_datasets[new_idx]
_y = int(_y)
# - sanity check concatanted data set aligns with the list of datasets
# assert y == _y
# from PIL import ImageChops
# diff = ImageChops.difference(x, _x) # https://stackoverflow.com/questions/35176639/compare-images-python-pil
# assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}'
# assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
# tensor comparison
x, _x = img2tensor(x), img2tensor(_x)
print(f'{data_idx=}, {x.norm()=}, {_x.norm()=}')
assert torch.equal(x, _x), f'Error for some reason, got: {data_idx=}, {x.norm()=}, {_x.norm()=}, {x=}, {_x=}'
# - relabling
new_label = y + offset
self.indices_to_labels[new_idx] = new_label
self.labels_to_indices[new_label] = new_idx
num_labels_for_current_dataset: int = max([y for _, y in dataset])
offset += num_labels_for_current_dataset
new_idx += 1
assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
# contains the list of labels from 0 - total num labels after concat
self.labels = range(offset)
self.target_transform = lambda data: torch.tensor(data, dtype=torch.int)
def __len__(self):
return len(self.concat_datasets)
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
x = self.concat_datasets[idx]
y = self.indices_to_labels[idx]
if self.target_transform is not None:
y = self.target_transform(y)
return x, y
编辑3,澄清请求:
我对我想要的数据集的设想是将一个有问题的数据集连在一起--从第一个标签开始重新标记。最重要的是(据我看来-在这个问题上可能是错的),一旦连在一起,我们就应该以某种方式验证数据集的行为确实是我们想要的。我认为一项检查是从数据集列表以及数据集的级联对象中索引数据点。如果数据集被正确地串连在一起,我希望图像按照这个索引是对应的。所以,如果第一个数据集中的第一个图像有一些唯一的标识符(例如像素),那么数据集的连接应该有第一个图像与数据集列表中的第一个图像相同,所以如果我开始创建新的标签--我怎么知道我这样做是正确的呢?
reddit链接:https://www.reddit.com/r/pytorch/comments/xurnu9/why_dont_the_images_align_when_concatenating_two/
发布于 2022-10-13 20:40:30
更正后的代码可以在这里找到,https://github.com/brando90/ultimate-utils/blob/master/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py,您可以通过pip安装库pip install ultimate-utils
。
由于只有链接不是回答的好方法,所以我也会复制粘贴代码,并使用它的测试和预期输出:
"""
do checks, loop through all data points, create counts for each label how many data points there are
do this for MI only
then check union and ur implementation?
compare the mappings of one & the other?
actually it's easy, just add the cummulative offset and that's it. :D the indices are already -1 indexed.
assert every image has a label between 0 --> n1+n2+... and every bin for each class is none empty
for it to work with any standard pytorch data set I think the workflow would be:
-> l2l元数据集->联合数据集-> .dataset字段->数据加载器
for l2l data sets:
l2l元数据集->联合数据集-> .dataset字段->数据加载器
but the last one might need to make sure .indices or .labels is created or a get labels function that checking the attribute
gets the right .labels or remaps it correctly
"""
from collections import defaultdict
from pathlib import Path
from typing import Callable, Optional
import torch
import torchvision
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
class ConcatDatasetMutuallyExclusiveLabels(Dataset):
"""
Useful attributes:
- self.labels: contains all new USL labels i.e. contains the list of labels from 0 - total num labels after concat.
- len(self): gives number of images after all images have been concatenated
- self.indices_to_labels: maps the new concat idx to the new label after concat.
ref:
- https://stackoverflow.com/questions/73913522/why-dont-the-images-align-when-concatenating-two-data-sets-in-pytorch-using-tor
- https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
"""
def __init__(self, datasets: list[Dataset],
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
compare_imgs_directly: bool = False,
verify_xs_align: bool = False,
):
"""
Concatenates different data sets assuming the labels are mutually exclusive in the data sets.
compare_imgs_directly: adds the additional test that imgs compare at the PIL imgage level.
"""
self.datasets = datasets
self.transform = transform
self.target_transform = target_transform
# I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
# maps a class label to a list of sample indices with that label.
self.labels_to_indices = defaultdict(list)
# maps a sample index to its corresponding class label.
self.indices_to_labels = defaultdict(None)
# - do the relabeling
self._re_label_all_dataset(datasets, compare_imgs_directly, verify_xs_align)
def __len__(self):
return len(self.concat_datasets)
def _re_label_all_dataset(self, datasets: list[Dataset],
compare_imgs_directly: bool = False,
verify_xs_align: bool = False,
):
"""
Relabels according to a blind (mutually exclusive) assumption.
Relabling Algorithm:
The zero index of the label starts at the number of labels collected so far. So when relabling we do:
y = y + total_number_labels
total_number_labels += max label for current data set
where total_number_labels always has the + 1 to correct for the zero indexing.
:param datasets:
:param compare_imgs_directly:
:parm verify_xs_align: set to false by default in case your transforms aren't deterministic.
:return:
"""
self.img2tensor: Callable = torchvision.transforms.ToTensor()
self.int2tensor: Callable = lambda data: torch.tensor(data, dtype=torch.int)
total_num_labels_so_far: int = 0
new_idx: int = 0
for dataset_idx, dataset in enumerate(datasets):
assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
assert dataset == self.concat_datasets.datasets[dataset_idx]
for data_idx, (x, y) in enumerate(dataset):
y = int(y)
# - get data point from concataned data set (to compare with the data point from the data set list)
_x, _y = self.concat_datasets[new_idx]
_y = int(_y)
# - sanity check concatanted data set aligns with the list of datasets
assert y == _y
if compare_imgs_directly:
# from PIL import ImageChops
# diff = ImageChops.difference(x, _x) # https://stackoverflow.com/questions/35176639/compare-images-python-pil
# assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}' # doesn't work :/
assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
# tensor comparison
if not isinstance(x, Tensor):
x, _x = self.img2tensor(x), self.img2tensor(_x)
if isinstance(y, int):
y, _y = self.int2tensor(y), self.int2tensor(_y)
if verify_xs_align:
# this might fails if there are random ops in the getitem
assert torch.equal(x,
_x), f'Error for some reason, got: {dataset_idx=},' \
f' {new_idx=}, {data_idx=}, ' \
f'{x.norm()=}, {_x.norm()=}, ' \
f'{x=}, {_x=}'
# - relabling
new_label = y + total_num_labels_so_far
self.indices_to_labels[new_idx] = new_label
self.labels_to_indices[new_label].append(new_idx)
new_idx += 1
num_labels_for_current_dataset: int = int(max([y for _, y in dataset])) + 1
# - you'd likely resolve unions if you wanted a proper union, the addition assumes mutual exclusivity
total_num_labels_so_far += num_labels_for_current_dataset
assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
# contains the list of labels from 0 - total num labels after concat, assume mutually exclusive
self.labels = range(total_num_labels_so_far)
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
"""
Get's the data point and it's new label according to a mutually exclusive concatenation.
For later?
to do the relabling on the fly we'd need to figure out which data set idx corresponds to and to compute the
total_num_labels_so_far. Something like this:
current_data_set_idx = bisect_left(idx)
total_num_labels_so_far = sum(max(_, y in dataset)+1 for dataset_idx, dataset in enumerate(self.datasets) if dataset_idx <= current_data_set_idx)
new_y = total_num_labels_so_far
self.indices_to_labels[idx] = new_y
:param idx:
:return:
"""
x, _y = self.concat_datasets[idx]
y = self.indices_to_labels[idx]
# for the first data set they aren't re-labaled so can't use assert
# assert y != _y, f'concat dataset returns x, y so the y is not relabeled, but why are they the same {_y}, {y=}'
# idk what this is but could be useful? mnist had this.
# img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
x = self.transform(x)
if self.target_transform is not None:
y = self.target_transform(y)
return x, y
def assert_dataset_is_pytorch_dataset(datasets: list, verbose: bool = False):
""" to do 1 data set wrap it in a list"""
for dataset in datasets:
if verbose:
print(f'{type(dataset)=}')
print(f'{type(dataset.dataset)=}')
assert isinstance(dataset, Dataset), f'Expect dataset to be of type Dataset but got {type(dataset)=}.'
def get_relabling_counts(dataset: Dataset) -> dict:
"""
counts[new_label] -> counts/number of data points for that new label
"""
assert isinstance(dataset, Dataset), f'Expect dataset to be of type Dataset but got {type(dataset)=}.'
counts: dict = {}
iter_dataset = iter(dataset)
for datapoint in iter_dataset:
x, y = datapoint
# assert isinstance(x, torch.Tensor)
# assert isinstance(y, int)
if y not in counts:
counts[y] = 0
else:
counts[y] += 1
return counts
def assert_relabling_counts(counts: dict, labels: int = 100, counts_per_label: int = 600):
"""
default values are for MI.
- checks each label/class has the right number of expected images per class
- checks the relabels start from 0 and increase by 1
- checks the total number of labels after concat is what you expect
ref: https://openreview.net/pdf?id=rJY0-Kcll
Because the exact splits used in Vinyals et al. (2016)
were not released, we create our own version of the Mini-Imagenet dataset by selecting a random
100 classes from ImageNet and picking 600 examples of each class. We use 64, 16, and 20 classes
for training, validation and testing, respectively.
"""
# - check each image has the right number of total images
seen_labels: list[int] = []
for label, count in counts.items():
seen_labels.append(label)
assert counts[label] == counts_per_label
# - check all labels are there and total is correct
seen_labels.sort()
prev_label = -1
for label in seen_labels:
diff = label - prev_label
assert diff == 1
assert prev_label < label
# - checks the final label is the total number of labels
assert label == labels - 1
def check_entire_data_via_the_dataloader(dataloader: DataLoader) -> dict:
counts: dict = {}
for it, batch in enumerate(dataloader):
xs, ys = batch
for y in ys:
if y not in counts:
counts[y] = 0
else:
counts[y] += 1
return counts
# - tests
def check_xs_align_mnist():
root = Path('~/data/').expanduser()
import torchvision
# - test 1, imgs (not the recommended use)
train = torchvision.datasets.MNIST(root=root, train=True, download=True)
test = torchvision.datasets.MNIST(root=root, train=False, download=True)
concat = ConcatDatasetMutuallyExclusiveLabels([train, test], compare_imgs_directly=True)
print(f'{len(concat)=}')
print(f'{len(concat.labels)=}')
# - test 2, tensor imgs
train = torchvision.datasets.MNIST(root=root, train=True, download=True,
transform=torchvision.transforms.ToTensor(),
target_transform=lambda data: torch.tensor(data, dtype=torch.int))
test = torchvision.datasets.MNIST(root=root, train=False, download=True,
transform=torchvision.transforms.ToTensor(),
target_transform=lambda data: torch.tensor(data, dtype=torch.int))
concat = ConcatDatasetMutuallyExclusiveLabels([train, test], verify_xs_align=True)
print(f'{len(concat)=}')
print(f'{len(concat.labels)=}')
assert len(concat) == 10 * 7000, f'Err, unexpected number of datapoints {len(concat)=} expected {100 * 700}'
assert len(
concat.labels) == 20, f'Note it should be 20 (since it is not a true union), but got {len(concat.labels)=}'
# - test dataloader
loader = DataLoader(concat)
for batch in loader:
x, y = batch
assert isinstance(x, torch.Tensor)
assert isinstance(y, torch.Tensor)
def check_xs_align_cifar100():
from pathlib import Path
root = Path('~/data/').expanduser()
import torchvision
# - test 1, imgs (not the recommended use)
train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
concat = ConcatDatasetMutuallyExclusiveLabels([train, test], compare_imgs_directly=True)
print(f'{len(concat)=}')
print(f'{len(concat.labels)=}')
# - test 2, tensor imgs
train = torchvision.datasets.CIFAR100(root=root, train=True, download=True,
transform=torchvision.transforms.ToTensor(),
target_transform=lambda data: torch.tensor(data, dtype=torch.int))
test = torchvision.datasets.CIFAR100(root=root, train=False, download=True,
transform=torchvision.transforms.ToTensor(),
target_transform=lambda data: torch.tensor(data, dtype=torch.int))
concat = ConcatDatasetMutuallyExclusiveLabels([train, test], verify_xs_align=True)
print(f'{len(concat)=}')
print(f'{len(concat.labels)=}')
assert len(concat) == 100 * 600, f'Err, unexpected number of datapoints {len(concat)=} expected {100 * 600}'
assert len(
concat.labels) == 200, f'Note it should be 200 (since it is not a true union), but got {len(concat.labels)=}'
# details on cifar100: https://www.cs.toronto.edu/~kriz/cifar.html
# - test dataloader
loader = DataLoader(concat)
for batch in loader:
x, y = batch
assert isinstance(x, torch.Tensor)
assert isinstance(y, torch.Tensor)
def concat_data_set_mi():
"""
note test had to be in MI where train, val, test have disjount/different labels. In cifar100 classic the labels
in train, val and test are shared from 0-99 instead of being different/disjoint.
:return:
"""
# - get mi data set
from diversity_src.dataloaders.hdb1_mi_omniglot_l2l import get_mi_datasets
train_dataset, validation_dataset, test_dataset = get_mi_datasets()
assert_dataset_is_pytorch_dataset([train_dataset, validation_dataset, test_dataset])
train_dataset, validation_dataset, test_dataset = train_dataset.dataset, validation_dataset.dataset, test_dataset.dataset
# - create usl data set
union = ConcatDatasetMutuallyExclusiveLabels([train_dataset, validation_dataset, test_dataset])
# union = ConcatDatasetMutuallyExclusiveLabels([train_dataset, validation_dataset, test_dataset],
# compare_imgs_directly=True)
assert_dataset_is_pytorch_dataset([union])
assert len(union) == 100 * 600, f'got {len(union)=}'
assert len(union.labels) == 100, f'got {len(union.labels)=}'
# - create dataloader
from uutils.torch_uu.dataloaders.common import get_serial_or_distributed_dataloaders
union_loader, _ = get_serial_or_distributed_dataloaders(train_dataset=union, val_dataset=union)
for batch in union_loader:
x, y = batch
assert x is not None
assert y is not None
if __name__ == '__main__':
import time
from uutils import report_times
start = time.time()
# - run experiment
check_xs_align_mnist()
check_xs_align_cifar100()
concat_data_set_mi()
# - Done
print(f"\nSuccess Done!: {report_times(start)}\a")
预期正确产出:
len(concat)=70000
len(concat.labels)=20
len(concat)=70000
len(concat.labels)=20
Files already downloaded and verified
Files already downloaded and verified
len(concat)=60000
len(concat.labels)=200
Files already downloaded and verified
Files already downloaded and verified
len(concat)=60000
len(concat.labels)=200
Success Done!: time passed: hours:0.16719497998555502, minutes=10.0316987991333, seconds=601.901927947998
警告:
,如果您有一个随机的转换,数据集对齐的验证可能会使它看起来好像这两个数据点没有丢失一样。代码是正确的,所以它不是一个问题,但可能消除的随机性以某种方式。注意,我实际上决定不强迫用户检查他们数据集的所有映像,并且相信我的代码在我的单元测试之后不会运行。还要注意的是,由于我在开始时做了重新标记,所以构造数据集的速度很慢。最好是在飞行中重新贴上标签。我为它概述了如何做它的代码,但是决定不这样做,因为我们总是看到所有的数据集至少一次,所以这样做,摊销是一样的,就像在执行它一样(注意,for伪代码保存标签以避免重新计算)。
这样做更好:
# int2tensor: Callable = lambda data: torch.tensor(data, dtype=torch.int)
int2tensor: Callable = lambda data: torch.tensor(data, dtype=torch.long)
class ConcatDatasetMutuallyExclusiveLabels(Dataset):
"""
Useful attributes:
- self.labels: contains all new USL labels i.e. contains the list of labels from 0 - total num labels after concat.
- len(self): gives number of images after all images have been concatenated
- self.indices_to_labels: maps the new concat idx to the new label after concat.
ref:
- https://stackoverflow.com/questions/73913522/why-dont-the-images-align-when-concatenating-two-data-sets-in-pytorch-using-tor
- https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
"""
def __init__(self, datasets: list[Dataset],
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
compare_imgs_directly: bool = False,
verify_xs_align: bool = False,
):
"""
Concatenates different data sets assuming the labels are mutually exclusive in the data sets.
compare_imgs_directly: adds the additional test that imgs compare at the PIL imgage level.
"""
self.datasets = datasets
self.transform = transform
self.target_transform = target_transform
# I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
# maps a class label to a list of sample indices with that label.
self.labels_to_indices = defaultdict(list)
# maps a sample index to its corresponding class label.
self.indices_to_labels = defaultdict(None)
# - do the relabeling
self._re_label_all_dataset(datasets, compare_imgs_directly, verify_xs_align)
def __len__(self):
return len(self.concat_datasets)
def _re_label_all_dataset(self, datasets: list[Dataset],
compare_imgs_directly: bool = False,
verify_xs_align: bool = False,
verbose: bool = False,
):
"""
Relabels according to a blind (mutually exclusive) assumption.
Relabling Algorithm:
The zero index of the label starts at the number of labels collected so far. So when relabling we do:
y = y + total_number_labels
total_number_labels += max label for current data set
where total_number_labels always has the + 1 to correct for the zero indexing.
assumption: it re-lables the data points to have a concatenation of all the labels. If there are rebeated labels
they are treated as different. So if dataset1 and dataset2 both have cats (represented as indices), then they
will get unique integers representing these. So the cats are treated as entirely different labels.
"""
print()
self.img2tensor: Callable = torchvision.transforms.ToTensor()
total_num_labels_so_far: int = 0
global_idx: int = 0 # new_idx
assert len(self.indices_to_labels.keys()) == 0
assert len(self.labels_to_indices.keys()) == 0
for dataset_idx, dataset in enumerate(datasets):
print(f'{dataset_idx=} \n{len(dataset)=}')
if hasattr(dataset, 'labels'):
print(f'{len(dataset.labels)=}')
assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
assert dataset == self.concat_datasets.datasets[dataset_idx]
original_label2global_idx: defaultdict = defaultdict(list)
for original_data_idx, (x, original_y) in enumerate(dataset):
original_y = int(original_y)
# - get data point from concataned data set (to compare with the data point from the data set list)
_x, _y = self.concat_datasets[global_idx]
_y = int(_y)
# - sanity check concatanted data set aligns with the list of datasets
assert original_y == _y, f'{original_y=}, {_y=}'
if compare_imgs_directly:
# from PIL import ImageChops
# diff = ImageChops.difference(x, _x) # https://stackoverflow.com/questions/35176639/compare-images-python-pil
# assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}' # doesn't work :/
assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
# - tensor comparison of raw images
if not isinstance(x, Tensor):
x, _x = self.img2tensor(x), self.img2tensor(_x)
# if isinstance(original_y, int):
# original_y, _y = int2tensor(original_y), int2tensor(_y)
if verify_xs_align: # checks the data points after doing get item make them match.
# this might fails if there are random ops in the getitem
assert torch.equal(x,
_x), f'Error for some reason, got: {dataset_idx=},' \
f' {global_idx=}, {original_data_idx=}, ' \
f'{x.norm()=}, {_x.norm()=}, ' \
f'{x=}, {_x=}'
# - collect original labels in dictionary keys
original_label2global_idx[int(original_y)].append(global_idx)
global_idx += 1
print(f'{global_idx=}')
local_num_dps: int = sum(len(global_indices) for global_indices in original_label2global_idx.values())
assert len(dataset) == local_num_dps, f'Error: \n{local_num_dps=} \n{len(dataset)=}'
# - do relabeling - original labeling to new global labels
print(f'{total_num_labels_so_far=}')
assert total_num_labels_so_far != len(dataset), f'Err:\n{total_num_labels_so_far=}\n{len(dataset)=}'
new_local_label2global_indices: dict = {}
global_label2global_indices: dict = {}
# make sure to sort to avoid random looping of unordered data structures e.g. keys in a dict
for new_local_label, original_label in enumerate(sorted(original_label2global_idx.keys())):
global_indices: list[int] = original_label2global_idx[original_label]
new_local_label2global_indices[int(new_local_label)] = global_indices
new_global_label: int = total_num_labels_so_far + new_local_label
global_label2global_indices[int(new_global_label)] = global_indices
local_num_dps: int = sum(len(global_indices) for global_indices in original_label2global_idx.values())
assert len(dataset) == local_num_dps, f'Error: \n{local_num_dps=} \n{len(dataset)=}'
local_num_dps: int = sum(len(global_indices) for global_indices in new_local_label2global_indices.values())
assert len(dataset) == local_num_dps, f'Error: \n{local_num_dps=} \n{len(dataset)=}'
local_num_dps: int = sum(len(global_indices) for global_indices in global_label2global_indices.values())
assert len(dataset) == local_num_dps, f'Error: \n{local_num_dps=} \n{len(dataset)=}'
# - this assumes the integers in each data set is different, if there were unions you'd likely need semantic information about the label e.g. the string cat instead of absolute integers, or know the integers are shared between the two data sets
print(f'{total_num_labels_so_far=}')
# this is the step where classes are concatenated. Note due to the previous loops assuming each label is uning this should never have intersecting keys.
print(f'{list(self.labels_to_indices.keys())=}')
print(f'{list(global_label2global_indices.keys())=}')
dup: list = get_duplicates(list(self.labels_to_indices.keys()) + list(global_label2global_indices.keys()))
print(f'{list(self.labels_to_indices.keys())=}')
print(f'{list(global_label2global_indices.keys())=}')
assert len(dup) == 0, f'Error:\n{self.labels_to_indices.keys()=}\n{global_label2global_indices.keys()=}\n{dup=}'
for global_label, global_indices in global_label2global_indices.items():
# note g_idx might different to global_idx!
global_indices: list[int]
for g_idx in global_indices:
self.labels_to_indices[int(global_label)] = g_idx
self.indices_to_labels[g_idx] = int(global_label)
# - update number of labels seen so far
num_labels_for_current_dataset: int = len(original_label2global_idx.keys())
print(f'{num_labels_for_current_dataset=}')
total_num_labels_so_far += num_labels_for_current_dataset
assert total_num_labels_so_far == len(self.labels_to_indices.keys()), f'Err:\n{total_num_labels_so_far=}' \
f'\n{len(self.labels_to_indices.keys())=}'
assert global_idx == len(self.indices_to_labels.keys()), f'Err:\n{global_idx=}\n{len(self.indices_to_labels.keys())=}'
if hasattr(dataset, 'labels'):
assert len(dataset.labels) == num_labels_for_current_dataset, f'Err:\n{len(dataset.labels)=}' \
f'\n{num_labels_for_current_dataset=}'
# - relabling done
assert len(self.indices_to_labels.keys()) == len(
self.concat_datasets), f'Err: \n{len(self.indices_to_labels.keys())=}' \
f'\n {len(self.concat_datasets)=}'
if all(hasattr(dataset, 'labels') for dataset in datasets):
assert sum(len(dataset.labels) for dataset in datasets) == total_num_labels_so_far
# contains the list of labels from 0 - total num labels after concat, assume mutually exclusive
# - set & validate new labels
self.labels = range(total_num_labels_so_far)
labels = list(sorted(list(self.labels_to_indices.keys())))
assert labels == list(labels), f'labels should match and be consecutive, but got: \n{labels=}, \n{self.labels=}'
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
"""
Get's the data point and it's new label according to a mutually exclusive concatenation.
For later?
to do the relabling on the fly we'd need to figure out which data set idx corresponds to and to compute the
total_num_labels_so_far. Something like this:
current_data_set_idx = bisect_left(idx)
total_num_labels_so_far = sum(max(_, y in dataset)+1 for dataset_idx, dataset in enumerate(self.datasets) if dataset_idx <= current_data_set_idx)
new_y = total_num_labels_so_far + y
self.indices_to_labels[idx] = new_y
:param idx:
:return:
"""
x, _y = self.concat_datasets[idx]
y = self.indices_to_labels[idx]
# for the first data set they aren't re-labaled so can't use assert
# assert y != _y, f'concat dataset returns x, y so the y is not relabeled, but why are they the same {_y}, {y=}'
# idk what this is but could be useful? mnist had this.
# img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
x = self.transform(x)
if self.target_transform is not None:
y = self.target_transform(y)
return x, y
发布于 2022-10-03 05:15:07
您正在连接两个ImageFolder
数据集,即CIFAR10
和CIFAR100
。它们的默认__getitem__
返回PIL.Image
对象,而不是张量。因此,您不能使用torch.equal
来比较两个PIL.Image
对象。
试一试:
train = torchvision.datasets.CIFAR100(root=root, train=True, download=True,
transform=torchvision.transforms.ToTensor())
test = torchvision.datasets.CIFAR100(root=root, train=False, download=True,
transform=torchvision.transforms.ToTensor())
concat = ConcatDataset([train, test])
添加transform
将将dataset返回的PIL.Image
转换为torch.tensor
,并允许您在代码中使用torch.equal
。
发布于 2022-10-12 08:40:54
我保留原来的答案,因为它回答了你的问题。
您的代码中有一个小错误,它会导致所有问题:在内部循环中不增加new_idx
。因此,您可以将直接从dataset
中提取的元素与self.concat_datasets
中相应数据集的第一个元素进行比较。
工作的新代码:
class ConcatDataset(Dataset):
"""
ref:
- https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
- https://stackoverflow.com/questions/73913522/why-dont-the-images-align-when-concatenating-two-data-sets-in-pytorch-using-tor
"""
def __init__(self, datasets: list[Dataset]):
"""
"""
# I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
# maps a class label to a list of sample indices with that label.
self.labels_to_indices = defaultdict(list)
# maps a sample index to its corresponding class label.
self.indices_to_labels = defaultdict(None)
# - do the relabeling
img2tensor: Callable = torchvision.transforms.ToTensor()
offset: int = 0
new_idx: int = 0
for dataset_idx, dataset in enumerate(datasets):
assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
assert dataset == self.concat_datasets.datasets[dataset_idx]
for data_idx, (x, y) in enumerate(dataset):
y = int(y)
# - get data point from concataned data set (to compare with the data point from the data set list)
_x, _y = self.concat_datasets[new_idx]
_y = int(_y)
# - sanity check concatanted data set aligns with the list of datasets
# assert y == _y
# from PIL import ImageChops
# diff = ImageChops.difference(x, _x) # https://stackoverflow.com/questions/35176639/compare-images-python-pil
# assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}'
# assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
# tensor comparison
x, _x = img2tensor(x), img2tensor(_x)
print(f'{data_idx=}, {x.norm()=}, {_x.norm()=}')
assert torch.equal(x, _x), f'Error for some reason, got: {data_idx=}, {x.norm()=}, {_x.norm()=}, {x=}, {_x=}'
# - relabling
new_label = y + offset
self.indices_to_labels[new_idx] = new_label
self.labels_to_indices[new_label] = new_idx
# increment here!!
new_idx += 1 # need to increment inside the inner loop
#
num_labels_for_current_dataset: int = max([y for _, y in dataset])
offset += num_labels_for_current_dataset
# new_idx += 1 <-- THIS IS NOT THE RIGHT PLACE TO INCREMENT!
assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
# contains the list of labels from 0 - total num labels after concat
self.labels = range(offset)
self.target_transform = lambda data: torch.tensor(data, dtype=torch.int)
def __len__(self):
return len(self.concat_datasets)
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
x = self.concat_datasets[idx]
y = self.indices_to_labels[idx]
if self.target_transform is not None:
y = self.target_transform(y)
return x, y
https://stackoverflow.com/questions/73913522
复制相似问题