toad.nn.trainer module¶
- class toad.nn.trainer.callback(*args, **kwargs)[source]¶
Bases:
Decoratorcallback for trainer
Examples
>>> @callback ... def savemodel(model): ... model.save("path_to_file") ... ... trainer.train(model, callback = savemodel)
- class toad.nn.trainer.earlystopping(*args, **kwargs)[source]¶
Bases:
callbackExamples
>>> @earlystopping(delta = 1e-3, patience = 5) ... def auc(history): ... return AUC(history['y_hat'], history['y'])
- class toad.nn.trainer.Trainer(model, loader=None, optimizer=None, loss=None, keep_history=None, early_stopping=None)[source]¶
Bases:
Eventtrainer for training models
- __init__(model, loader=None, optimizer=None, loss=None, keep_history=None, early_stopping=None)[source]¶
initialization
- Parameters
model (nn.Module) – model will be trained
loader (torch.DataLoader) – training data loader
optimizer (torch.Optimier) – the default optimizer is Adam(lr = 1e-3)
loss (Callable) – could be called as ‘loss(y_hat, y)’
early_stopping (earlystopping) – the default value is loss_earlystopping, you can set it to False to disable early stopping
keep_history (int) – keep the last n-th epoch logs, None will keep all
- distributed(address=None, workers=4, gpu=False)[source]¶
setting distribution enviroment and initial a ray cluster connection
- Parameters
address (string) – the head of ray cluster address
workers (int) – compute task’s resource
gpu (Booleans) – whether use GPU, “True” or “False”
- train(loader=None, epoch=10, **kwargs)[source]¶
- Parameters
loader (torch.DataLoader) – training data loader
epoch (int) – number of epoch for training loop
callback (list[Callback]) –
callable function will be called every epoch - parameters of callback
model (nn.Module): the training model history (History): history of total log records epoch (int): current epoch number trainer (Trainer): self trainer
start (int) – epoch start from n round
backward_rounds (int) – backward after every n rounds
- Returns
the model with best performents
- Return type