nupic.research.frameworks.pytorch.dataset_utils

class PreprocessedDataset(cachefilepath, basename, qualifiers)[source]

Bases: torch.utils.data.Dataset

load_next()[source]

Call this to load the next copy into memory, such as at the end of an epoch.

Returns

Name of the file that was actually loaded.

load_qualifier(qualifier)[source]

Call this to load the a copy of a dataset with the specific qualifier into memory.

Returns

Name of the file that was actually loaded.

class UnionDataset(datasets, transform)[source]

Bases: torch.utils.data.Dataset

Dataset used to create unions of two or more datasets. The union is created by applying the given transformation to the items in the dataset.

Parameters
  • datasets – list of datasets of the same size to merge

  • transform – function used to merge 2 items in the datasets

create_validation_data_sampler(dataset, ratio)[source]

Create torch.utils.data.Sampler used to split the dataset into 2 ramdom sampled subsets. The first should used for training and the second for validation.

Parameters
  • dataset – A valid torch.utils.data.Dataset (i.e. torchvision.datasets.MNIST)

  • ratio – The percentage of the dataset to be used for training. The remaining (1-ratio)% will be used for validation

Returns

tuple with 2 torch.utils.data.Sampler. (train, validate)

split_dataset(dataset, groupby)[source]

Split the given dataset into multiple datasets grouped by the given groupby function. For example:

# Split mnist dataset into 10 datasets, one dataset for each label
splitDataset(mnist, groupby=lambda x: x[1])

# Split mnist dataset into 5 datasets, one dataset for each label pair:
# [0,1], [2,3],...
splitDataset(mnist, groupby=lambda x: x[1] // 2)
Parameters
  • dataset – Source dataset to split

  • groupby – Group by function. See itertools.groupby()

Returns

List of datasets