class CachedDatasetFolder(root, loader=torchvision.datasets.folder.default_loader, extensions=torchvision.datasets.folder.IMG_EXTENSIONS, transform=None, target_transform=None, is_valid_file=None, num_classes=1000)[source]

Bases: torchvision.datasets.DatasetFolder

A cached version of torchvision.datasets.DatasetFolder where the classes and image list are static and cached skiping the costly os.walk and os.scandir calls

class PreprocessedDataset(cachefilepath, basename, qualifiers)[source]



Call this to load the next copy into memory, such as at the end of an epoch.


Name of the file that was actually loaded.


Call this to load the a copy of a dataset with the specific qualifier into memory.


Name of the file that was actually loaded.

class UnionDataset(datasets, transform)[source]


Dataset used to create unions of two or more datasets. The union is created by applying the given transformation to the items in the dataset.

  • datasets – list of datasets of the same size to merge

  • transform – function used to merge 2 items in the datasets

create_validation_data_sampler(dataset, ratio)[source]

Create used to split the dataset into 2 ramdom sampled subsets. The first should used for training and the second for validation.

  • dataset – A valid (i.e. torchvision.datasets.MNIST)

  • ratio – The percentage of the dataset to be used for training. The remaining (1-ratio)% will be used for validation


tuple with 2 (train, validate)

select_subset(classes, class_to_idx, samples, num_classes)[source]

Selects a subset of the classes based on a given number of classes Fixed seed ensures the same classes are always chosen, in either train or val Example: num_classes=11 will select same classes as num_classes=10 plus 1 extra

split_dataset(dataset, groupby)[source]

Split the given dataset into multiple datasets grouped by the given groupby function. For example:

# Split mnist dataset into 10 datasets, one dataset for each label
splitDataset(mnist, groupby=lambda x: x[1])

# Split mnist dataset into 5 datasets, one dataset for each label pair:
# [0,1], [2,3],...
splitDataset(mnist, groupby=lambda x: x[1] // 2)
  • dataset – Source dataset to split

  • groupby – Group by function. See itertools.groupby()


List of datasets