toad.nn.trainer module

class toad.nn.trainer.History[source]

Bases: object

model history

__init__()[source]

Initialize self. See help(type(self)) for accurate signature.

log(key, value)[source]

log message to history

Parameters:
  • key (str) – name of message
  • value (Tensor) – tensor of values
class toad.nn.trainer.callback(*args, **kwargs)[source]

Bases: toad.utils.decorator.Decorator

callback for trainer

Examples

>>> @callback
... def savemodel(model):
...     model.save("path_to_file")
...
... trainer.train(model, callback = savemodel)
__init__(*args, **kwargs)[source]

Initialize self. See help(type(self)) for accurate signature.

class toad.nn.trainer.earlystopping(*args, **kwargs)[source]

Bases: toad.nn.trainer.callback.callback

Examples

>>> @earlystopping(delta = 1e-3, patience = 5)
... def auc(history):
...     return AUC(history['y_hat'], history['y'])
setup(delta=-0.001, patience=10, skip=0)[source]
Parameters:
  • delta (float) – stop training if diff of new score is smaller than delta
  • patience (int) – patience of rounds to stop training
  • skip (int) – n rounds from starting training to warm up
get_best_state()[source]

get best state of model

reset()[source]
class toad.nn.trainer.Trainer(model, loader=None, optimizer=None, loss=None, keep_history=None, early_stopping=None)[source]

Bases: toad.nn.trainer.event.Event

trainer for training models

__init__(model, loader=None, optimizer=None, loss=None, keep_history=None, early_stopping=None)[source]

initialization

Parameters:
  • model (nn.Module) – model will be trained
  • loader (torch.DataLoader) – training data loader
  • optimizer (torch.Optimier) – the default optimizer is Adam(lr = 1e-3)
  • loss (Callable) – could be called as ‘loss(y_hat, y)’
  • early_stopping (earlystopping) – the default value is loss_earlystopping, you can set it to False to disable early stopping
  • keep_history (int) – keep the last n-th epoch logs, None will keep all
set_model(model)[source]

setup model

distributed(address=None, workers=4, gpu=False)[source]

setting distribution enviroment and initial a ray cluster connection

Parameters:
  • address (string) – the head of ray cluster address
  • workers (int) – compute task’s resource
  • gpu (Booleans) – whether use GPU, “True” or “False”
train(loader=None, epoch=10, **kwargs)[source]
Parameters:
  • loader (torch.DataLoader) – training data loader
  • epoch (int) – number of epoch for training loop
  • callback (list[Callback]) –

    callable function will be called every epoch - parameters of callback

    model (nn.Module): the training model history (History): history of total log records epoch (int): current epoch number trainer (Trainer): self trainer
  • start (int) – epoch start from n round
  • backward_rounds (int) – backward after every n rounds
Returns:

the model with best performents

Return type:

Module

evaluate(loader, callback=None)[source]

evalute model

Parameters:
  • loader (torch.DataLoader) – evaluation data loader
  • callback (callable) – callback function