predict

k_fold_predict(dataset, extractor, classifier, *, k, random_state=None, sampling_function, balance_sample_weights=True, log)

Run k-fold cross-validation on a dataset.

Parameters:
  • dataset (AnnotatedDataset) – The dataset to cross-validate.

  • extractor (BaseExtractor[TypeVar(F, bound= Shaped)]) – The extractor to use.

  • classifier (Classifier) – The classifier to use.

  • k (int) – The number of folds.

  • random_state (int | Generator | None, default: None) – The random state to use.

  • sampling_function (SamplingFunction) – The sampling function to use.

  • balance_sample_weights (bool, default: True) – Whether to balance sample weights.

  • log (Logger | None) – The logger to use.

Return type:

DatasetClassificationResult

Returns:

The dataset classification result.

predict(sampleable, classifier, extractor, *, encoding_function=None, categories=None, exclude=None, log=None)

Run classification on a sampleable object.

Parameters:
Return type:

ClassificationResult | GroupClassificationResult | DatasetClassificationResult

Returns:

The classification result.