Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions src/eval/analyze_labeled_data_scored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
confusion_matrix, ConfusionMatrixDisplay,
precision_score, recall_score, f1_score, accuracy_score
)

# set paths
path = Path(__file__).parent
img_dir = path / "imgs"
img_dir.mkdir(exist_ok=True)

AGE_CLASSES = [
"0-2", "3-9", "10-19", "20-29", "30-39",
"40-49", "50-59", "60-69", "more than 70"
]

def age_to_range(age):
if age <= 2: return "0-2"
elif age <= 9: return "3-9"
elif age <= 19: return "10-19"
elif age <= 29: return "20-29"
elif age <= 39: return "30-39"
elif age <= 49: return "40-49"
elif age <= 59: return "50-59"
elif age <= 69: return "60-69"
else: return "more than 70"

def simplify_age_bucket(age_range: str) -> str:
below_30 = {"0-2", "3-9", "10-19", "20-29"}
return "Below 30" if age_range in below_30 else "30 and above"

def write_model_summary(model_name, acc, prec, rec, f1):
summary = f"{model_name} Summary:\n"
summary += f"• Accuracy: {acc:.2f} - Proportion of correct predictions.\n"
summary += f"• Precision: {prec:.2f} - Of predicted labels, how many were correct.\n"
summary += f"• Recall: {rec:.2f} - Of actual labels, how many were found.\n"
summary += f"• F1 Score: {f1:.2f} - Balance between Precision & Recall.\n"
if acc > 0.75:
summary += "This model performs well overall.\n"
elif acc > 0.5:
summary += "Moderate performance.\n"
else:
summary += "Struggles with accurate predictions.\n"

with open(img_dir / f"{model_name}_summary.txt", "w") as f:
f.write(summary)

def plot_conf_matrix(y_pred, y_true, model_name):
cm = confusion_matrix(y_true, y_pred, labels=AGE_CLASSES)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=AGE_CLASSES)
disp.plot(xticks_rotation=45, cmap="Blues")
plt.title(f"Confusion Matrix - {model_name}\n(Darker diagonal = better accuracy)")
plt.tight_layout()
plt.savefig(img_dir / f"conf_matrix_{model_name}.png")
plt.clf()

def plot_simplified_conf_matrix(y_true, y_pred, model_name):
y_true_simple = y_true.apply(simplify_age_bucket)
y_pred_simple = y_pred.apply(simplify_age_bucket)
cm = confusion_matrix(y_true_simple, y_pred_simple, labels=["Below 30", "30 and above"])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Below 30", "30 and above"])
disp.plot(cmap="Purples", xticks_rotation=0)
plt.title(f"Simplified Confusion Matrix - {model_name}")
plt.figtext(0.5, -0.05, "Off-diagonal cells = misclassifications", wrap=True, ha='center', fontsize=9)
plt.tight_layout()
plt.savefig(img_dir / f"simplified_conf_matrix_{model_name}.png")
plt.clf()

def plot_metrics(y_true, y_pred, model_name):
acc = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, average="macro", zero_division=0)
rec = recall_score(y_true, y_pred, average="macro", zero_division=0)
f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)

metrics = {"Accuracy": acc, "Precision": prec, "Recall": rec, "F1-Score": f1}
sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), palette="mako")
for i, (k, v) in enumerate(metrics.items()):
plt.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10, color='black')

plt.ylim(0, 1.05)
plt.title(f"Performance Metrics - {model_name}")
plt.figtext(0.5, -0.05, "Scores closer to 1.0 indicate better performance", wrap=True, ha='center', fontsize=9)
plt.tight_layout()
plt.savefig(img_dir / f"metrics_{model_name}.png")
plt.clf()

write_model_summary(model_name, acc, prec, rec, f1)

def plot_misclassifications(y_true, y_pred, model_name):
df_mis = pd.DataFrame({"true": y_true, "pred": y_pred})
mis_df = df_mis[df_mis["true"] != df_mis["pred"]]
most_common = mis_df.groupby(["true", "pred"]).size().reset_index(name="count")
most_common = most_common.sort_values(by="count", ascending=False).head(10)

plt.figure(figsize=(10, 6))
sns.barplot(
x="count",
y=most_common.apply(lambda x: f"{x['true']} → {x['pred']}", axis=1),
data=most_common,
palette="rocket"
)
plt.xlabel("Count")
plt.ylabel("Misclassification")
plt.title(f"Top Misclassifications - {model_name}")
plt.tight_layout()
plt.savefig(img_dir / f"top_misclassifications_{model_name}.png")
plt.clf()

def plot_conf_heatmap(y_true, y_pred, model_name):
cm = confusion_matrix(y_true, y_pred, labels=AGE_CLASSES, normalize='true')
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, xticklabels=AGE_CLASSES, yticklabels=AGE_CLASSES, cmap="YlGnBu", fmt=".2f")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title(f"Confusion Heatmap - {model_name}")
plt.figtext(0.5, -0.02, "Brighter squares on the diagonal = better classification accuracy", ha='center', fontsize=9)
plt.tight_layout()
plt.savefig(img_dir / f"heatmap_{model_name}.png")
plt.clf()

def main():
df = pd.read_csv(path / "temp_output_labeled.csv")

for model in ["age_classify_v001", "fairface_classifier", "vit_age_classifier"]:
model_df = df[df["model_name"] == model]
if model_df.empty:
print(f"No data found for {model}")
continue

y_true = model_df["true_label"].apply(age_to_range)
y_pred = model_df["label"]

plot_conf_matrix(y_pred, y_true, model)
plot_simplified_conf_matrix(y_true, y_pred, model)
plot_metrics(y_true, y_pred, model)
plot_misclassifications(y_true, y_pred, model)
plot_conf_heatmap(y_true, y_pred, model)

print(f"All visualizations and summaries saved to {img_dir}")

if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions src/eval/eval_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from src.eval.analyze_labeled_data_raw import main as raw_main
from src.eval.score_labeled_data import main as run_evaluation
from eval.transform_scores import main as transform_outputs

from src.eval.analyze_labeled_data_scored import main as visualize_outputs

def main(eval_table: str="age_gender_labeled", raw_plots: bool=False) -> pd.DataFrame:
"""Orchestrate full evaluation pipeline.
Expand All @@ -21,7 +21,7 @@ def main(eval_table: str="age_gender_labeled", raw_plots: bool=False) -> pd.Data

# TODO: endpoint for visualizations here
# chart scores and true/predicted labels

visualize_outputs()
return df


Expand Down
Binary file added src/eval/imgs/conf_matrix_age_classify_v001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/eval/imgs/conf_matrix_fairface_classifier.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/eval/imgs/conf_matrix_vit_age_classifier.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/eval/imgs/heatmap_age_classify_v001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/eval/imgs/heatmap_fairface_classifier.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/eval/imgs/heatmap_vit_age_classifier.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/eval/imgs/metrics_age_classify_v001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/eval/imgs/metrics_fairface_classifier.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/eval/imgs/metrics_vit_age_classifier.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading