toad.nn.module module

class toad.nn.module.Module[source]

Bases: torch.nn.modules.module.Module

base module for every model

Examples

>>> from toad.nn import Module
... from torch import nn
...
... class Net(Module):
...     def __init__(self, inputs, hidden, outputs):
...         super().__init__()
...         self.model = nn.Sequential(
...             nn.Linear(inputs, hidden),
...             nn.ReLU(),
...             nn.Linear(hidden, outputs),
...             nn.Sigmoid(),
...         )
...
...     def forward(self, x):
...         return self.model(x)
...
...     def fit_step(self, batch):
...         x, y = batch
...         y_hat = self(x)
...
...         # log into history
...         self.log('y', y)
...         self.log('y_hat', y_hat)
...
...         return nn.functional.mse_loss(y_hat, y)
...
... model = Net(10, 4, 1)
...
... model.fit(train_loader)
__init__()[source]

define model struct

device

device of model

fit(loader, trainer=None, optimizer=None, loss=None, early_stopping=None, **kwargs)[source]

train model

Parameters:
  • loader (DataLoader) – loader for training model
  • trainer (Trainer) – trainer for training model
  • optimizer (torch.Optimier) – the default optimizer is Adam(lr = 1e-3)
  • loss (Callable) – could be called as ‘loss(y_hat, y)’
  • early_stopping (earlystopping) – the default value is loss_earlystopping, you can set it to False to disable early stopping
  • epoch (int) – number of epoch for training loop
  • callback (callable) – callable function will be called every epoch
evaluate(loader, trainer=None)[source]

evaluate model

Parameters:
  • loader (DataLoader) – loader for evaluate model
  • trainer (Trainer) – trainer for evaluate model
fit_step(batch, loss=None, *args, **kwargs)[source]

step for fitting

Parameters:
  • batch (Any) – batch data from dataloader
  • loss (Callable) – could be called as ‘loss(y_hat, y)’
Returns:

loss of this step

Return type:

Tensor

save(path)[source]

save model

load(path)[source]

load model

log(key, value)[source]

log values to history

Parameters:
  • key (str) – name of message
  • value (Tensor) – tensor of values
distributed(backend=None, **kwargs)[source]

get distributed model

class toad.nn.module.DistModule(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False)[source]

Bases: torch.nn.parallel.distributed.DistributedDataParallel

distributed module class