model

Base class(es) for building models.

class graphnet.models.model.Model(*args, **kwargs)[source]

Bases: Logger, Configurable, LightningModule, ABC

Base class for all models in graphnet.

Construct Logger.

Parameters:
  • args (Any) –

  • kwargs (Any) –

Return type:

object

abstract forward(x)[source]

Forward pass.

Return type:

Union[Tensor, Data]

Parameters:

x (Tensor | Data) –

fit(train_dataloader, val_dataloader, *, max_epochs, gpus, callbacks, ckpt_path, logger, log_every_n_steps, gradient_clip_val, distribution_strategy, **trainer_kwargs)[source]

Fit Model using pytorch_lightning.Trainer.

Return type:

None

Parameters:
  • train_dataloader (DataLoader) –

  • val_dataloader (DataLoader | None) –

  • max_epochs (int) –

  • gpus (List[int] | int | None) –

  • callbacks (List[Callback] | None) –

  • ckpt_path (str | None) –

  • logger (Logger | None) –

  • log_every_n_steps (int) –

  • gradient_clip_val (float | None) –

  • distribution_strategy (str | None) –

  • trainer_kwargs (Any) –

predict(dataloader, gpus, distribution_strategy)[source]

Return predictions for dataloader.

Returns a list of Tensors, one for each model output.

Return type:

List[Tensor]

Parameters:
  • dataloader (DataLoader) –

  • gpus (List[int] | int | None) –

  • distribution_strategy (str | None) –

predict_as_dataframe(dataloader, prediction_columns, *, additional_attributes, gpus, distribution_strategy)[source]

Return predictions for dataloader as a DataFrame.

Include additional_attributes as additional columns in the output DataFrame.

Return type:

DataFrame

Parameters:
  • dataloader (DataLoader) –

  • prediction_columns (List[str]) –

  • additional_attributes (List[str] | None) –

  • gpus (List[int] | int | None) –

  • distribution_strategy (str | None) –

save(path)[source]

Save entire model to path.

Return type:

None

Parameters:

path (str) –

classmethod load(path)[source]

Load entire model from path.

Return type:

Model

Parameters:

path (str) –

save_state_dict(path)[source]

Save model state_dict to path.

Return type:

None

Parameters:

path (str) –

load_state_dict(path, **kargs)[source]

Load model state_dict from path.

Return type:

Model

Parameters:
  • path (str | Dict) –

  • kargs (Any | None) –

classmethod from_config(source, trust, load_modules)[source]

Construct Model instance from source configuration.

Parameters:
  • trust (bool, default: False) – Whether to trust the ModelConfig file enough to eval(…) any lambda function expressions contained.

  • load_modules (Optional[List[str]], default: None) – List of modules used in the definition of the model which, as a consequence, need to be loaded into the global namespace. Defaults to loading torch.

  • source (ModelConfig | str) –

Raises:

ValueError – If the ModelConfig contains lambda functions but trust = False.

Return type:

Model