toad.nn.module module

class toad.nn.module.Module[source]

Bases: torch.nn.modules.module.Module

base module for every model

__init__()[source]

define model struct

device

device of model

fit(loader, trainer=None, optimizer=None, early_stopping=None, **kwargs)[source]

train model

Parameters:
  • loader (DataLoader) – loader for training model
  • trainer (Trainer) – trainer for training model
  • optimizer (torch.Optimier) – the default optimizer is Adam(lr = 1e-3)
  • early_stopping (earlystopping) – the default value is loss_earlystopping, you can set it to False to disable early stopping
  • epoch (int) – number of epoch for training loop
  • callback (callable) – callable function will be called every epoch
evaluate(loader, trainer=None)[source]

evaluate model :param loader: loader for evaluate model :type loader: DataLoader :param trainer: trainer for evaluate model :type trainer: Trainer

fit_step(batch, *args, **kwargs)[source]

step for fitting :param batch: batch data from dataloader :type batch: Any

Returns:loss of this step
Return type:Tensor
save(path)[source]

save model

load(path)[source]

load model

log(key, value)[source]

log values to history

Parameters:
  • key (str) – name of message
  • value (Tensor) – tensor of values
distributed(backend=None, **kwargs)[source]

get distributed model

class toad.nn.module.DistModule(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False)[source]

Bases: torch.nn.parallel.distributed.DistributedDataParallel

distributed module class