toad.nn.module module¶
-
class
toad.nn.module.
Module
[source]¶ Bases:
torch.nn.modules.module.Module
base module for every model
-
device
¶ device of model
-
fit
(loader, trainer=None, optimizer=None, early_stopping=None, **kwargs)[source]¶ train model
Parameters: - loader (DataLoader) – loader for training model
- trainer (Trainer) – trainer for training model
- optimizer (torch.Optimier) – the default optimizer is Adam(lr = 1e-3)
- early_stopping (earlystopping) – the default value is loss_earlystopping, you can set it to False to disable early stopping
- epoch (int) – number of epoch for training loop
- callback (callable) – callable function will be called every epoch
-
evaluate
(loader, trainer=None)[source]¶ evaluate model :param loader: loader for evaluate model :type loader: DataLoader :param trainer: trainer for evaluate model :type trainer: Trainer
-
fit_step
(batch, *args, **kwargs)[source]¶ step for fitting :param batch: batch data from dataloader :type batch: Any
Returns: loss of this step Return type: Tensor
-
-
class
toad.nn.module.
DistModule
(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False)[source]¶ Bases:
torch.nn.parallel.distributed.DistributedDataParallel
distributed module class