|
15 | 15 | from sklearn.metrics.pairwise import cosine_similarity |
16 | 16 | import pandas as pd |
17 | 17 | from collections import defaultdict |
| 18 | +from sklearn.neighbors import NearestNeighbors |
| 19 | +import math |
18 | 20 |
|
19 | 21 | # Add parent directory to path to import evaluator |
20 | 22 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
@@ -125,13 +127,13 @@ def plot_confusion_matrices(model_names, data_root="/scratch/cv-course2025/group |
125 | 127 | output_dir (str): Directory to save plots |
126 | 128 | distance_measure (str): Distance measure for evaluation |
127 | 129 | """ |
128 | | - from sklearn.neighbors import NearestNeighbors |
129 | | - |
| 130 | + |
130 | 131 | n_models = len(model_names) |
131 | | - fig, axes = plt.subplots(1, n_models, figsize=(6 * n_models, 5)) |
132 | | - if n_models == 1: |
133 | | - axes = [axes] |
134 | | - |
| 132 | + n_cols = 2 |
| 133 | + n_rows = math.ceil(n_models / n_cols) |
| 134 | + fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 5 * n_rows)) |
| 135 | + axes = axes.flatten() # Flatten in case of single row |
| 136 | + |
135 | 137 | for i, model_name in enumerate(model_names): |
136 | 138 | try: |
137 | 139 | print(f"Creating confusion matrix for {model_name}...") |
@@ -187,7 +189,11 @@ def plot_confusion_matrices(model_names, data_root="/scratch/cv-course2025/group |
187 | 189 | ax.text(0.5, 0.5, f'Error: {str(e)}', transform=ax.transAxes, |
188 | 190 | ha='center', va='center', fontsize=12) |
189 | 191 | ax.set_title(f'{model_name} - Error') |
190 | | - |
| 192 | + |
| 193 | + # Hide any unused subplots |
| 194 | + for j in range(len(model_names), len(axes)): |
| 195 | + fig.delaxes(axes[j]) |
| 196 | + |
191 | 197 | plt.suptitle(f'Confusion Matrices ({distance_measure} distance)', fontsize=16, fontweight='bold') |
192 | 198 | plt.tight_layout() |
193 | 199 |
|
|
0 commit comments