Merging datasets¶
Basic usage¶
The class MergeDataset
is used to merge multiple datasets:
from alodataset import FlyingThings3DSubsetDataset, MergeDataset
dataset1 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["left"])
dataset2 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["right"])
dataset = MergeDataset([dataset1, dataset2])
It is then possible to shuffle the datasets together, and sample batches than can contain items from different datasets:
# this batch can contain items from dataset1 and/or dataset2
batch = next(iter(dataset.train_loader(batch_size=4)))
It is possible to apply specific transformations to each dataset, and then apply the same global transformation to the items:
from alodataset.transforms import RandomCrop
dataset1 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["left"]) # specific transform
dataset2 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["right"]) # specific transform
dataset = MergeDataset([dataset1, dataset2], transform_fn=RandomCrop(size=(368, 496)) # global transform
MergeDataset API¶
- class alodataset.merge_dataset.MergeDataset(datasets, transform_fn=None)¶
Bases:
Generic
[torch.utils.data.dataset.T_co
]Dataset merging multiple alodataset.BaseDataset
Iterating sequentially over the dataset will yield all samples from first dataset, then all samples from the next dataset until the last dataset.
Shuffling the dataset will shuffle the samples of all datasets together
- Parameters
- datasetsList[alodataset.BaseDataset]
List of datasets
- transform_fnfunction
transformation applied to each sample
- stream_loader(num_workers=2)¶
Get a stream loader from the dataset. Compared to the
train_loader()
thestream_loader()
do not have batch dimension and do not shuffle the dataset.- Parameters
- datasettorch.utils.data.Dataset
Dataset to make dataloader
- num_workersint
Number of workers, by default 2
- Returns
- torch.utils.data.DataLoader
A generator
- train_loader(batch_size=1, num_workers=2, sampler=<class 'torch.utils.data.sampler.RandomSampler'>)¶
Get training loader from the dataset