toad.nn.module module¶
- class toad.nn.module.Module[source]¶
Bases:
Modulebase module for every model
Examples
>>> from toad.nn import Module ... from torch import nn ... ... class Net(Module): ... def __init__(self, inputs, hidden, outputs): ... super().__init__() ... self.model = nn.Sequential( ... nn.Linear(inputs, hidden), ... nn.ReLU(), ... nn.Linear(hidden, outputs), ... nn.Sigmoid(), ... ) ... ... def forward(self, x): ... return self.model(x) ... ... def fit_step(self, batch): ... x, y = batch ... y_hat = self(x) ... ... # log into history ... self.log('y', y) ... self.log('y_hat', y_hat) ... ... return nn.functional.mse_loss(y_hat, y) ... ... model = Net(10, 4, 1) ... ... model.fit(train_loader)
- property device¶
device of model
- fit(loader, trainer=None, optimizer=None, loss=None, early_stopping=None, **kwargs)[source]¶
train model
- Parameters
loader (DataLoader) – loader for training model
trainer (Trainer) – trainer for training model
optimizer (torch.Optimier) – the default optimizer is Adam(lr = 1e-3)
loss (Callable) – could be called as ‘loss(y_hat, y)’
early_stopping (earlystopping) – the default value is loss_earlystopping, you can set it to False to disable early stopping
epoch (int) – number of epoch for training loop
callback (callable) – callable function will be called every epoch
- evaluate(loader, trainer=None)[source]¶
evaluate model
- Parameters
loader (DataLoader) – loader for evaluate model
trainer (Trainer) – trainer for evaluate model
- fit_step(batch, loss=None, *args, **kwargs)[source]¶
step for fitting
- Parameters
batch (Any) – batch data from dataloader
loss (Callable) – could be called as ‘loss(y_hat, y)’
- Returns
loss of this step
- Return type
Tensor
- class toad.nn.module.DistModule(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=None, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=False, static_graph=False, delay_all_reduce_named_params=None, param_to_hook_all_reduce=None, mixed_precision: Optional[_MixedPrecision] = None, device_mesh=None)[source]¶
Bases:
DistributedDataParalleldistributed module class