weight_fitting¶
Classes for fitting per-event weights for training.
- class graphnet.training.weight_fitting.WeightFitter(database_path, truth_table, index_column)[source]¶
Bases:
ABC
,Logger
Produces per-event weights.
Weights are returned by the public method fit_weights(), and the weights can be saved as a table in the database.
Construct UniformWeightFitter.
- Parameters:
database_path (str) –
truth_table (str) –
index_column (str) –
- fit(bins, variable, weight_name, add_to_database, selection, transform, **kwargs)[source]¶
Fit weights.
Calls private _fit_weights method. Output is returned as a pandas.DataFrame and optionally saved to sql.
- Parameters:
bins (
ndarray
) – Desired bins used for fitting.variable (
str
) – the name of the variable. Must match corresponding column name in the truth table.weight_name (
Optional
[str
], default:None
) – Name of the weights.add_to_database (
bool
, default:False
) – If True, the weights are saved to sql in a table named weight_name.selection (
Optional
[List
[int
]], default:None
) – a list of event_no’s. If given, only events in the selection is used for fitting.transform (
Optional
[Callable
], default:None
) – A callable method that transform the variable into a desired space. E.g. np.log10 for energy. If given, fitting will happen in this space.**kwargs (
Any
) – Additional arguments passed to _fit_weights.
- Return type:
DataFrame
- Returns:
DataFrame that contains weights, event_nos.
- class graphnet.training.weight_fitting.Uniform(database_path, truth_table, index_column)[source]¶
Bases:
WeightFitter
Produces per-event weights making variable distribution uniform.
Construct UniformWeightFitter.
- Parameters:
database_path (str) –
truth_table (str) –
index_column (str) –
- class graphnet.training.weight_fitting.BjoernLow(database_path, truth_table, index_column)[source]¶
Bases:
WeightFitter
Produces per-event weights.
Events below x_low are weighted to be uniform, whereas events above x_low are weighted to follow a 1/(1+a*(x_low -x)) curve.
Construct UniformWeightFitter.
- Parameters:
database_path (str) –
truth_table (str) –
index_column (str) –