nupic.research.frameworks.pytorch.model_utils

count_nonzero_params(model)[source]

Count the total number of non-zero weights in the model, including bias weights.

evaluate_model(model, loader, device, batches_in_epoch=9223372036854775807, criterion=torch.nn.functional.nll_loss, progress=None)[source]

Evaluate pre-trained model using given test dataset loader.

Parameters
  • model (torch.nn.Module) – Pretrained pytorch model

  • loader (torch.utils.data.DataLoader) – test dataset loader

  • device (torch.device) – device to use (‘cpu’ or ‘cuda’)

  • batches_in_epoch (int) – Max number of mini batches to test on.

  • criterion (function) – loss function to use

  • progress (dict or None) – Optional tqdm progress bar args. None for no progress bar

Returns

dictionary with computed “mean_accuracy”, “mean_loss”, “total_correct”.

Return type

dict

set_random_seed(seed, deterministic_mode=True)[source]

Set pytorch, python random, and numpy random seeds (these are all the seeds we normally use).

Parameters
train_model(model, loader, optimizer, device, criterion=torch.nn.functional.nll_loss, batches_in_epoch=9223372036854775807, pre_batch_callback=None, post_batch_callback=None, progress_bar=None)[source]

Train the given model by iterating through mini batches. An epoch ends after one pass through the training set, or if the number of mini batches exceeds the parameter “batches_in_epoch”.

Parameters
  • model (torch.nn.Module) – pytorch model to be trained

  • loader (torch.utils.data.DataLoader) – train dataset loader

  • optimizer – Optimizer object used to train the model. This function will train the model on every batch using this optimizer and the torch.nn.functional.nll_loss() function

  • batches_in_epoch – Max number of mini batches to train.

  • device (:class:`torch.device) – device to use (‘cpu’ or ‘cuda’)

  • criterion (function) – loss function to use

  • post_batch_callback (function) – Callback function to be called after every batch with the following parameters: model, batch_idx

  • pre_batch_callback (function) – Callback function to be called before every batch with the following parameters: model, batch_idx

  • progress_bar (dict or None) – Optional tqdm progress bar args. None for no progress bar