Merging datasets¶
Basic usage¶
The class MergeDataset is used to merge multiple datasets:
from alodataset import FlyingThings3DSubsetDataset, MergeDataset
dataset1 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["left"])
dataset2 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["right"])
dataset = MergeDataset([dataset1, dataset2])
It is then possible to shuffle the datasets together, and sample batches than can contain items from different datasets:
# this batch can contain items from dataset1 and/or dataset2
batch = next(iter(dataset.train_loader(batch_size=4)))
It is possible to apply specific transformations to each dataset, and then apply the same global transformation to the items:
from alodataset.transforms import RandomCrop
dataset1 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["left"]) # specific transform
dataset2 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["right"]) # specific transform
dataset = MergeDataset([dataset1, dataset2], transform_fn=RandomCrop(size=(368, 496)) # global transform
MergeDataset API¶
- class alodataset.merge_dataset.MergeDataset(datasets, transform_fn=None)¶
- Bases: - Generic[- torch.utils.data.dataset.T_co]- Dataset merging multiple alodataset.BaseDataset - Iterating sequentially over the dataset will yield all samples from first dataset, then all samples from the next dataset until the last dataset. - Shuffling the dataset will shuffle the samples of all datasets together - Parameters
- datasetsList[alodataset.BaseDataset]
- List of datasets 
- transform_fnfunction
- transformation applied to each sample 
 
 - stream_loader(num_workers=2)¶
- Get a stream loader from the dataset. Compared to the - train_loader()the- stream_loader()do not have batch dimension and do not shuffle the dataset.- Parameters
- datasettorch.utils.data.Dataset
- Dataset to make dataloader 
- num_workersint
- Number of workers, by default 2 
 
- Returns
- torch.utils.data.DataLoader
- A generator 
 
 
 - train_loader(batch_size=1, num_workers=2, sampler=<class 'torch.utils.data.sampler.RandomSampler'>)¶
- Get training loader from the dataset