Skip to content

Commit cb9b1c5

Browse files
committed
Adds fixed figure size and font size to make corner plot less awkward
1 parent 700f6b7 commit cb9b1c5

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

RATapi/utils/plotting.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from functools import partial, wraps
55
from math import ceil, floor, sqrt
66
from statistics import stdev
7-
from textwrap import fill
87
from typing import Callable, Literal, Optional, Union
98

109
import matplotlib
@@ -668,11 +667,15 @@ def plot_corner(
668667

669668
num_params = len(params)
670669

671-
fig, axes = plt.subplots(num_params, num_params, figsize=(2 * num_params, 2 * num_params))
670+
fig, axes = plt.subplots(num_params, num_params, figsize=(14, 10))
672671
# i is row, j is column
673672
for i, row_param in enumerate(params):
674673
for j, col_param in enumerate(params):
675674
current_axes: Axes = axes[i][j]
675+
current_axes.tick_params(which="both", labelsize="medium")
676+
current_axes.xaxis.offsetText.set_fontsize("small")
677+
current_axes.yaxis.offsetText.set_fontsize("small")
678+
current_axes.yaxis.offsetText.set_x(-1.5)
676679
if i == j: # diagonal: histograms
677680
plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs)
678681
elif i > j: # lower triangle: 2d histograms
@@ -689,8 +692,6 @@ def plot_corner(
689692
# make labels invisible as titles cover that
690693
current_axes.set_ylabel("")
691694
current_axes.set_xlabel("")
692-
693-
fig.tight_layout()
694695
if return_fig:
695696
return fig
696697
plt.show(block=block)
@@ -776,7 +777,7 @@ def plot_one_hist(
776777
color="white",
777778
)
778779

779-
axes.set_title(fill(results.fitNames[param], 20)) # use `fill` to wrap long titles
780+
axes.set_title(results.fitNames[param], loc="left", fontsize="medium")
780781

781782
if estimated_density:
782783
dx = bins[1] - bins[0]
@@ -899,7 +900,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
899900
"""
900901
nplots = len(indices)
901902
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
902-
fig = plt.subplots(nrows, ncols, figsize=(2.5 * ncols, 2 * nrows))[0]
903+
fig = plt.subplots(nrows, ncols, figsize=(14, 10))[0]
903904
axs = fig.get_axes()
904905

905906
for plot_num, index in enumerate(indices):

0 commit comments

Comments
 (0)