predict¶
- k_fold_predict(dataset, extractor, classifier, *, k, random_state=None, sampling_function, balance_sample_weights=True, log)¶
Run k-fold cross-validation on a dataset.
- Parameters:
dataset (
AnnotatedDataset) – The dataset to cross-validate.extractor (
BaseExtractor[TypeVar(F, bound=Shaped)]) – The extractor to use.classifier (
Classifier) – The classifier to use.k (
int) – The number of folds.random_state (
int|Generator|None, default:None) – The random state to use.sampling_function (
SamplingFunction) – The sampling function to use.balance_sample_weights (
bool, default:True) – Whether to balance sample weights.log (
Logger|None) – The logger to use.
- Return type:
- Returns:
The dataset classification result.
- predict(sampleable, classifier, extractor, *, encoding_function=None, categories=None, exclude=None, log=None)¶
Run classification on a sampleable object.
- Parameters:
sampleable (
SampleableMixin) – The sampleable object to classify.classifier (
Classifier) – The classifier to use.extractor (
BaseExtractor[TypeVar(F, bound=Shaped)]) – The extractor to use.encoding_function (
EncodingFunction|None, default:None) – The encoding function to use.categories (
Iterable[str] |None, default:None) – The categories to use.exclude (
Iterable[str|int|tuple[str|int,str|int]] |None, default:None) – The identifiers to exclude.log (
Logger|None, default:None) – The logger to use.
- Return type:
ClassificationResult|GroupClassificationResult|DatasetClassificationResult- Returns:
The classification result.