Skip to content

Commit 633bc31

Browse files
committed
optimize layout
1 parent 342fe73 commit 633bc31

1 file changed

Lines changed: 106 additions & 63 deletions

File tree

tlparser/viz.py

Lines changed: 106 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _plot_violin(
166166

167167
if len(metrics) == 3:
168168
# Create a layout with the third plot centered by spanning both columns
169-
fig = plt.figure(figsize=(8, 8))
169+
fig = plt.figure(figsize=(7, 8))
170170
gs = fig.add_gridspec(nrows=2, ncols=2)
171171
axes_list = [
172172
fig.add_subplot(gs[0, 0]),
@@ -176,9 +176,9 @@ def _plot_violin(
176176
fig.subplots_adjust(hspace=0.05, wspace=0.05)
177177
else:
178178
fig, axes = plt.subplots(
179-
nrows=3 if len(metrics) > 3 else 2,
179+
nrows=3 if len(metrics) > 3 else 1,
180180
ncols=2,
181-
figsize=(8, 11) if len(metrics) > 3 else (8, 7),
181+
figsize=(7, 12) if len(metrics) > 3 else (7, 4),
182182
sharex=False,
183183
sharey=False,
184184
)
@@ -248,7 +248,7 @@ def _plot_violin(
248248
)
249249
ax.text(
250250
x_shift,
251-
0.83,
251+
0.80,
252252
annotation_text,
253253
color="black",
254254
ha="center",
@@ -334,7 +334,7 @@ def plot_violin_engcompl(self, include_strip=False, palette_index=0):
334334
def plot_violin_reqtext(self, include_strip=False, palette_index=0):
335335
df_filtered = self.data[self.data["translation"] == "self"]
336336
# metrics = df_filtered.filter(like=".req_").columns.tolist()
337-
metrics = ["stats.req_word_count", "stats.req_sentence_count", "stats.req_len"]
337+
metrics = ["stats.req_word_count", "stats.req_sentence_count"]
338338
df_long = pd.melt(
339339
df_filtered,
340340
id_vars=["id", "type"],
@@ -427,8 +427,8 @@ def plot_pairplot_reqwords(self, include_trend: bool = False):
427427
"stats.agg.tops",
428428
"stats.asth",
429429
"stats.entropy.lops_tops",
430-
"stats.req_sentence_count",
431430
"stats.req_len",
431+
"stats.req_sentence_count",
432432
]
433433
y_metric = "stats.req_word_count"
434434

@@ -437,11 +437,22 @@ def plot_pairplot_reqwords(self, include_trend: bool = False):
437437
# Nothing to plot
438438
return ""
439439

440-
# Build figure 3x3
440+
# Build figure with 2 columns and dynamic rows
441+
ncols = 2
442+
n = len(available)
443+
nrows = math.ceil(n / ncols)
444+
fig_height = max(3, nrows * 3)
441445
fig, axes = plt.subplots(
442-
nrows=3, ncols=3, figsize=(8, 7), sharey=True, sharex=False
446+
nrows=nrows, ncols=ncols, figsize=(7, fig_height), sharey=True, sharex=False
443447
)
444-
axes = axes.flatten()
448+
# Match _plot_violin spacing
449+
plt.subplots_adjust(hspace=0.05, wspace=0.05)
450+
451+
if nrows * ncols == 1:
452+
axes = [axes]
453+
else:
454+
axes = axes.flatten()
455+
445456
base_color = "#ffaf2d"
446457

447458
for ax, x in zip(axes, available):
@@ -454,10 +465,12 @@ def plot_pairplot_reqwords(self, include_trend: bool = False):
454465
edgecolors="black",
455466
linewidths=0.5,
456467
)
457-
x_label = self.title_map.get(x, [x, ""])[0]
458-
y_label = "Requirement Words"
459-
ax.set_xlabel(x_label)
460-
ax.set_ylabel(y_label)
468+
469+
# Match _plot_violin: use subplot title above, empty x-label
470+
x_title = self.title_map.get(x, [x, ""])[0]
471+
ax.set_title(x_title, fontsize=plt.rcParams.get("axes.titlesize", 10))
472+
ax.set_xlabel("")
473+
ax.set_ylabel("Requirement Words")
461474

462475
if include_trend and df[x].nunique() > 1:
463476
sns.regplot(
@@ -469,43 +482,33 @@ def plot_pairplot_reqwords(self, include_trend: bool = False):
469482
color="black",
470483
line_kws={"linewidth": 1.2, "alpha": 0.8, "zorder": 5},
471484
)
472-
# Restore labels after regplot overlay
473-
ax.set_xlabel(x_label)
474-
ax.set_ylabel(y_label)
485+
# Keep labels consistent after regplot overlay
486+
ax.set_title(x_title, fontsize=plt.rcParams.get("axes.titlesize", 10))
487+
ax.set_xlabel("")
488+
ax.set_ylabel("Requirement Words")
489+
490+
fig.tight_layout()
475491

476-
# Hide any unused axes (if fewer than 9 x metrics are available)
492+
# If the last row contains a single plot, center it within the bottom-row span (no stretching)
493+
if n % ncols == 1 and len(axes) >= (nrows * ncols):
494+
idx = (nrows - 1) * ncols # first axis in the last row
495+
# Use both bottom axes to compute true span, even if the second is empty
496+
left_pos = axes[idx].get_position()
497+
right_pos = axes[idx + 1].get_position()
498+
left_edge = left_pos.x0
499+
right_edge = right_pos.x0 + right_pos.width
500+
x_center = (left_edge + right_edge) / 2.0
501+
502+
W = left_pos.width
503+
H = left_pos.height
504+
Y0 = left_pos.y0
505+
X0 = x_center - W / 2.0
506+
axes[idx].set_position([X0, Y0, W, H])
507+
508+
# Hide any unused axes (if the grid is larger than available metrics)
477509
for j in range(len(available), len(axes)):
478510
axes[j].set_visible(False)
479511

480-
fig.tight_layout()
481-
# Center items in the last row without stretching when not all columns are used
482-
n = len(available)
483-
if n > 0 and n % 3 != 0:
484-
r = n % 3 # 1 or 2 items in the last row
485-
last_row_start = (n // 3) * 3
486-
# Reference positions for the three columns in the last row
487-
left_pos = axes[last_row_start].get_position()
488-
mid_pos = axes[last_row_start + 1].get_position()
489-
right_pos = axes[last_row_start + 2].get_position()
490-
491-
if r == 1:
492-
# Move the single plot to the middle column
493-
axes[last_row_start].set_position(
494-
[mid_pos.x0, mid_pos.y0, mid_pos.width, mid_pos.height]
495-
)
496-
elif r == 2:
497-
# Place the two plots adjacent and centered as a group.
498-
W = left_pos.width
499-
# spacing between adjacent columns (gap between left and middle)
500-
S = mid_pos.x0 - (left_pos.x0 + left_pos.width)
501-
total = 2 * W + S
502-
left_x = 0.5 - total / 2
503-
axes[last_row_start].set_position(
504-
[left_x, left_pos.y0, W, left_pos.height]
505-
)
506-
axes[last_row_start + 1].set_position(
507-
[left_x + W + S, left_pos.y0, W, left_pos.height]
508-
)
509512
out = self.__get_file_name("pair_reqw")
510513
plt.savefig(out)
511514
plt.close()
@@ -725,7 +728,8 @@ def plot_dag_interactive(self):
725728
return outs
726729

727730
def plot_operator_bars(self):
728-
# Summarize operator counts across self translations and plot three bar charts side-by-side
731+
# Summarize operator counts across self translations and plot three bar charts:
732+
# (a) Temporal, (b) Logical on the first row, and (c) Comparison centered in the second row.
729733
df = self.data[self.data["translation"] == "self"].copy()
730734

731735
groups = [
@@ -760,12 +764,10 @@ def plot_operator_bars(self):
760764
),
761765
]
762766

763-
# Compute data and width ratios first so we can size columns proportionally
764-
bar_groups = [] # list of (title, labels, values)
765-
width_ratios = []
767+
# Prepare data per group
768+
bar_groups = []
766769
for title, mapping in groups:
767-
labels = []
768-
values = []
770+
labels, values = [], []
769771
cols = [col for col in mapping.values() if col in df.columns]
770772
if cols:
771773
sums = df[cols].sum()
@@ -776,27 +778,40 @@ def plot_operator_bars(self):
776778
if not labels:
777779
labels, values = ["n/a"], [0]
778780
bar_groups.append((title, labels, values))
779-
width_ratios.append(max(len(labels), 1))
780781

781-
fig, axes = plt.subplots(
782-
1,
783-
3,
784-
figsize=(9, 3.5),
785-
sharey=False,
786-
gridspec_kw={"width_ratios": width_ratios},
782+
# Compute separate y-limits: shared for first row, independent for third plot
783+
first_row_max = max(
784+
(max(vals) for _, _, vals in bar_groups[:2] if vals), default=0
787785
)
788-
axes = axes.flatten()
786+
third_max = max(bar_groups[2][2]) if bar_groups[2][2] else 0
787+
y_top_first = max(1, int(first_row_max * 1.15) + 1)
788+
y_top_third = max(1, int(third_max * 1.15) + 1)
789+
790+
# Make second row shorter
791+
second_row_ratio = 0.7 # < 1.0 makes the second row less high than the first
792+
fig = plt.figure(figsize=(6, 5))
793+
gs = fig.add_gridspec(nrows=2, ncols=2, height_ratios=[1.0, second_row_ratio])
794+
795+
ax00 = fig.add_subplot(gs[0, 0])
796+
ax01 = fig.add_subplot(gs[0, 1], sharey=ax00) # share y only within first row
797+
ax10 = fig.add_subplot(gs[1, 0]) # independent y
798+
ax11 = fig.add_subplot(gs[1, 1]) # placeholder to compute span
799+
axes = [ax00, ax01, ax10, ax11]
789800

790801
palette = sns.color_palette("tab10")
791-
alpha = 0.85 # add a bit of transparency to bar faces
802+
alpha = 0.85
792803

793-
for ax, (title, labels, values) in zip(axes, bar_groups):
804+
# Plot the three groups on the first three axes
805+
for idx, (ax, (title, labels, values)) in enumerate(zip(axes[:3], bar_groups)):
794806
base_colors = palette[: len(labels)]
795807
bar_colors = [(r, g, b, alpha) for (r, g, b) in base_colors]
796808
bars = ax.bar(labels, values, color=bar_colors, edgecolor="black")
797809
ax.set_title(title)
798810
ax.set_ylabel("Count")
799-
ax.set_ylim(0, max(values) * 1.15 + 1)
811+
if idx < 2:
812+
ax.set_ylim(0, y_top_first)
813+
else:
814+
ax.set_ylim(0, y_top_third)
800815

801816
for rect, val in zip(bars, values):
802817
ax.annotate(
@@ -810,6 +825,34 @@ def plot_operator_bars(self):
810825
)
811826

812827
fig.tight_layout()
828+
829+
# Stretch the third plot to match bar pixel width of first row and center within bottom-row span
830+
n_bars = [len(labels) for _, labels, _ in bar_groups]
831+
if n_bars[2] > 0 and max(n_bars[0], n_bars[1]) > 0:
832+
top_left = axes[0].get_position()
833+
834+
# Use both bottom axes to get the true span of the second row
835+
bottom_left = axes[2].get_position()
836+
bottom_right = axes[3].get_position()
837+
left_edge = bottom_left.x0
838+
right_edge = bottom_right.x0 + bottom_right.width
839+
span_total = right_edge - left_edge
840+
841+
W_ref = top_left.width
842+
N_ref = max(n_bars[0], n_bars[1])
843+
desired_W3 = W_ref * (n_bars[2] / N_ref)
844+
845+
W3 = min(desired_W3, span_total)
846+
H3 = bottom_left.height
847+
Y3 = bottom_left.y0
848+
849+
X_center = (left_edge + right_edge) / 2.0
850+
X3 = X_center - W3 / 2.0
851+
axes[2].set_position([X3, Y3, W3, H3])
852+
853+
# Hide the unused 4th axis after layout adjustments
854+
axes[3].set_visible(False)
855+
813856
out = self.__get_file_name("ops_bars")
814857
plt.savefig(out)
815858
plt.close()

0 commit comments

Comments
 (0)