toad.nn.module module

class toad.nn.module.Module[source]

Bases: Module

base module for every model

Examples

>>> from toad.nn import Module
... from torch import nn
...
... class Net(Module):
...     def __init__(self, inputs, hidden, outputs):
...         super().__init__()
...         self.model = nn.Sequential(
...             nn.Linear(inputs, hidden),
...             nn.ReLU(),
...             nn.Linear(hidden, outputs),
...             nn.Sigmoid(),
...         )
...
...     def forward(self, x):
...         return self.model(x)
...
...     def fit_step(self, batch):
...         x, y = batch
...         y_hat = self(x)
...
...         # log into history
...         self.log('y', y)
...         self.log('y_hat', y_hat)
...
...         return nn.functional.mse_loss(y_hat, y)
...
... model = Net(10, 4, 1)
...
... model.fit(train_loader)
__init__()[source]

define model struct

property device

device of model

fit(loader, trainer=None, optimizer=None, loss=None, early_stopping=None, **kwargs)[source]

train model

Parameters
  • loader (DataLoader) – loader for training model

  • trainer (Trainer) – trainer for training model

  • optimizer (torch.Optimier) – the default optimizer is Adam(lr = 1e-3)

  • loss (Callable) – could be called as ‘loss(y_hat, y)’

  • early_stopping (earlystopping) – the default value is loss_earlystopping, you can set it to False to disable early stopping

  • epoch (int) – number of epoch for training loop

  • callback (callable) – callable function will be called every epoch

evaluate(loader, trainer=None)[source]

evaluate model

Parameters
  • loader (DataLoader) – loader for evaluate model

  • trainer (Trainer) – trainer for evaluate model

fit_step(batch, loss=None, *args, **kwargs)[source]

step for fitting

Parameters
  • batch (Any) – batch data from dataloader

  • loss (Callable) – could be called as ‘loss(y_hat, y)’

Returns

loss of this step

Return type

Tensor

save(path)[source]

save model

load(path)[source]

load model

log(key, value)[source]

log values to history

Parameters
  • key (str) – name of message

  • value (Tensor) – tensor of values

distributed(backend=None, **kwargs)[source]

get distributed model

class toad.nn.module.DistModule(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False, delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision: Optional[_MixedPrecision] = None, device_mesh=None)[source]

Bases: DistributedDataParallel

distributed module class