Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions simopt/experiment_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,13 +1355,27 @@ class PlotType(Enum):
SOLVE_TIME_CDF = "solve_time_cdf"
CDF_SOLVABILITY = "cdf_solvability"
QUANTILE_SOLVABILITY = "quantile_solvability"
DIFF_CDF_SOLVABILITY = "diff_cdf_solvability"
DIFF_QUANTILE_SOLVABILITY = "diff_quantile_solvability"
DIFF_CDF_SOLVABILITY = "difference_of_cdf_solvability"
DIFF_QUANTILE_SOLVABILITY = "difference_of_quantile_solvability"
AREA = "area"
BOX = "box"
VIOLIN = "violin"
TERMINAL_SCATTER = "terminal_scatter"

@staticmethod
def from_str(label: str) -> PlotType:
"""Converts a string label to a PlotType enum."""
# Reverse mapping from string to PlotType enum.
name = label.lower().replace(" ", "_")
inv_plot_type = {pt.value: pt for pt in PlotType}
if name in inv_plot_type:
return inv_plot_type[name]
error_msg = (
f"Unknown plot type: {label} ({name}). "
f"Must be one of {[pt.value for pt in PlotType]}."
)
raise ValueError(error_msg)


def bootstrap_procedure(
experiments: list[list[ProblemSolver]],
Expand Down Expand Up @@ -3621,35 +3635,29 @@ def save_plot(
plot_name = plot_name + "_unnorm"

# Reformat plot_name to be suitable as a string literal.
plot_name = plot_name.replace("\\", "")
plot_name = plot_name.replace("$", "")
plot_name = plot_name.replace(" ", "_")
plot_name = plot_name.replace("\\", "").replace("$", "").replace(" ", "_")

# If the plot title is not provided, use the default title.
if plot_title is None:
plot_title = f"{solver_name}_{problem_name}_{plot_name}"
path_name = plot_dir / plot_title

# Check to make sure file does not override previous images
counter = 0
while True:
# add extension to path name
extended_path_name = path_name.with_suffix(ext)

# If file doesn't exist, break out of loop
if not extended_path_name.exists():
break
# Read in the contents of the plot directory
existing_plots = [path.name for path in list(plot_dir.glob("*"))]
print(f"Existing plots: {existing_plots}")

# If file exists, increment counter and try again
counter = 0
while (plot_title + ext) in existing_plots:
# If the plot title already exists, append a counter to the filename
counter += 1
path_name = plot_dir / f"{plot_title} ({counter})"
plot_title = f"{plot_title} ({counter})"
extended_path_name = plot_dir / (plot_title + ext)

plt.savefig(extended_path_name, bbox_inches="tight")

# save plot as pickle
if save_as_pickle:
fig = plt.gcf()
pickle_path = path_name.with_suffix(".pkl")
pickle_path = extended_path_name.with_suffix(".pkl")
with pickle_path.open("wb") as pickle_file:
pickle.dump(fig, pickle_file)
# Return path_name for use in GUI.
Expand Down
26 changes: 18 additions & 8 deletions simopt/gui/new_experiment_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from simopt.base import Problem, Solver
from simopt.data_farming_base import DATA_FARMING_DIR
from simopt.experiment_base import (
PlotType,
ProblemSolver,
ProblemsSolvers,
create_design,
Expand Down Expand Up @@ -4027,7 +4028,7 @@ def show_plot_options(self, plot_type_tk: tk.StringVar) -> None:
subplot_type_options = [
"CDF Solvability",
"Quantile Solvability",
"Difference of CDF Solvablility",
"Difference of CDF Solvability",
"Difference of Quantile Solvability",
]
self.subplot_type_var = tk.StringVar()
Expand Down Expand Up @@ -4317,7 +4318,7 @@ def enable_ref_solver(self, plot_type_tk: tk.StringVar) -> None:
if plot_type in ["CDF Solvability", "Quantile Solvability"]:
self.ref_solver_menu.configure(state="disabled")
elif plot_type in [
"Difference of CDF Solvablility",
"Difference of CDF Solvability",
"Difference of Quantile Solvability",
]:
self.ref_solver_menu.configure(state="normal")
Expand Down Expand Up @@ -4456,14 +4457,18 @@ def __plot_progress_curve(self) -> None:
parameters["Quantile Probability"] = beta
parameters["Number Bootstrap Samples"] = n_boot
parameters["Confidence Level"] = con_level
# Lookup plot type enum for passing to plotting function
subplot_type_enum: PlotType = PlotType.from_str(
subplot_type.lower()
)
# create new plot for each problem
for i in range(n_problems):
prob_list = []
for solver_group in exp_sublist:
prob_list.append(solver_group[i])
returned_path = plot_progress_curves(
experiments=prob_list,
plot_type=subplot_type, # type: ignore
plot_type=subplot_type_enum,
beta=beta,
normalize=norm,
all_in_one=all_in,
Expand Down Expand Up @@ -4642,14 +4647,18 @@ def __plot_terminal_progress(self) -> None:
parameters["Plot Type"] = subplot_type
assert subplot_type in ["box", "violin"]
parameters["Normalize Optimality Gaps"] = normalize_str
# Lookup plot type enum for passing to plotting function
subplot_type_enum: PlotType = PlotType.from_str(
subplot_type.lower()
)
# create a new plot for each problem
for i in range(n_problems):
prob_list = []
for solver_group in exp_sublist:
prob_list.append(solver_group[i])
returned_path = plot_terminal_progress(
experiments=prob_list,
plot_type=subplot_type, # type: ignore
plot_type=subplot_type_enum,
all_in_one=all_in,
normalize=norm,
save_as_pickle=True,
Expand Down Expand Up @@ -4736,7 +4745,7 @@ def __plot_solvability_profile(self) -> None:
subplot_types = {
"CDF Solvability": "cdf_solvability",
"Quantile Solvability": "quantile_solvability",
"Difference of CDF Solvablility": "diff_cdf_solvability",
"Difference of CDF Solvability": "diff_cdf_solvability",
"Difference of Quantile Solvability": "diff_quantile_solvability",
}
subplot_type = self.subplot_type_var.get()
Expand Down Expand Up @@ -4768,11 +4777,12 @@ def __plot_solvability_profile(self) -> None:
"Difference of Quantile Solvability",
]:
parameters["Quantile Probability"] = beta

# Lookup plot type enum for passing to plotting function
subplot_type_enum = PlotType.from_str(subplot_type)
if subplot_type in ["CDF Solvability", "Quantile Solvability"]:
returned_path = plot_solvability_profiles(
experiments=exp_sublist,
plot_type=plot_input, # type: ignore
plot_type=subplot_type_enum,
all_in_one=all_in,
n_bootstraps=n_boot,
conf_level=con_level,
Expand All @@ -4792,7 +4802,7 @@ def __plot_solvability_profile(self) -> None:
parameters["Reference Solver"] = ref_solver
returned_path = plot_solvability_profiles(
experiments=exp_sublist,
plot_type=plot_input, # type: ignore
plot_type=subplot_type_enum,
all_in_one=all_in,
n_bootstraps=n_boot,
conf_level=con_level,
Expand Down