Merging datasets

Basic usage

The class MergeDataset is used to merge multiple datasets:

from alodataset import FlyingThings3DSubsetDataset, MergeDataset
dataset1 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["left"])
dataset2 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["right"])
dataset = MergeDataset([dataset1, dataset2])

It is then possible to shuffle the datasets together, and sample batches than can contain items from different datasets:

# this batch can contain items from dataset1 and/or dataset2
batch = next(iter(dataset.train_loader(batch_size=4)))

It is possible to apply specific transformations to each dataset, and then apply the same global transformation to the items:

from alodataset.transforms import RandomCrop
dataset1 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["left"]) # specific transform
dataset2 = FlyingThings3DSubsetDataset(sequence_size=2, transform_fn=lambda f: f["right"]) # specific transform
dataset = MergeDataset([dataset1, dataset2], transform_fn=RandomCrop(size=(368, 496)) # global transform

MergeDataset API

class alodataset.merge_dataset.MergeDataset(datasets, transform_fn=None)

Bases: Generic[torch.utils.data.dataset.T_co]

Dataset merging multiple alodataset.BaseDataset

Iterating sequentially over the dataset will yield all samples from first dataset, then all samples from the next dataset until the last dataset.

Shuffling the dataset will shuffle the samples of all datasets together

Parameters
datasetsList[alodataset.BaseDataset]

List of datasets

transform_fnfunction

transformation applied to each sample

stream_loader(num_workers=2)

Get a stream loader from the dataset. Compared to the train_loader() the stream_loader() do not have batch dimension and do not shuffle the dataset.

Parameters
datasettorch.utils.data.Dataset

Dataset to make dataloader

num_workersint

Number of workers, by default 2

Returns
torch.utils.data.DataLoader

A generator

train_loader(batch_size=1, num_workers=2, sampler=<class 'torch.utils.data.sampler.RandomSampler'>)

Get training loader from the dataset