Skip to content

Commit 7cd3e24

Browse files
committed
Fixed Invalid Type Passing, Infinite Loop when Saving Plots (#164)
* fixed infinite loop due to invalid existing file checks * fixed strings being passed to plot_type arg instead of PlotType enums
1 parent 29f3e72 commit 7cd3e24

2 files changed

Lines changed: 44 additions & 26 deletions

File tree

simopt/experiment_base.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,13 +1355,27 @@ class PlotType(Enum):
13551355
SOLVE_TIME_CDF = "solve_time_cdf"
13561356
CDF_SOLVABILITY = "cdf_solvability"
13571357
QUANTILE_SOLVABILITY = "quantile_solvability"
1358-
DIFF_CDF_SOLVABILITY = "diff_cdf_solvability"
1359-
DIFF_QUANTILE_SOLVABILITY = "diff_quantile_solvability"
1358+
DIFF_CDF_SOLVABILITY = "difference_of_cdf_solvability"
1359+
DIFF_QUANTILE_SOLVABILITY = "difference_of_quantile_solvability"
13601360
AREA = "area"
13611361
BOX = "box"
13621362
VIOLIN = "violin"
13631363
TERMINAL_SCATTER = "terminal_scatter"
13641364

1365+
@staticmethod
1366+
def from_str(label: str) -> PlotType:
1367+
"""Converts a string label to a PlotType enum."""
1368+
# Reverse mapping from string to PlotType enum.
1369+
name = label.lower().replace(" ", "_")
1370+
inv_plot_type = {pt.value: pt for pt in PlotType}
1371+
if name in inv_plot_type:
1372+
return inv_plot_type[name]
1373+
error_msg = (
1374+
f"Unknown plot type: {label} ({name}). "
1375+
f"Must be one of {[pt.value for pt in PlotType]}."
1376+
)
1377+
raise ValueError(error_msg)
1378+
13651379

13661380
def bootstrap_procedure(
13671381
experiments: list[list[ProblemSolver]],
@@ -3621,35 +3635,29 @@ def save_plot(
36213635
plot_name = plot_name + "_unnorm"
36223636

36233637
# Reformat plot_name to be suitable as a string literal.
3624-
plot_name = plot_name.replace("\\", "")
3625-
plot_name = plot_name.replace("$", "")
3626-
plot_name = plot_name.replace(" ", "_")
3638+
plot_name = plot_name.replace("\\", "").replace("$", "").replace(" ", "_")
36273639

36283640
# If the plot title is not provided, use the default title.
36293641
if plot_title is None:
36303642
plot_title = f"{solver_name}_{problem_name}_{plot_name}"
3631-
path_name = plot_dir / plot_title
36323643

3633-
# Check to make sure file does not override previous images
3634-
counter = 0
3635-
while True:
3636-
# add extension to path name
3637-
extended_path_name = path_name.with_suffix(ext)
3638-
3639-
# If file doesn't exist, break out of loop
3640-
if not extended_path_name.exists():
3641-
break
3644+
# Read in the contents of the plot directory
3645+
existing_plots = [path.name for path in list(plot_dir.glob("*"))]
3646+
print(f"Existing plots: {existing_plots}")
36423647

3643-
# If file exists, increment counter and try again
3648+
counter = 0
3649+
while (plot_title + ext) in existing_plots:
3650+
# If the plot title already exists, append a counter to the filename
36443651
counter += 1
3645-
path_name = plot_dir / f"{plot_title} ({counter})"
3652+
plot_title = f"{plot_title} ({counter})"
3653+
extended_path_name = plot_dir / (plot_title + ext)
36463654

36473655
plt.savefig(extended_path_name, bbox_inches="tight")
36483656

36493657
# save plot as pickle
36503658
if save_as_pickle:
36513659
fig = plt.gcf()
3652-
pickle_path = path_name.with_suffix(".pkl")
3660+
pickle_path = extended_path_name.with_suffix(".pkl")
36533661
with pickle_path.open("wb") as pickle_file:
36543662
pickle.dump(fig, pickle_file)
36553663
# Return path_name for use in GUI.

simopt/gui/new_experiment_window.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from simopt.base import Problem, Solver
2121
from simopt.data_farming_base import DATA_FARMING_DIR
2222
from simopt.experiment_base import (
23+
PlotType,
2324
ProblemSolver,
2425
ProblemsSolvers,
2526
create_design,
@@ -4027,7 +4028,7 @@ def show_plot_options(self, plot_type_tk: tk.StringVar) -> None:
40274028
subplot_type_options = [
40284029
"CDF Solvability",
40294030
"Quantile Solvability",
4030-
"Difference of CDF Solvablility",
4031+
"Difference of CDF Solvability",
40314032
"Difference of Quantile Solvability",
40324033
]
40334034
self.subplot_type_var = tk.StringVar()
@@ -4317,7 +4318,7 @@ def enable_ref_solver(self, plot_type_tk: tk.StringVar) -> None:
43174318
if plot_type in ["CDF Solvability", "Quantile Solvability"]:
43184319
self.ref_solver_menu.configure(state="disabled")
43194320
elif plot_type in [
4320-
"Difference of CDF Solvablility",
4321+
"Difference of CDF Solvability",
43214322
"Difference of Quantile Solvability",
43224323
]:
43234324
self.ref_solver_menu.configure(state="normal")
@@ -4456,14 +4457,18 @@ def __plot_progress_curve(self) -> None:
44564457
parameters["Quantile Probability"] = beta
44574458
parameters["Number Bootstrap Samples"] = n_boot
44584459
parameters["Confidence Level"] = con_level
4460+
# Lookup plot type enum for passing to plotting function
4461+
subplot_type_enum: PlotType = PlotType.from_str(
4462+
subplot_type.lower()
4463+
)
44594464
# create new plot for each problem
44604465
for i in range(n_problems):
44614466
prob_list = []
44624467
for solver_group in exp_sublist:
44634468
prob_list.append(solver_group[i])
44644469
returned_path = plot_progress_curves(
44654470
experiments=prob_list,
4466-
plot_type=subplot_type, # type: ignore
4471+
plot_type=subplot_type_enum,
44674472
beta=beta,
44684473
normalize=norm,
44694474
all_in_one=all_in,
@@ -4642,14 +4647,18 @@ def __plot_terminal_progress(self) -> None:
46424647
parameters["Plot Type"] = subplot_type
46434648
assert subplot_type in ["box", "violin"]
46444649
parameters["Normalize Optimality Gaps"] = normalize_str
4650+
# Lookup plot type enum for passing to plotting function
4651+
subplot_type_enum: PlotType = PlotType.from_str(
4652+
subplot_type.lower()
4653+
)
46454654
# create a new plot for each problem
46464655
for i in range(n_problems):
46474656
prob_list = []
46484657
for solver_group in exp_sublist:
46494658
prob_list.append(solver_group[i])
46504659
returned_path = plot_terminal_progress(
46514660
experiments=prob_list,
4652-
plot_type=subplot_type, # type: ignore
4661+
plot_type=subplot_type_enum,
46534662
all_in_one=all_in,
46544663
normalize=norm,
46554664
save_as_pickle=True,
@@ -4736,7 +4745,7 @@ def __plot_solvability_profile(self) -> None:
47364745
subplot_types = {
47374746
"CDF Solvability": "cdf_solvability",
47384747
"Quantile Solvability": "quantile_solvability",
4739-
"Difference of CDF Solvablility": "diff_cdf_solvability",
4748+
"Difference of CDF Solvability": "diff_cdf_solvability",
47404749
"Difference of Quantile Solvability": "diff_quantile_solvability",
47414750
}
47424751
subplot_type = self.subplot_type_var.get()
@@ -4768,11 +4777,12 @@ def __plot_solvability_profile(self) -> None:
47684777
"Difference of Quantile Solvability",
47694778
]:
47704779
parameters["Quantile Probability"] = beta
4771-
4780+
# Lookup plot type enum for passing to plotting function
4781+
subplot_type_enum = PlotType.from_str(subplot_type)
47724782
if subplot_type in ["CDF Solvability", "Quantile Solvability"]:
47734783
returned_path = plot_solvability_profiles(
47744784
experiments=exp_sublist,
4775-
plot_type=plot_input, # type: ignore
4785+
plot_type=subplot_type_enum,
47764786
all_in_one=all_in,
47774787
n_bootstraps=n_boot,
47784788
conf_level=con_level,
@@ -4792,7 +4802,7 @@ def __plot_solvability_profile(self) -> None:
47924802
parameters["Reference Solver"] = ref_solver
47934803
returned_path = plot_solvability_profiles(
47944804
experiments=exp_sublist,
4795-
plot_type=plot_input, # type: ignore
4805+
plot_type=subplot_type_enum,
47964806
all_in_one=all_in,
47974807
n_bootstraps=n_boot,
47984808
conf_level=con_level,

0 commit comments

Comments
 (0)