utils

class Classifier(*args, **kwargs)

Bases: Protocol

Protocol for classifiers.

This protocol defines the methods that a classifier should implement.

See also

ClassifierMixin for the classifier interface in sklearn.

fit(*args, **kwargs)
Return type:

Self

get_params()
Return type:

dict[str, Any]

predict(*args, **kwargs)
Return type:

ndarray

predict_proba(*args, **kwargs)
Return type:

ndarray

fit_classifier(classifier, X, y, *, sample_weight=None, log=None)

Fit the given classifier to the given data.

Parameters:
  • classifier (Classifier) – The classifier to fit.

  • X (ndarray) – The feature matrix.

  • y (ndarray) – The target vector.

  • sample_weight (ndarray | None, default: None) – The sample weights.

  • log (Logger | None, default: None) – The logger to use.

init_new_classifier(classifier, random_state)

Initialize a new classifier with the same parameters as the given classifier.

Parameters:
  • classifier (Classifier) – The classifier to copy the parameters from.

  • random_state (int | Generator | None) – The random state to use for the new classifier.

Return type:

Classifier

to_predictions(y, y_proba, category_names, timestamps)

Convert the given predictions to a DataFrame.

Parameters:
  • y (ndarray) – The target vector.

  • y_proba (ndarray) – The predicted probabilities.

  • category_names (Iterable[str]) – The category names.

  • timestamps (ndarray) – The timestamps.

Return type:

DataFrame

validate_predictions(predictions, annotations, *, on='predictions', key_columns=('group', 'actor', 'recipient'))

Validate the predictions or annotations.

This calculates the mean and maximum probabilities for each predicted interval and retrieves the corresponding ground truth category as the category of the annotated interval with the highest overlap (on="predictions"), or correspondingly, the category of the predicted interval with the highest overlap (on="annotations").

Parameters:
  • predictions (DataFrame) – The predictions.

  • annotations (DataFrame) – The annotations.

  • on (Literal['predictions', 'annotations'], default: 'predictions') – The type of data to validate.

  • key_columns (Iterable[str], default: ('group', 'actor', 'recipient')) – The key columns.

Return type:

DataFrame

Returns:

The validated predictions or annotations.