utils¶
Utility functions for graphnet.training.
- graphnet.training.utils.collate_fn(graphs)[source]¶
Remove graphs with less than two DOM hits.
Should not occur in “production”.
- Return type:
Batch
- Parameters:
graphs (List[Data]) –
- graphnet.training.utils.make_dataloader(db, pulsemaps, graph_definition, features, truth, *, batch_size, shuffle, selection, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_table, loss_weight_column, index_column, labels)[source]¶
Construct DataLoader instance.
- Return type:
DataLoader
- Parameters:
db (str) –
pulsemaps (str | List[str]) –
graph_definition (GraphDefinition) –
features (List[str]) –
truth (List[str]) –
batch_size (int) –
shuffle (bool) –
selection (List[int] | None) –
num_workers (int) –
persistent_workers (bool) –
node_truth (List[str] | None) –
truth_table (str) –
node_truth_table (str | None) –
string_selection (List[int] | None) –
loss_weight_table (str | None) –
loss_weight_column (str | None) –
index_column (str) –
labels (Dict[str, Callable] | None) –
- graphnet.training.utils.make_train_validation_dataloader(db, graph_definition, selection, pulsemaps, features, truth, *, batch_size, database_indices, seed, test_size, num_workers, persistent_workers, node_truth, truth_table, node_truth_table, string_selection, loss_weight_column, loss_weight_table, index_column, labels)[source]¶
Construct train and test DataLoader instances.
- Return type:
Tuple
[DataLoader
,DataLoader
]- Parameters:
db (str) –
graph_definition (GraphDefinition) –
selection (List[int] | None) –
pulsemaps (str | List[str]) –
features (List[str]) –
truth (List[str]) –
batch_size (int) –
database_indices (List[int] | None) –
seed (int) –
test_size (float) –
num_workers (int) –
persistent_workers (bool) –
node_truth (str | None) –
truth_table (str) –
node_truth_table (str | None) –
string_selection (List[int] | None) –
loss_weight_column (str | None) –
loss_weight_table (str | None) –
index_column (str) –
labels (Dict[str, Callable] | None) –
- graphnet.training.utils.get_predictions(trainer, model, dataloader, prediction_columns, *, node_level, additional_attributes)[source]¶
Get model predictions on dataloader.
- Return type:
DataFrame
- Parameters:
trainer (Trainer) –
model (Model) –
dataloader (DataLoader) –
prediction_columns (List[str]) –
node_level (bool) –
additional_attributes (List[str] | None) –