callbacks

Callback class(es) for using during model training.

class graphnet.training.callbacks.PiecewiseLinearLR(optimizer, milestones, factors, last_epoch, verbose)[source]

Bases: _LRScheduler

Interpolate learning rate linearly between milestones.

Construct PiecewiseLinearLR.

For each milestone, denoting a specified number of steps, a factor multiplying the base learning rate is specified. For steps between two milestones, the learning rate is interpolated linearly between the two closest milestones. For steps before the first milestone, the factor for the first milestone is used; vice versa for steps after the last milestone.

Parameters:
  • optimizer (Optimizer) – Wrapped optimizer.

  • milestones (List[int]) – List of step indices. Must be increasing.

  • factors (List[float]) – List of multiplicative factors. Must be same length as milestones.

  • last_epoch (int, default: -1) – The index of the last epoch.

  • verbose (bool, default: False) – If True, prints a message to stdout for each update.

get_lr()[source]

Get effective learning rate(s) for each optimizer.

Return type:

List[float]

class graphnet.training.callbacks.ProgressBar(refresh_rate, process_position)[source]

Bases: TQDMProgressBar

Custom progress bar for graphnet.

Customises the default progress in pytorch-lightning.

Parameters:
  • refresh_rate (int) –

  • process_position (int) –

init_validation_tqdm()[source]

Override for customisation.

Return type:

Bar

init_predict_tqdm()[source]

Override for customisation.

Return type:

Bar

init_test_tqdm()[source]

Override for customisation.

Return type:

Bar

init_train_tqdm()[source]

Override for customisation.

Return type:

Bar

get_metrics(trainer, model)[source]

Override to not show the version number in the logging.

Return type:

Dict

Parameters:
  • trainer (Trainer) –

  • model (LightningModule) –

on_train_epoch_start(trainer, model)[source]

Print the results of the previous epoch on a separate line.

This allows the user to see the losses/metrics for previous epochs while the current is training. The default behaviour in pytorch- lightning is to overwrite the progress bar from previous epochs.

Return type:

None

Parameters:
  • trainer (Trainer) –

  • model (LightningModule) –

on_train_epoch_end(trainer, model)[source]

Log the final progress bar for the epoch to file.

Don’t duplciate to stdout.

Return type:

None

Parameters:
  • trainer (Trainer) –

  • model (LightningModule) –