utils

Utility functions for graphnet.training.

graphnet.training.utils.collate_fn(graphs)[source]

Remove graphs with less than two DOM hits.

Should not occur in “production”.

Return type:

Batch

Parameters:

graphs (List[Data]) –

graphnet.training.utils.make_dataloader(db, pulsemaps, graph_definition, features, truth, *, batch_size, shuffle, selection, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_table, loss_weight_column, index_column, labels)[source]

Construct DataLoader instance.

Return type:

DataLoader

Parameters:
  • db (str) –

  • pulsemaps (str | List[str]) –

  • graph_definition (GraphDefinition) –

  • features (List[str]) –

  • truth (List[str]) –

  • batch_size (int) –

  • shuffle (bool) –

  • selection (List[int] | None) –

  • num_workers (int) –

  • persistent_workers (bool) –

  • node_truth (List[str] | None) –

  • truth_table (str) –

  • node_truth_table (str | None) –

  • string_selection (List[int] | None) –

  • loss_weight_table (str | None) –

  • loss_weight_column (str | None) –

  • index_column (str) –

  • labels (Dict[str, Callable] | None) –

graphnet.training.utils.make_train_validation_dataloader(db, graph_definition, selection, pulsemaps, features, truth, *, batch_size, database_indices, seed, test_size, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_column, loss_weight_table, index_column, labels)[source]

Construct train and test DataLoader instances.

Return type:

Tuple[DataLoader, DataLoader]

Parameters:
  • db (str) –

  • graph_definition (GraphDefinition) –

  • selection (List[int] | None) –

  • pulsemaps (str | List[str]) –

  • features (List[str]) –

  • truth (List[str]) –

  • batch_size (int) –

  • database_indices (List[int] | None) –

  • seed (int) –

  • test_size (float) –

  • num_workers (int) –

  • persistent_workers (bool) –

  • node_truth (str | None) –

  • truth_table (str) –

  • node_truth_table (str | None) –

  • string_selection (List[int] | None) –

  • loss_weight_column (str | None) –

  • loss_weight_table (str | None) –

  • index_column (str) –

  • labels (Dict[str, Callable] | None) –

graphnet.training.utils.get_predictions(trainer, model, dataloader, prediction_columns, *, node_level, additional_attributes)[source]

Get model predictions on dataloader.

Return type:

DataFrame

Parameters:
  • trainer (Trainer) –

  • model (Model) –

  • dataloader (DataLoader) –

  • prediction_columns (List[str]) –

  • node_level (bool) –

  • additional_attributes (List[str] | None) –

graphnet.training.utils.save_results(db, tag, results, archive, model)[source]

Save trained model and prediction results in db.

Return type:

None

Parameters:
  • db (str) –

  • tag (str) –

  • results (DataFrame) –

  • archive (str) –

  • model (Model) –