event.dispatcher
- class pydgn.training.event.dispatcher.EventDispatcher
Bases:
object
Class implementing the publisher/subscribe pattern. It is used to register subscribers that implement the
EventHandler
interface- _dispatch(event_name: str, state: pydgn.training.event.state.State)
Triggers the callback
event_name
for all subscribers (note: order matters!)- Parameters
event_name (str) – the name of the callback to trigger
state (
State
) – object holding training information
- register(event_handler)
Registers a subscriber
- Parameters
event_handler – an object implementing the
EventHandler
interface
- unregister(event_handler)
De-registers a subscriber
- Parameters
event_handler – an object implementing the
EventHandler
interface
event.handler
- class pydgn.training.event.handler.EventHandler
Bases:
object
Interface that adheres to the publisher/subscribe pattern for training. It defines the main methods that a subscriber should implement. Each subscriber can make use of the
State
object that is passed to each method, so detailed knowledge about that object is required.This class defines a set of callbacks that should cover a sufficient number of use cases. These are meant to work closely with the
TrainingEngine
object, which implements the overall training and evaluation process. This training engine is fairly general to accomodate a number of situations, so we expect we won’t need to change it much to deal with static graph problems.We list below some pre/post conditions for each method that depend on the current implementation of the main training engine
TrainingEngine
. These are clearly not strict conditions, but they can help design new training engines with their own publisher/subscriber patterns or create subclasses ofTrainingEngine
that require special modifications.- ON_BACKWARD = 'on_backward'
- ON_COMPUTE_METRICS = 'on_compute_metrics'
- ON_EPOCH_END = 'on_epoch_end'
- ON_EPOCH_START = 'on_epoch_start'
- ON_EVAL_BATCH_END = 'on_eval_batch_end'
- ON_EVAL_BATCH_START = 'on_eval_batch_start'
- ON_EVAL_EPOCH_END = 'on_eval_epoch_end'
- ON_EVAL_EPOCH_START = 'on_eval_epoch_start'
- ON_FETCH_DATA = 'on_fetch_data'
- ON_FIT_END = 'on_fit_end'
- ON_FIT_START = 'on_fit_start'
- ON_FORWARD = 'on_forward'
- ON_TRAINING_BATCH_END = 'on_training_batch_end'
- ON_TRAINING_BATCH_START = 'on_training_batch_start'
- ON_TRAINING_EPOCH_END = 'on_training_epoch_end'
- ON_TRAINING_EPOCH_START = 'on_training_epoch_start'
- on_backward(state: pydgn.training.event.state.State)
Updates the parameters of the model using loss information.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.batch_loss
: a dictionary holding the loss of the minibatch
- on_compute_metrics(state: pydgn.training.event.state.State)
Computes the metrics of interest using the output and ground truth information obtained so far. The loss-related subscriber MUST be called before the score-related one
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.batch_input
: the input to be fed to the modelstate.batch_targets
: the ground truth values to be fed to the model (if any, ow a dummy value can be used)state.batch_outputs
: the output produced the model (a tuple of values)
- Post-condition:
- The following fields have been initialized:
state.batch_loss
: a dictionary holding the loss of the minibatchstate.batch_loss_extra
: a dictionary containing extra info, e.g., intermediate loss scores etc.state.batch_score
: a dictionary holding the score of the minibatch
- on_epoch_end(state: pydgn.training.event.state.State)
Perform bookkeeping operations at the end of an epoch, e.g., early stopping, plotting, etc.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.epoch_loss
: a dictionary containing theaggregated loss value across all minibatches
state.epoch_score
: a dictionary containing theaggregated score value across all minibatches
- Post-condition:
- The following fields have been initialized:
state.stop_training
: do/don’t train the modelstate.optimizer_state
: the internal state of theoptimizer (can be
None
)
state.scheduler_state
: the internal state of thescheduler (can be
None
)
state.best_epoch_results
: a dictionary with the bestresults computed so far (can be used when resuming training, either for early stopping or to keep some information about the last checkpoint).
- on_epoch_start(state: pydgn.training.event.state.State)
Initialize/reset some internal state at the start of a training/evaluation epoch.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.epoch
: the current epochstate.return_node_embeddings
:do/don’t return node_embeddings for each graph at the end of the epoch
- on_eval_batch_end(state: pydgn.training.event.state.State)
Initialize/reset some internal state after evaluating on a new minibatch of data.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.set
: the dataset type (can beTRAINING
,VALIDATION
orTEST
)state.batch_num_graphs
: the total number of graphs in the minibatchstate.batch_num_nodes
: the total number of nodes in the minibatchstate.batch_num_targets
: the total number of ground truth values in the minibatchstate.batch_loss
: a dictionary holding the loss of the minibatchstate.batch_loss_extra
: a dictionary containing extra info, e.g., intermediate loss scores etc.state.batch_score
: a dictionary holding the score of the minibatch
- on_eval_batch_start(state: pydgn.training.event.state.State)
Initialize/reset some internal state before evaluating on a new minibatch of data.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.set
: the dataset type (can beTRAINING
,VALIDATION
orTEST
)state.batch_input
: the input to be fed to the modelstate.batch_targets
: the ground truth values to be fed to the model (if any, ow a dummy value can be used)state.batch_num_graphs
: the total number of graphs in the minibatchstate.batch_num_nodes
: the total number of nodes in the minibatchstate.batch_num_targets
: the total number of ground truth values in the minibatch
- on_eval_epoch_end(state: pydgn.training.event.state.State)
Initialize/reset some internal state at the end of an evaluation epoch.
- Parameters
state (
State
) – object holding training information
- Post-condition:
- The following fields have been initialized:
state.epoch_loss
: a dictionary containing the aggregated loss value across all minibatchesstate.epoch_score
: a dictionary containing the aggregated score value across all minibatches
- on_eval_epoch_start(state: pydgn.training.event.state.State)
Initialize/reset some internal state at the start of an evaluation epoch.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.set
: the dataset type (can beTRAINING
,VALIDATION
orTEST
)
- on_fetch_data(state: pydgn.training.event.state.State)
Load the next batch of data, possibly applying some kind of additional pre-processing not included in the
transform
package.- Parameters
state (
State
) – object holding training information
- Pre-condition:
The data loader is contained in
state.loader_iterable
and the minibatch ID (i.e., a counter) is stored in``state.id_batch``- Post-condition:
The
state
object now has a fieldbatch_input
with the next batch of data
- on_fit_end(state: pydgn.training.event.state.State)
Training has ended, free all resources, e.g., close Tensorboard writers.
- Parameters
state (
State
) – object holding training information
- on_fit_start(state: pydgn.training.event.state.State)
Initialize an object at the beginning of the training phase, e.g., the internals of an optimizer, using the information contained in
state
.- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.initial_epoch
: the initial epoch from which to start/resume trainingstate.stop_training
: do/don’t train the modelstate.optimizer_state
: the internal state of the optimizer (can beNone
)state.scheduler_state
: the internal state of the scheduler (can beNone
)state.best_epoch_results
: a dictionary with the best results computed so far (can be used when resuming training, either for early stopping or to keep some information about the last checkpoint).
- on_forward(state: pydgn.training.event.state.State)
Feed the input data to the model.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.batch_input
: the input to be fed to the modelstate.batch_targets
: the ground truth values to be fed to the model (if any, ow a dummy value can be used)
- Post-condition:
- The following fields have been initialized:
state.batch_outputs
: the output produced the model (a tuple of values)
- on_training_batch_end(state: pydgn.training.event.state.State)
Initialize/reset some internal state after training on a new minibatch of data.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.set
: it must be set toTRAINING
state.batch_num_graphs
: the total number of graphs in the minibatchstate.batch_num_nodes
: the total number of nodes in the minibatchstate.batch_num_targets
: the total number of ground truth values in the minibatchstate.batch_loss
: a dictionary holding the loss of the minibatchstate.batch_loss_extra
: a dictionary containing extra info, e.g., intermediate loss scores etc.state.batch_score
: a dictionary holding the score of the minibatch
- on_training_batch_start(state: pydgn.training.event.state.State)
Initialize/reset some internal state before training on a new minibatch of data.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.set
: it must be set toTRAINING
state.batch_input
: the input to be fed to the modelstate.batch_targets
: the ground truth values to be fed to the model (if any, ow a dummy value can be used)state.batch_num_graphs
: the total number of graphs in the minibatchstate.batch_num_nodes
: the total number of nodes in the minibatchstate.batch_num_targets
: the total number of ground truth values in the minibatch
- on_training_epoch_end(state: pydgn.training.event.state.State)
Initialize/reset some internal state at the end of a training epoch.
- Parameters
state (
State
) – object holding training information
- Post-condition:
- The following fields have been initialized:
state.epoch_loss
: a dictionary containing the aggregated loss value across all minibatchesstate.epoch_score
: a dictionary containing the aggregated score value across all minibatches
- on_training_epoch_start(state: pydgn.training.event.state.State)
Initialize/reset some internal state at the start of a training epoch.
- Parameters
state (
State
) – object holding training information
- Pre-condition:
- The following fields have been initialized:
state.set
: it must be set toTRAINING
event.state
- class pydgn.training.event.state.State(model, optimizer, device)
Bases:
object
Any object of this class contains training information that is handled and modified by a
TrainingEngine
as well as by the EventHandler objects implementing callbacks- Parameters
model (torch.nn.Module) – the model
optimizer (training.callback.optimizer.Optimizer) – the optimizer
device (str) – the device on which to run computations
- update(**values: dict)
The method sets new attributes or updates existing ones using the key,value pairs in
values
- Parameters
values – a dictionary of key,value pairs to store in the global state