nupic.research.frameworks.pytorch.dataset_utils

class CachedDatasetFolder(root, loader=torchvision.datasets.folder.default_loader, extensions=torchvision.datasets.folder.IMG_EXTENSIONS, transform=None, target_transform=None, is_valid_file=None, num_classes=1000)[source]

Bases: torchvision.datasets.DatasetFolder

A cached version of torchvision.datasets.DatasetFolder where the classes and image list are static and cached skiping the costly os.walk and os.scandir calls

class PreprocessedDataset(cachefilepath, basename, qualifiers)[source]

Bases: torch.utils.data.Dataset

load_next()[source]

Call this to load the next copy into memory, such as at the end of an epoch.

Returns

Name of the file that was actually loaded.

load_qualifier(qualifier)[source]

Call this to load the a copy of a dataset with the specific qualifier into memory.

Returns

Name of the file that was actually loaded.

class UnionDataset(datasets, transform)[source]

Bases: torch.utils.data.Dataset

Dataset used to create unions of two or more datasets. The union is created by applying the given transformation to the items in the dataset.

Parameters
  • datasets – list of datasets of the same size to merge

  • transform – function used to merge 2 items in the datasets

create_validation_data_sampler(dataset, ratio)[source]

Create torch.utils.data.Sampler used to split the dataset into 2 ramdom sampled subsets. The first should used for training and the second for validation.

Parameters
  • dataset – A valid torch.utils.data.Dataset (i.e. torchvision.datasets.MNIST)

  • ratio – The percentage of the dataset to be used for training. The remaining (1-ratio)% will be used for validation

Returns

tuple with 2 torch.utils.data.Sampler. (train, validate)

select_subset(classes, class_to_idx, samples, num_classes)[source]

Selects a subset of the classes based on a given number of classes Fixed seed ensures the same classes are always chosen, in either train or val Example: num_classes=11 will select same classes as num_classes=10 plus 1 extra

split_dataset(dataset, groupby)[source]

Split the given dataset into multiple datasets grouped by the given groupby function. For example:

# Split mnist dataset into 10 datasets, one dataset for each label
splitDataset(mnist, groupby=lambda x: x[1])

# Split mnist dataset into 5 datasets, one dataset for each label pair:
# [0,1], [2,3],...
splitDataset(mnist, groupby=lambda x: x[1] // 2)
Parameters
  • dataset – Source dataset to split

  • groupby – Group by function. See itertools.groupby()

Returns

List of datasets