model¶
Base class(es) for building models.
- class graphnet.models.model.Model(*args, **kwargs)[source]¶
Bases:
Logger
,Configurable
,LightningModule
,ABC
Base class for all models in graphnet.
Construct Logger.
- Parameters:
args (Any) –
kwargs (Any) –
- Return type:
object
- abstract forward(x)[source]¶
Forward pass.
- Return type:
Union
[Tensor
,Data
]- Parameters:
x (Tensor | Data) –
- fit(train_dataloader, val_dataloader, *, max_epochs, gpus, callbacks, ckpt_path, logger, log_every_n_steps, gradient_clip_val, distribution_strategy, **trainer_kwargs)[source]¶
Fit Model using pytorch_lightning.Trainer.
- Return type:
None
- Parameters:
train_dataloader (DataLoader) –
val_dataloader (DataLoader | None) –
max_epochs (int) –
gpus (List[int] | int | None) –
callbacks (List[Callback] | None) –
ckpt_path (str | None) –
logger (Logger | None) –
log_every_n_steps (int) –
gradient_clip_val (float | None) –
distribution_strategy (str | None) –
trainer_kwargs (Any) –
- predict(dataloader, gpus, distribution_strategy)[source]¶
Return predictions for dataloader.
Returns a list of Tensors, one for each model output.
- Return type:
List
[Tensor
]- Parameters:
dataloader (DataLoader) –
gpus (List[int] | int | None) –
distribution_strategy (str | None) –
- predict_as_dataframe(dataloader, prediction_columns, *, additional_attributes, gpus, distribution_strategy)[source]¶
Return predictions for dataloader as a DataFrame.
Include additional_attributes as additional columns in the output DataFrame.
- Return type:
DataFrame
- Parameters:
dataloader (DataLoader) –
prediction_columns (List[str]) –
additional_attributes (List[str] | None) –
gpus (List[int] | int | None) –
distribution_strategy (str | None) –
- save_state_dict(path)[source]¶
Save model state_dict to path.
- Return type:
None
- Parameters:
path (str) –
- load_state_dict(path, **kargs)[source]¶
Load model state_dict from path.
- Return type:
- Parameters:
path (str | Dict) –
kargs (Any | None) –
- classmethod from_config(source, trust, load_modules)[source]¶
Construct Model instance from source configuration.
- Parameters:
trust (
bool
, default:False
) – Whether to trust the ModelConfig file enough to eval(…) any lambda function expressions contained.load_modules (
Optional
[List
[str
]], default:None
) – List of modules used in the definition of the model which, as a consequence, need to be loaded into the global namespace. Defaults to loading torch.source (ModelConfig | str) –
- Raises:
ValueError – If the ModelConfig contains lambda functions but trust = False.
- Return type: