Results¶
[1]:
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline
[2]:
import os
from functools import partial
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalize
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.impute import KNNImputer
from sklearn.pipeline import Pipeline
from helpers import smooth_model_outputs, subsample_train, score_priority
from sklearn.utils.class_weight import compute_sample_weight
from xgboost import XGBClassifier
from vassi.classification import (
plot_confusion_matrix,
predict,
)
from vassi.config import cfg
from vassi.features import DataFrameFeatureExtractor
from vassi.io import from_yaml, load_dataset, to_cache, from_cache, save_data, load_data
from vassi.classification.results import DatasetClassificationResult
from vassi.utils import Experiment
from vassi.sliding_metrics import (
SlidingWindowAggregator,
get_window_slices,
metrics,
)
import vassi._manuscript_utils as manuscript_utils
from vassi import visualization as vis
from interactive_table import Table
[3]:
cfg.key_keypoints = "pose"
cfg.key_timestamp = "time_stamp"
cfg.trajectory_keys = ("pose", "time_stamp")
Example train and test run¶
[4]:
# load training and test datasets
dataset_full = load_dataset(
"cichlids",
directory="../../datasets/social_cichlids",
target="dyad",
background_category="none",
)
# fixed random_state, the train test split always uses random_state=1
dataset_train, dataset_test = dataset_full.split(
0.8,
random_state=1,
)
# initialize feature extraction pipeline
time_scales, slices = get_window_slices(3, time_scales=(91,))
aggregator = ColumnTransformer(
[
(
"aggregate",
SlidingWindowAggregator(
[metrics.median, metrics.q10, metrics.q90], max(time_scales), slices
),
make_column_selector(),
),
("original", "passthrough", make_column_selector()),
],
)
pipeline = Pipeline(
[("impute", KNNImputer()), ("aggregate", aggregator)]
).set_output(transform="pandas")
extractor = DataFrameFeatureExtractor(
cache_directory="/media/paul/Data2/cichlids_cache",
pipeline=pipeline,
).read_yaml("config_file-cichlids.yaml")
# and load optimized postprocessing parameters
best_parameters = from_yaml("optimization/optimization-summary.yaml")
priority_function = partial(
score_priority,
weight_max_probability=best_parameters["weight_max_probability"],
weight_mean_probability=1 - best_parameters["weight_max_probability"],
)
best_thresholds = [best_parameters[f"threshold-{category}"] for category in dataset_test.categories]
# optionally, run a full example run with training and predictions
run_example_run = False
if run_example_run:
# set a fixed random state for reproducible results
random_state = np.random.default_rng(1)
# subsample dataset and encode target from string to numeric
X, y = subsample_train(
dataset_train,
extractor,
random_state=random_state,
log=None,
)
y = dataset_train.encode(y)
# specify and fit classification model
classifier = XGBClassifier(n_estimators=1000, random_state=random_state).fit(
X.to_numpy(), y, sample_weight=compute_sample_weight("balanced", y)
)
# use model for predictions on the test dataset
test_result = predict(
dataset_test,
classifier,
extractor,
log=None,
)
# postprocessing and f1 scores
f1_scores = (
test_result
.smooth(partial(smooth_model_outputs, best_parameters), decision_thresholds=best_thresholds)
.remove_overlapping_predictions(
priority_function=priority_function,
prefilter_recipient_bouts=best_parameters["prefilter_recipient_bouts"],
max_bout_gap=best_parameters["max_bout_gap"],
max_allowed_bout_overlap=best_parameters["max_allowed_bout_overlap"],
)
.score()
)
2025-06-24 15:19:27.403 [WARNING ] Loading categories (approach, chase, dart_bite, frontal_display, lateral_display, none, quiver) from observations file, specify categories argument if incomplete.
2025-06-24 15:19:41.475 [WARNING ] Time scales adjusted to match num_windows_per_scale: (91,) -> [93].
[5]:
## execute the following code to pre-compute and cache the features for all dyads in the dataset
# from tqdm.auto import tqdm
# for sampleable in tqdm([sampleable for _, group in dataset_full for _, sampleable in group]):
# _ = sampleable.sample_X(extractor)
[6]:
try:
X_example_train, y_example_train = from_cache("example_samples.pkl")
except FileNotFoundError:
X_example_train, y_example_train = subsample_train(dataset_train, extractor, random_state=1, log=None)
to_cache([X_example_train, y_example_train], "example_samples.pkl")
[7]:
sample_counts = pd.concat(
(
(
dataset_full
.observations.groupby("category")
.aggregate(samples_full=("duration", "sum"))
),
(
dataset_train
.observations.groupby("category")
.aggregate(samples_train=("duration", "sum"))
),
(
dataset_test
.observations.groupby("category")
.aggregate(samples_test=("duration", "sum"))
),
pd.DataFrame({"subsampled_train": pd.Series(*np.unique(y_example_train, return_counts=True)[::-1])}),
),
axis=1,
)
[8]:
sample_counts = pd.concat(
[
sample_counts,
(sample_counts / sample_counts.sum(axis=0) * 100).rename(columns=lambda column: f"{column} %"),
],
axis=1,
).iloc[:, [0, 4, 1, 5, 3, 7, 2, 6]]
sample_counts
[8]:
| samples_full | samples_full % | samples_train | samples_train % | subsampled_train | subsampled_train % | samples_test | samples_test % | |
|---|---|---|---|---|---|---|---|---|
| approach | 7699 | 0.030793 | 4769 | 0.023913 | 4769 | 1.852701 | 2930 | 0.057916 |
| chase | 7412 | 0.029646 | 5546 | 0.027809 | 5546 | 2.154556 | 1866 | 0.036885 |
| dart_bite | 12318 | 0.049268 | 9840 | 0.049340 | 9810 | 3.811070 | 2478 | 0.048982 |
| frontal_display | 39569 | 0.158263 | 31663 | 0.158767 | 7915 | 3.074885 | 7906 | 0.156275 |
| lateral_display | 14179 | 0.056711 | 10946 | 0.054886 | 10946 | 4.252393 | 3233 | 0.063906 |
| none | 24910846 | 99.635110 | 19870951 | 99.638431 | 209078 | 81.224360 | 5039895 | 99.622021 |
| quiver | 10053 | 0.040209 | 9344 | 0.046853 | 9344 | 3.630035 | 709 | 0.014015 |
[9]:
sample_counts.to_csv("cichlid-samples.csv", index=False)
[42]:
print(
f"""full dataset: {
len(
[
(group_identifier, dyad_identifier)
for group_identifier, group in dataset_full
for dyad_identifier, _ in group
]
)
} dyads
dataset train: {
len(
[
(group_identifier, dyad_identifier)
for group_identifier, group in dataset_train
for dyad_identifier, _ in group
]
)
} dyads
dataset test: {
len(
[
(group_identifier, dyad_identifier)
for group_identifier, group in dataset_test
for dyad_identifier, _ in group
]
)
} dyads"""
)
full dataset: 1862 dyads
dataset train: 1487 dyads
dataset test: 375 dyads
[47]:
{
group_identifier: len(group.individuals)
for group_identifier, group in dataset_full
}
[47]:
{'GH010423': 15,
'GH010861': 15,
'GH013974': 14,
'GH019910': 15,
'GH030423': 15,
'GH030451': 15,
'GH030861': 15,
'GH039910': 15,
'GH039931': 15}
Evaluation¶
Uncomment the code in the following cell to run 20 training and test runs with different random states.
Note that this takes quite a long time. The evaluation-cichlids.py script allows to run the same code in parallel using MPI on appropriate hardware (e.g., a high-performance computing cluster).
[10]:
# from tqdm.auto import tqdm
# cache_directory = "/media/paul/Data1/samples_cache"
# experiment = Experiment(20, random_state=1)
# for run in tqdm(experiment, total=experiment.num_runs):
# # X, y = subsample_train(
# # dataset_train,
# # extractor,
# # random_state=experiment.random_state,
# # log=None,
# # )
# # y = dataset_train.encode(y)
# # classifier = XGBClassifier(n_estimators=1000, random_state=experiment.random_state).fit(
# # X.to_numpy(), y, sample_weight=compute_sample_weight("balanced", y)
# # )
# classifier = from_cache(os.path.join(cache_directory, f"clf_{run:02d}.cache"))
# summary = []
# y = {"true": {}, "pred": {}}
# test_result = (
# predict(dataset_test, classifier, extractor, log=None)
# .remove_overlapping_predictions(
# priority_function=priority_function,
# prefilter_recipient_bouts=best_parameters["prefilter_recipient_bouts"],
# max_bout_gap=best_parameters["max_bout_gap"],
# max_allowed_bout_overlap=best_parameters["max_allowed_bout_overlap"],
# )
# )
# summary.append(
# manuscript_utils.summarize_scores(
# test_result,
# foreground_categories=dataset_test.foreground_categories,
# run=run,
# postprocessing_step="model_outputs",
# )
# )
# test_result = (
# test_result.smooth(partial(smooth_model_outputs, best_parameters))
# .remove_overlapping_predictions(
# priority_function=priority_function,
# prefilter_recipient_bouts=best_parameters["prefilter_recipient_bouts"],
# max_bout_gap=best_parameters["max_bout_gap"],
# max_allowed_bout_overlap=best_parameters["max_allowed_bout_overlap"],
# )
# )
# summary.append(
# manuscript_utils.summarize_scores(
# test_result,
# foreground_categories=dataset_test.foreground_categories,
# run=run,
# postprocessing_step="smoothed",
# )
# )
# test_result = (
# test_result.threshold(best_thresholds, default_decision="none")
# .remove_overlapping_predictions(
# priority_function=priority_function,
# prefilter_recipient_bouts=best_parameters["prefilter_recipient_bouts"],
# max_bout_gap=best_parameters["max_bout_gap"],
# max_allowed_bout_overlap=best_parameters["max_allowed_bout_overlap"],
# )
# )
# summary.append(
# manuscript_utils.summarize_scores(
# test_result,
# foreground_categories=dataset_test.foreground_categories,
# run=run,
# postprocessing_step="thresholded",
# )
# )
# summary = pd.concat(summary, ignore_index=True)
# y["true"]["timestamp"] = test_result.y_true_numeric
# y["pred"]["timestamp"] = test_result.y_pred_numeric
# y["true"]["annotation"] = dataset_test.encode(test_result.annotations["category"].to_numpy())
# y["pred"]["annotation"] = dataset_test.encode(test_result.annotations["predicted_category"].to_numpy())
# y["true"]["prediction"] = dataset_test.encode(test_result.predictions["true_category"].to_numpy())
# y["pred"]["prediction"] = dataset_test.encode(test_result.predictions["category"].to_numpy())
# experiment.add((summary, y))
# summary = pd.concat([summary for summary, _ in experiment.collect().values()], ignore_index=True)
# confusion = [y for _, y in experiment.collect().values()]
# for run, confusion_data in enumerate(confusion):
# save_data("results.h5", confusion_data["true"], os.path.join(f"run_{run:02d}", "true"))
# save_data("results.h5", confusion_data["pred"], os.path.join(f"run_{run:02d}", "pred"))
# save_data("results.h5", {"runs": np.array([f"run_{run:02d}" for run in range(len(confusion))])})
# summary.to_hdf("results.h5", key="summary")
# test_result.to_h5("results.h5", dataset_name="test_dataset")
[11]:
# instead, just load the results
summary = pd.read_hdf("results.h5", key="summary")
confusion = [load_data("results.h5", run) for run in load_data("results.h5", "runs")]
test_result = DatasetClassificationResult.from_h5("results.h5", dataset_name="test_dataset")
Interaction networks and timeline plot¶
[12]:
# shorter category labels if necessary
category_labels = ["appr", "chase", "bite", "front", "lat", "none", "quiv"]
# corresponding colors
colors = ["#66C2A5", "#B84D3D", "#FC8D62", "#4C6688", "#8DA0CB", "#dddddd", "#EEA5CC"]
# and additional keyword arguments for all boxes in the following visualizations
box_kwargs = {"lw": 0, "joinstyle": "round", "alpha": 0.5}
[13]:
from vassi.classification.predict import k_fold_predict
from vassi.logging import set_logging_level
from vassi.io import save_dataset
from tqdm.auto import tqdm
run_k_fold = False
if run_k_fold:
# this takes ~2h to compute features for all dyads
for sampleable in tqdm([sampleable for _, group in dataset_full for _, sampleable in group]):
sampleable.sample_X(extractor)
# but makes this quicker in return (still takes 6h to finish)
k_fold_result = k_fold_predict(
dataset_full,
extractor,
XGBClassifier(n_estimators=1000),
k=5,
random_state=1, # fixed for reproducibility
sampling_function=subsample_train,
balance_sample_weights=True,
log=set_logging_level("info"),
)
# postprocessing with optimal parameters
dataset_cv = (
k_fold_result
.smooth(partial(smooth_model_outputs, best_parameters), decision_thresholds=best_thresholds)
.remove_overlapping_predictions(
priority_function=priority_function,
prefilter_recipient_bouts=best_parameters["prefilter_recipient_bouts"],
max_bout_gap=best_parameters["max_bout_gap"],
max_allowed_bout_overlap=best_parameters["max_allowed_bout_overlap"],
)
.to_dataset(
trajectories={identifier: group.trajectories for identifier, group in dataset_full},
background_category="none",
)
)
# save to datasets directory
# this is also available in the data repository, examples/social_cichlids/k_fold_predictions_predictions.csv
save_dataset(
dataset_cv,
dataset_name="k_fold_predictions",
directory="../../datasets/social_cichlids",
observation_suffix="predictions",
)
[14]:
# use the following two bash commands to "create" this dataset instead, assuming you have downloaded
# k_fold_predictions_predictions.csv from the datarepository and saved it in the directory of this notebook.
# note that you also need to download the dataset itself for the trajectories
!cp ../../datasets/social_cichlids/cichlids_trajectories.h5 ../../datasets/social_cichlids/k_fold_predictions_trajectories.h5
!cp k_fold_predictions_predictions.csv ../../datasets/social_cichlids/
[15]:
# load the cross-validation predictions (saved as a dataset) instead of running the "run_k_fold" code above
dataset_cv = load_dataset(
"k_fold_predictions",
directory="../../datasets/social_cichlids",
target="dyad",
observation_suffix="predictions",
background_category="none",
)
2025-06-24 15:19:53.370 [WARNING ] Loading categories (approach, chase, dart_bite, frontal_display, lateral_display, none, quiver) from observations file, specify categories argument if incomplete.
[16]:
# dyadic interaction matrices for one group, both annotated and predicted
group_annotated = dataset_full.select("GH030861")
group_predicted = dataset_cv.select("GH030861")
locations = np.array(
[
np.mean(group_annotated.trajectories[individual][cfg.key_keypoints], axis=(0, 1))
for individual in group_annotated.individuals
]
)
locations = manuscript_utils.adjust_node_positions_repulsion_vectorized(locations, min_distance=200, step=1)
interaction_matrices_annotated = manuscript_utils.dyadic_interactions(group_annotated, kind="count")
counts_annotated = np.array(list(interaction_matrices_annotated.values())).ravel()
counts_annotated = counts_annotated[counts_annotated > 0]
interaction_matrices_predicted = manuscript_utils.dyadic_interactions(group_predicted, kind="count")
counts_predicted = np.array(list(interaction_matrices_predicted.values())).ravel()
counts_predicted = counts_predicted[counts_predicted > 0]
counts = np.concatenate([counts_annotated, counts_predicted])
observations_group = dataset_test.select("GH030861").observations
observations_group = observations_group[observations_group["category"] != dataset_test.background_category]
cmap = LinearSegmentedColormap.from_list("gray_to_black", ["lightgray", "black"])
norm = Normalize(vmin=np.log(counts).min(), vmax=np.log(counts).max())
actor_counts = observations_group.groupby("actor", as_index=False).aggregate(count=("recipient", "count")).sort_values("count", ascending=False)
actor = actor_counts["actor"].iloc[0]
recipient_counts = observations_group[observations_group["actor"] == actor].groupby("recipient", as_index=False).aggregate(count=("recipient", "count")).sort_values("count", ascending=False)
recipients = recipient_counts["recipient"].tolist()[:3]
[17]:
num_steps = 3
num_categories = len(dataset_test.categories)
figsize = (9, 6.5)
fig = plt.figure(figsize=figsize)
panel = vis.Panel(*figsize, extent=(0, 0, *figsize))
top, middle = panel.divide(sizes_absolute=[0, 2.5], orientation="vertical", spacing_absolute=1)
panels_timeline = top.divide(sizes=[1] * len(dataset_test.foreground_categories), spacing_absolute=0.15, orientation="vertical")
row_1, row_2 = middle.divide(sizes=[1, 1], orientation="vertical", spacing=0.1)
panels_row_1 = row_1.divide(sizes=[1] * len(group_annotated.foreground_categories), spacing=0.025, orientation="horizontal")
panels_row_2 = row_2.divide(sizes=[1] * len(group_annotated.foreground_categories), spacing=0.025, orientation="horizontal")
labels = {0: "A", 1: "B", 2: "C"}
axes_timeline = [panel.get_ax(fig, label=labels.pop(0, None)) for panel in panels_timeline]
axes_annotated = [panel.get_ax(fig, label=labels.pop(1, None), spines=[False] * 4) for panel in panels_row_1]
axes_predicted = [panel.get_ax(fig, label=labels.pop(2, None), spines=[False] * 4) for panel in panels_row_2]
foreground = np.isin(dataset_test.categories, dataset_test.foreground_categories)
for idx, recipient in enumerate(recipients):
dyad_results = test_result.classification_results["GH030861"].classification_results[(actor, recipient)] # GH030861: 6, 9
manuscript_utils.plot_classification_timeline_multiple(
dyad_results.predictions,
dataset_test.foreground_categories,
annotations=dyad_results.annotations,
timestamps=dyad_results.timestamps,
y_proba_smoothed=dyad_results.y_proba_smoothed[:, foreground],
interval=(-np.inf, np.inf),
x_tick_step=(30 * 60),
x_tick_conversion=lambda ticks: (np.asarray(ticks) / (30 * 60)).astype(int),
y_offset=idx / 6,
x_offset=idx * 120,
zorder=-idx,
axes=axes_timeline,
)
for label, ax, color in zip(
np.array(category_labels)[foreground],
axes_timeline,
np.array(colors)[foreground],
):
ax.set_ylabel(None)
vis.add_ytick_box(0.5 + 3 / (6 * 2), 2, ax, color=color, text=label, offset_in_inches=0.25, **box_kwargs)
ax.set_xlabel("Time (min)")
####
for idx, category in enumerate(group_annotated.foreground_categories):
manuscript_utils.draw_network(
axes_annotated[idx], interaction_matrices_annotated[category], locations, cmap, norm,
fc=[("black" if ind == actor else ("lightgray" if ind in recipients else "white")) for ind in group_annotated.individuals],
)
manuscript_utils.draw_network(
axes_predicted[idx], interaction_matrices_predicted[category], locations, cmap, norm,
fc=[("black" if ind == actor else ("lightgray" if ind in recipients else "white")) for ind in group_annotated.individuals],
)
axes_annotated[idx].set_xticks([])
axes_predicted[idx].set_xticks([])
axes_annotated[idx].set_yticks([])
axes_predicted[idx].set_yticks([])
if idx > 0:
axes_annotated[idx].sharex(axes_annotated[0])
axes_predicted[idx].sharex(axes_predicted[0])
axes_annotated[idx].sharey(axes_annotated[0])
axes_predicted[idx].sharey(axes_predicted[0])
axes_annotated[0].set_ylabel("Annotations", labelpad=15)
axes_predicted[0].set_ylabel("Predictions", labelpad=15)
for ax, color, category in zip(
axes_annotated,
["#66C2A5", "#B84D3D", "#FC8D62", "#4C6688", "#8DA0CB", "#EEA5CC"],
group_annotated.foreground_categories,
):
if category == "dart_bite":
category = "bite/dart"
category = category.replace("_", " ")
vis.add_xtick_box(np.mean(ax.get_xlim()), np.diff(ax.get_xlim())[0] - 400, ax, offset_in_inches=0.05, y="top", text=category, color=color, **box_kwargs)
Prediction-based F1 scores¶
[18]:
scores = manuscript_utils.aggregate_scores(summary, "prediction_f1", categories=dataset_test.categories)
num_steps = 3
num_categories = len(dataset_test.categories)
figsize = (5.25, 5.75)
fig = plt.figure(figsize=figsize)
panel = vis.Panel(*figsize, extent=(0, 0, *figsize))
top, bottom = panel.divide(sizes_absolute=[3, 0], spacing_absolute=1.25, orientation="vertical")
panels_top = top.divide(sizes_absolute=[3, 0], spacing_absolute=1.25, orientation="horizontal")
ax_confusion = panels_top[0].get_ax(fig, label="A", spines=[True] * 4)
right = panels_top[1].divide(sizes_absolute=[0.5, 2, 0], spacing=0, orientation="vertical")
right[0].get_ax(fig, label="B").axis("off") # emtpy axes aligned with confusion matrix
ax_macro_f1 = right[1].get_ax(fig) # actual axes for errorbars
ax_category_f1 = bottom.get_ax(fig, label="C")
####
plot_confusion_matrix(
np.array([y["pred"]["prediction"] for y in confusion], dtype=object),
np.array([y["true"]["prediction"] for y in confusion], dtype=object),
show_colorbar=False,
ax=ax_confusion,
)
ax_confusion.set_ylabel("Predicted intervals", labelpad=25)
ax_confusion.set_xlabel("Annotated category\n(max overlapping)", labelpad=25)
vis.add_xtick_boxes(range(len(category_labels)), 0.975, ax_confusion, labels=category_labels, colors=colors, **box_kwargs)
vis.add_ytick_boxes(range(len(category_labels)), 0.975, ax_confusion, labels=category_labels, colors=colors, **box_kwargs)
####
manuscript_utils.plot_errorbars(
ax_macro_f1,
*scores["macro-all"].to_numpy().T,
ylabel="'Prediction' macro F1",
)
x = [x for idx in range(num_categories) for x in np.arange(num_steps) + idx * num_categories / 1.75]
means, stds = pd.concat([scores[category] for category in dataset_test.categories]).to_numpy().T
manuscript_utils.plot_errorbars(
ax_category_f1,
means,
stds,
x=x,
padding=1,
xticklabels=("model", "smooth", "thresh") * num_categories,
ylabel="'Prediction' category F1",
)
for (idx, category), color in zip(enumerate(category_labels), colors):
x_left = x[num_steps * idx] - 0.5
x_right = x[num_steps * (idx + 1) - 1] + 0.625
ax_category_f1.axvspan(x_left, x_right, color=color, alpha=0.25, lw=0)
vis.add_xtick_box(
(x_left + x_right) / 2,
x_right - x_left,
ax_category_f1,
y="top",
text=category,
color=color,
**box_kwargs,
)
Timestamp-based F1 scores¶
[19]:
scores = manuscript_utils.aggregate_scores(summary, "timestamp_f1", categories=dataset_test.categories)
num_steps = 3
num_categories = len(dataset_test.categories)
figsize = (5.25, 5.75)
fig = plt.figure(figsize=figsize)
panel = vis.Panel(*figsize, extent=(0, 0, *figsize))
top, bottom = panel.divide(sizes_absolute=[3, 0], spacing_absolute=1.25, orientation="vertical")
panels_top = top.divide(sizes_absolute=[3, 0], spacing_absolute=1.25, orientation="horizontal")
ax_confusion = panels_top[0].get_ax(fig, label="A", spines=[True] * 4)
right = panels_top[1].divide(sizes_absolute=[0.5, 2, 0], spacing=0, orientation="vertical")
right[0].get_ax(fig, label="B").axis("off") # emtpy axes aligned with confusion matrix
ax_macro_f1 = right[1].get_ax(fig) # actual axes for errorbars
ax_category_f1 = bottom.get_ax(fig, label="C")
####
plot_confusion_matrix(
[y["true"]["timestamp"] for y in confusion],
[y["pred"]["timestamp"] for y in confusion],
show_colorbar=False,
ax=ax_confusion,
)
ax_confusion.set_xlabel("Predicted frames", labelpad=25)
ax_confusion.set_ylabel("Annotated frames", labelpad=25)
vis.add_xtick_boxes(range(len(category_labels)), 0.975, ax_confusion, labels=category_labels, colors=colors, **box_kwargs)
vis.add_ytick_boxes(range(len(category_labels)), 0.975, ax_confusion, labels=category_labels, colors=colors, **box_kwargs)
####
manuscript_utils.plot_errorbars(
ax_macro_f1,
*scores["macro-all"].to_numpy().T,
ylabel="'Frame' macro F1",
)
x = [x for idx in range(num_categories) for x in np.arange(num_steps) + idx * num_categories / 1.75]
means, stds = pd.concat([scores[category] for category in dataset_test.categories]).to_numpy().T
manuscript_utils.plot_errorbars(
ax_category_f1,
means,
stds,
x=x,
padding=1,
xticklabels=("model", "smooth", "thresh") * num_categories,
ylabel="'Frame' category F1",
)
for (idx, category), color in zip(enumerate(category_labels), colors):
x_left = x[num_steps * idx] - 0.5
x_right = x[num_steps * (idx + 1) - 1] + 0.625
ax_category_f1.axvspan(x_left, x_right, color=color, alpha=0.25, lw=0)
vis.add_xtick_box(
(x_left + x_right) / 2,
x_right - x_left,
ax_category_f1,
y="top",
text=category,
color=color,
**box_kwargs,
)
Annotation-based F1 scores¶
[20]:
scores = manuscript_utils.aggregate_scores(summary, "annotation_f1", categories=dataset_test.categories)
num_steps = 3
num_categories = len(dataset_test.categories)
figsize = (5.25, 5.75)
fig = plt.figure(figsize=figsize)
panel = vis.Panel(*figsize, extent=(0, 0, *figsize))
top, bottom = panel.divide(sizes_absolute=[3, 0], spacing_absolute=1.25, orientation="vertical")
panels_top = top.divide(sizes_absolute=[3, 0], spacing_absolute=1.25, orientation="horizontal")
ax_confusion = panels_top[0].get_ax(fig, label="A", spines=[True] * 4)
right = panels_top[1].divide(sizes_absolute=[0.5, 2, 0], spacing=0, orientation="vertical")
right[0].get_ax(fig, label="B").axis("off") # emtpy axes aligned with confusion matrix
ax_macro_f1 = right[1].get_ax(fig) # actual axes for errorbars
ax_category_f1 = bottom.get_ax(fig, label="C")
####
plot_confusion_matrix(
[y["true"]["annotation"] for y in confusion],
[y["pred"]["annotation"] for y in confusion],
show_colorbar=False,
ax=ax_confusion,
)
ax_confusion.set_xlabel("Predicted category\n(max overlapping)", labelpad=25)
ax_confusion.set_ylabel("Annotated intervals", labelpad=25)
vis.add_xtick_boxes(range(len(category_labels)), 0.975, ax_confusion, labels=category_labels, colors=colors, **box_kwargs)
vis.add_ytick_boxes(range(len(category_labels)), 0.975, ax_confusion, labels=category_labels, colors=colors, **box_kwargs)
####
manuscript_utils.plot_errorbars(
ax_macro_f1,
*scores["macro-all"].to_numpy().T,
ylabel="'Annotation' macro F1",
)
x = [x for idx in range(num_categories) for x in np.arange(num_steps) + idx * num_categories / 1.75]
means, stds = pd.concat([scores[category] for category in dataset_test.categories]).to_numpy().T
manuscript_utils.plot_errorbars(
ax_category_f1,
means,
stds,
x=x,
padding=1,
xticklabels=("model", "smooth", "thresh") * num_categories,
ylabel="'Annotation' category F1",
)
for (idx, category), color in zip(enumerate(category_labels), colors):
x_left = x[num_steps * idx] - 0.5
x_right = x[num_steps * (idx + 1) - 1] + 0.625
ax_category_f1.axvspan(x_left, x_right, color=color, alpha=0.25, lw=0)
vis.add_xtick_box(
(x_left + x_right) / 2,
x_right - x_left,
ax_category_f1,
y="top",
text=category,
color=color,
**box_kwargs,
)
Predictions vs association¶
[21]:
annotations = dataset_test.observations
predictions = test_result.predictions
[22]:
dyads = [(group_identifier, *identifier) for group_identifier, group in dataset_test for identifier in group.identifiers]
[23]:
import pandas as pd
import matplotlib.pyplot as plt
from vassi import features
[24]:
average_body_length = np.mean(
[
features.keypoint_distances(trajectory, keypoints_1=(0, 1), keypoints_2=(1, 2), element_wise=True).sum(axis=1).mean()
for _, group in dataset_test
for trajectory in group.trajectories.values()
]
)
# just a one to one mapping, but could also group multiple categories
category_subsets = {
category: [category]
for category in dataset_test.foreground_categories
}
# three average body length factors for sensibility analysis for proximitry vs predicted counts
for distance_factor in [1, 5, 3]:
aggregated_counts = []
for idx, (subset, categories_subset) in enumerate(category_subsets.items()):
identifier_columns = ["group", "actor", "recipient"]
annotated_counts = pd.DataFrame(dyads, columns=identifier_columns)
annotated_counts["count"] = 0
annotated_counts["duration"] = 0
annotated_counts["duration_associated"] = 0
annotated_counts = annotated_counts.set_index(identifier_columns)
grouped = annotations.groupby(identifier_columns)
for dyad in dyads:
dyad_data = dataset_test.select(dyad[0]).select(dyad[1:])
distances = features.keypoint_distances(
dyad_data.trajectory,
trajectory_other=dyad_data.trajectory_other,
keypoints_1=(0, 1, 2),
keypoints_2=(0, 1, 2),
flat=True,
)
dyad_observations = grouped.get_group(dyad)
annotated_counts.loc[dyad, "count"] = np.isin(dyad_observations["category"], categories_subset).sum()
annotated_counts.loc[dyad, "duration"] = len(dataset_test.select(dyad[0]).select(dyad[1:]))
annotated_counts.loc[dyad, "duration_associated"] = np.sum(distances.min(axis=1) < distance_factor * average_body_length)
annotated_counts = annotated_counts.reset_index()
predicted_counts = pd.DataFrame(dyads, columns=identifier_columns)
predicted_counts["count"] = 0
predicted_counts = predicted_counts.set_index(identifier_columns)
grouped = predictions.groupby(identifier_columns)
for dyad in dyads:
dyad_observations = grouped.get_group(dyad)
predicted_counts.loc[dyad, "count"] = np.isin(dyad_observations["category"], categories_subset).sum()
predicted_counts
predicted_counts = predicted_counts.reset_index()
counts = pd.DataFrame(
{
"body_length": distance_factor,
"count_annotated": annotated_counts["count"],
"count_predicted": predicted_counts["count"],
"duration": annotated_counts["duration"],
"duration_associated": annotated_counts["duration_associated"],
},
)
counts[identifier_columns] = dyads
counts["categories_subset"] = subset
aggregated_counts.append(counts)
aggregated_counts = pd.concat(aggregated_counts)
aggregated_counts.to_csv(f"predictions_vs_association/aggregated_counts-{distance_factor}bl.csv", index=False)
[25]:
Table(aggregated_counts)
[25]:
[26]:
# this is the only place vassi uses seaborn, so just install it for this visualization (regplot)
!pip install seaborn
Requirement already satisfied: seaborn in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (0.13.2)
Requirement already satisfied: numpy!=1.24.0,>=1.20 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from seaborn) (2.2.6)
Requirement already satisfied: pandas>=1.2 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from seaborn) (2.3.0)
Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from seaborn) (3.10.3)
Requirement already satisfied: contourpy>=1.0.1 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.3.2)
Requirement already satisfied: cycler>=0.10 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.58.4)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.8)
Requirement already satisfied: packaging>=20.0 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (25.0)
Requirement already satisfied: pillow>=8 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (11.2.1)
Requirement already satisfied: pyparsing>=2.3.1 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.2.3)
Requirement already satisfied: python-dateutil>=2.7 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from pandas>=1.2->seaborn) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from pandas>=1.2->seaborn) (2025.2)
Requirement already satisfied: six>=1.5 in /home/paul/miniforge3/envs/vassi/lib/python3.13/site-packages (from python-dateutil>=2.7->matplotlib!=3.6.1,>=3.4->seaborn) (1.17.0)
[27]:
import seaborn as sns
figsize = (9, 2.6)
fig = plt.figure(figsize=figsize)
panel = vis.Panel(*figsize, extent=(0, 0, *figsize))
top, bottom = panel.divide(sizes_absolute=[1, 1], spacing_absolute=0.5, orientation="vertical")
panels_count = top.divide(sizes_absolute=[1] * 6, spacing_absolute=0.5, orientation="horizontal")
panel_association = bottom.divide(sizes_absolute=[1] * 6, spacing_absolute=0.5, orientation="horizontal")
labels_count = ["A", "B", "C", "D", "E", "F"]
labels_association = ["G", "H", "I", "J", "K", "L"]
axes_count = [panel.get_ax(fig, label=labels_count.pop(0), label_offset=(0.05, 0)) for panel in panels_count]
axes_association = [panel.get_ax(fig, label=labels_association.pop(0), label_offset=(0.05, 0)) for panel in panel_association]
aggregated_counts["duration_associated_min"] = aggregated_counts["duration_associated"] / (30 * 60)
max_association = aggregated_counts["duration_associated_min"].max()
max_count = np.asarray(aggregated_counts[["count_predicted", "count_annotated"]]).max()
for idx, subset in enumerate(category_subsets):
interaction_counts = aggregated_counts.loc[aggregated_counts["categories_subset"] == subset]
n_dyads = len(interaction_counts)
sns.regplot(
data=interaction_counts,
x="count_predicted",
y="count_annotated",
ax=axes_count[idx],
x_jitter=0.2,
y_jitter=0.2,
scatter_kws=dict(s=5, color="k", linewidths=0),
line_kws=dict(lw=1.5, color="gray", zorder=-1),
)
sns.regplot(
data=interaction_counts,
x="duration_associated_min",
y="count_annotated",
ax=axes_association[idx],
y_jitter=0.2,
scatter_kws=dict(s=5, color="k", linewidths=0),
line_kws=dict(lw=1.5, color="gray", zorder=-1),
)
axes_count[idx].set_xlim(-1, max_count + 1)
axes_count[idx].set_ylim(-1, max_count + 1)
axes_association[idx].set_xlim(-1, max_association + 1)
axes_association[idx].set_ylim(-1, max_count + 1)
axes_count[idx].set_ylabel("Count", labelpad=2)
axes_count[idx].set_xlabel("Count", labelpad=2)
axes_association[idx].set_ylabel("Count", labelpad=2)
axes_association[idx].set_xlabel("Time (min)", labelpad=2)
def get_xtick_box(ax, factor=1):
return np.mean(ax.get_xlim()), np.diff(ax.get_xlim())[0] * factor, ax
def get_ytick_box(ax, factor=1):
return np.mean(ax.get_ylim()), np.diff(ax.get_ylim())[0] * factor, ax
for ax, subset, color in zip(axes_count, category_subsets, np.array(colors)[foreground]):
category = subset
if category == "dart_bite":
category = "bite/dart"
category = category.replace("_", " ")
vis.add_xtick_box(*get_xtick_box(ax), y="top", text=category, color=color, offset_in_inches=0.1, **box_kwargs)
vis.add_ytick_box(
*get_ytick_box(axes_count[0], 1.2), text="Annotated\n~\nPredicted",
width_in_inches=0.6, offset_in_inches=0.5, color="lightgrey", **box_kwargs,
)
vis.add_ytick_box(
*get_ytick_box(axes_association[0], 1.2), text="Annotated\n~\nAssociation time",
width_in_inches=0.6, offset_in_inches=0.5, color="lightgrey", **box_kwargs,
)
plt.show()
F1 scores summary¶
[28]:
# f1 scores for all levels (timestamps, annotations, predictions)
# for all (post)processing steps, macro (foreground/all) or for each category
scores = []
formatted = {}
for f1 in ["timestamp", "annotation", "prediction"]:
aggregated = manuscript_utils.aggregate_scores(summary, f"{f1}_f1", categories=dataset_test.categories).reset_index()
aggregated["f1"] = f1
scores.append(aggregated.set_index("f1"))
scores = pd.concat(scores).reset_index()
formatted["F1"] = scores["f1"]
formatted["Postprocessing step"] = scores["postprocessing_step"]
for column in scores.columns.droplevel(1)[2:]:
column_formatted = column.replace("-", " (")
if "(" in column_formatted:
column_formatted += ")"
column_formatted = column_formatted.replace("_", " ")
formatted[column_formatted.capitalize()] = scores[column].apply(lambda values: f"{values["mean"]:.03f}±{values["std"]:.03f}", axis=1)
formatted = pd.DataFrame(formatted)
formatted.map(lambda s: s.replace("_", " ").capitalize())
[28]:
| F1 | Postprocessing step | Macro (foreground) | Macro (all) | Approach | Chase | Dart bite | Frontal display | Lateral display | None | Quiver | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | Timestamp | Model outputs | 0.247±0.003 | 0.354±0.003 | 0.237±0.010 | 0.315±0.007 | 0.298±0.007 | 0.317±0.004 | 0.219±0.004 | 0.995±0.000 | 0.095±0.008 |
| 1 | Timestamp | Smoothed | 0.271±0.005 | 0.374±0.005 | 0.262±0.017 | 0.337±0.008 | 0.327±0.010 | 0.353±0.008 | 0.242±0.009 | 0.996±0.000 | 0.101±0.019 |
| 2 | Timestamp | Thresholded | 0.289±0.005 | 0.390±0.004 | 0.283±0.017 | 0.378±0.010 | 0.316±0.009 | 0.394±0.012 | 0.258±0.020 | 0.997±0.000 | 0.102±0.021 |
| 3 | Annotation | Model outputs | 0.552±0.013 | 0.605±0.012 | 0.495±0.035 | 0.535±0.014 | 0.614±0.022 | 0.674±0.014 | 0.586±0.031 | 0.919±0.003 | 0.409±0.069 |
| 4 | Annotation | Smoothed | 0.521±0.012 | 0.576±0.010 | 0.487±0.032 | 0.537±0.016 | 0.550±0.012 | 0.664±0.011 | 0.571±0.024 | 0.907±0.003 | 0.318±0.058 |
| 5 | Annotation | Thresholded | 0.482±0.015 | 0.536±0.013 | 0.542±0.035 | 0.557±0.024 | 0.524±0.015 | 0.524±0.020 | 0.426±0.031 | 0.864±0.003 | 0.316±0.064 |
| 6 | Prediction | Model outputs | 0.150±0.005 | 0.217±0.004 | 0.196±0.013 | 0.162±0.020 | 0.179±0.012 | 0.171±0.010 | 0.134±0.010 | 0.623±0.002 | 0.057±0.008 |
| 7 | Prediction | Smoothed | 0.303±0.008 | 0.375±0.007 | 0.335±0.023 | 0.290±0.017 | 0.385±0.021 | 0.387±0.009 | 0.279±0.017 | 0.805±0.004 | 0.145±0.031 |
| 8 | Prediction | Thresholded | 0.383±0.009 | 0.453±0.008 | 0.396±0.029 | 0.368±0.025 | 0.413±0.020 | 0.543±0.021 | 0.424±0.038 | 0.868±0.003 | 0.156±0.035 |
[29]:
formatted.to_csv("f1_scores.csv", index=False)