diff --git a/simopt/experiment_base.py b/simopt/experiment_base.py index b13ceab88..ff6827d24 100644 --- a/simopt/experiment_base.py +++ b/simopt/experiment_base.py @@ -1355,13 +1355,27 @@ class PlotType(Enum): SOLVE_TIME_CDF = "solve_time_cdf" CDF_SOLVABILITY = "cdf_solvability" QUANTILE_SOLVABILITY = "quantile_solvability" - DIFF_CDF_SOLVABILITY = "diff_cdf_solvability" - DIFF_QUANTILE_SOLVABILITY = "diff_quantile_solvability" + DIFF_CDF_SOLVABILITY = "difference_of_cdf_solvability" + DIFF_QUANTILE_SOLVABILITY = "difference_of_quantile_solvability" AREA = "area" BOX = "box" VIOLIN = "violin" TERMINAL_SCATTER = "terminal_scatter" + @staticmethod + def from_str(label: str) -> PlotType: + """Converts a string label to a PlotType enum.""" + # Reverse mapping from string to PlotType enum. + name = label.lower().replace(" ", "_") + inv_plot_type = {pt.value: pt for pt in PlotType} + if name in inv_plot_type: + return inv_plot_type[name] + error_msg = ( + f"Unknown plot type: {label} ({name}). " + f"Must be one of {[pt.value for pt in PlotType]}." + ) + raise ValueError(error_msg) + def bootstrap_procedure( experiments: list[list[ProblemSolver]], @@ -3621,35 +3635,29 @@ def save_plot( plot_name = plot_name + "_unnorm" # Reformat plot_name to be suitable as a string literal. - plot_name = plot_name.replace("\\", "") - plot_name = plot_name.replace("$", "") - plot_name = plot_name.replace(" ", "_") + plot_name = plot_name.replace("\\", "").replace("$", "").replace(" ", "_") # If the plot title is not provided, use the default title. if plot_title is None: plot_title = f"{solver_name}_{problem_name}_{plot_name}" - path_name = plot_dir / plot_title - # Check to make sure file does not override previous images - counter = 0 - while True: - # add extension to path name - extended_path_name = path_name.with_suffix(ext) - - # If file doesn't exist, break out of loop - if not extended_path_name.exists(): - break + # Read in the contents of the plot directory + existing_plots = [path.name for path in list(plot_dir.glob("*"))] + print(f"Existing plots: {existing_plots}") - # If file exists, increment counter and try again + counter = 0 + while (plot_title + ext) in existing_plots: + # If the plot title already exists, append a counter to the filename counter += 1 - path_name = plot_dir / f"{plot_title} ({counter})" + plot_title = f"{plot_title} ({counter})" + extended_path_name = plot_dir / (plot_title + ext) plt.savefig(extended_path_name, bbox_inches="tight") # save plot as pickle if save_as_pickle: fig = plt.gcf() - pickle_path = path_name.with_suffix(".pkl") + pickle_path = extended_path_name.with_suffix(".pkl") with pickle_path.open("wb") as pickle_file: pickle.dump(fig, pickle_file) # Return path_name for use in GUI. diff --git a/simopt/gui/new_experiment_window.py b/simopt/gui/new_experiment_window.py index 068b659eb..f5aa1b606 100644 --- a/simopt/gui/new_experiment_window.py +++ b/simopt/gui/new_experiment_window.py @@ -20,6 +20,7 @@ from simopt.base import Problem, Solver from simopt.data_farming_base import DATA_FARMING_DIR from simopt.experiment_base import ( + PlotType, ProblemSolver, ProblemsSolvers, create_design, @@ -4027,7 +4028,7 @@ def show_plot_options(self, plot_type_tk: tk.StringVar) -> None: subplot_type_options = [ "CDF Solvability", "Quantile Solvability", - "Difference of CDF Solvablility", + "Difference of CDF Solvability", "Difference of Quantile Solvability", ] self.subplot_type_var = tk.StringVar() @@ -4317,7 +4318,7 @@ def enable_ref_solver(self, plot_type_tk: tk.StringVar) -> None: if plot_type in ["CDF Solvability", "Quantile Solvability"]: self.ref_solver_menu.configure(state="disabled") elif plot_type in [ - "Difference of CDF Solvablility", + "Difference of CDF Solvability", "Difference of Quantile Solvability", ]: self.ref_solver_menu.configure(state="normal") @@ -4456,6 +4457,10 @@ def __plot_progress_curve(self) -> None: parameters["Quantile Probability"] = beta parameters["Number Bootstrap Samples"] = n_boot parameters["Confidence Level"] = con_level + # Lookup plot type enum for passing to plotting function + subplot_type_enum: PlotType = PlotType.from_str( + subplot_type.lower() + ) # create new plot for each problem for i in range(n_problems): prob_list = [] @@ -4463,7 +4468,7 @@ def __plot_progress_curve(self) -> None: prob_list.append(solver_group[i]) returned_path = plot_progress_curves( experiments=prob_list, - plot_type=subplot_type, # type: ignore + plot_type=subplot_type_enum, beta=beta, normalize=norm, all_in_one=all_in, @@ -4642,6 +4647,10 @@ def __plot_terminal_progress(self) -> None: parameters["Plot Type"] = subplot_type assert subplot_type in ["box", "violin"] parameters["Normalize Optimality Gaps"] = normalize_str + # Lookup plot type enum for passing to plotting function + subplot_type_enum: PlotType = PlotType.from_str( + subplot_type.lower() + ) # create a new plot for each problem for i in range(n_problems): prob_list = [] @@ -4649,7 +4658,7 @@ def __plot_terminal_progress(self) -> None: prob_list.append(solver_group[i]) returned_path = plot_terminal_progress( experiments=prob_list, - plot_type=subplot_type, # type: ignore + plot_type=subplot_type_enum, all_in_one=all_in, normalize=norm, save_as_pickle=True, @@ -4736,7 +4745,7 @@ def __plot_solvability_profile(self) -> None: subplot_types = { "CDF Solvability": "cdf_solvability", "Quantile Solvability": "quantile_solvability", - "Difference of CDF Solvablility": "diff_cdf_solvability", + "Difference of CDF Solvability": "diff_cdf_solvability", "Difference of Quantile Solvability": "diff_quantile_solvability", } subplot_type = self.subplot_type_var.get() @@ -4768,11 +4777,12 @@ def __plot_solvability_profile(self) -> None: "Difference of Quantile Solvability", ]: parameters["Quantile Probability"] = beta - + # Lookup plot type enum for passing to plotting function + subplot_type_enum = PlotType.from_str(subplot_type) if subplot_type in ["CDF Solvability", "Quantile Solvability"]: returned_path = plot_solvability_profiles( experiments=exp_sublist, - plot_type=plot_input, # type: ignore + plot_type=subplot_type_enum, all_in_one=all_in, n_bootstraps=n_boot, conf_level=con_level, @@ -4792,7 +4802,7 @@ def __plot_solvability_profile(self) -> None: parameters["Reference Solver"] = ref_solver returned_path = plot_solvability_profiles( experiments=exp_sublist, - plot_type=plot_input, # type: ignore + plot_type=subplot_type_enum, all_in_one=all_in, n_bootstraps=n_boot, conf_level=con_level,