toad.nn.module module¶
-
class
toad.nn.module.
Module
[source]¶ Bases:
torch.nn.modules.module.Module
base module for every model
Examples
>>> from toad.nn import Module ... from torch import nn ... ... class Net(Module): ... def __init__(self, inputs, hidden, outputs): ... super().__init__() ... self.model = nn.Sequential( ... nn.Linear(inputs, hidden), ... nn.ReLU(), ... nn.Linear(hidden, outputs), ... nn.Sigmoid(), ... ) ... ... def forward(self, x): ... return self.model(x) ... ... def fit_step(self, batch): ... x, y = batch ... y_hat = self(x) ... ... # log into history ... self.log('y', y) ... self.log('y_hat', y_hat) ... ... return nn.functional.mse_loss(y_hat, y) ... ... model = Net(10, 4, 1) ... ... model.fit(train_loader)
-
device
¶ device of model
-
fit
(loader, trainer=None, optimizer=None, loss=None, early_stopping=None, **kwargs)[source]¶ train model
Parameters: - loader (DataLoader) – loader for training model
- trainer (Trainer) – trainer for training model
- optimizer (torch.Optimier) – the default optimizer is Adam(lr = 1e-3)
- loss (Callable) – could be called as ‘loss(y_hat, y)’
- early_stopping (earlystopping) – the default value is loss_earlystopping, you can set it to False to disable early stopping
- epoch (int) – number of epoch for training loop
- callback (callable) – callable function will be called every epoch
-
evaluate
(loader, trainer=None)[source]¶ evaluate model
Parameters: - loader (DataLoader) – loader for evaluate model
- trainer (Trainer) – trainer for evaluate model
-
fit_step
(batch, loss=None, *args, **kwargs)[source]¶ step for fitting
Parameters: - batch (Any) – batch data from dataloader
- loss (Callable) – could be called as ‘loss(y_hat, y)’
Returns: loss of this step
Return type: Tensor
-
-
class
toad.nn.module.
DistModule
(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False)[source]¶ Bases:
torch.nn.parallel.distributed.DistributedDataParallel
distributed module class