diff --git a/docs/changes/1960.feature.md b/docs/changes/1960.feature.md new file mode 100644 index 0000000000..cf9b996326 --- /dev/null +++ b/docs/changes/1960.feature.md @@ -0,0 +1 @@ +Add functionality to overwrite simulation model parameters from the command line using `--overwrite_model_parameters`. diff --git a/src/simtools/applications/derive_incident_angle.py b/src/simtools/applications/derive_incident_angle.py index 000c9e39a3..1ce04134fc 100644 --- a/src/simtools/applications/derive_incident_angle.py +++ b/src/simtools/applications/derive_incident_angle.py @@ -155,6 +155,7 @@ def main(): output_dir, label_with_telescope, debug_plots=app_context.args.get("debug_plots", False), + model_version=app_context.args.get("model_version", None), ) calculator.save_model_parameters(results_by_offset) total = sum(len(t) for t in results_by_offset.values()) diff --git a/src/simtools/applications/derive_psf_parameters.py b/src/simtools/applications/derive_psf_parameters.py index 8d53849ea9..e342717739 100644 --- a/src/simtools/applications/derive_psf_parameters.py +++ b/src/simtools/applications/derive_psf_parameters.py @@ -207,6 +207,7 @@ def main(): site=app_context.args["site"], telescope_name=app_context.args["telescope"], model_version=app_context.args["model_version"], + overwrite_model_parameters=app_context.args.get("overwrite_model_parameters"), ) psf_opt.run_psf_optimization_workflow( diff --git a/src/simtools/applications/derive_pulse_shape_parameters.py b/src/simtools/applications/derive_pulse_shape_parameters.py index ce0b50dbce..c6faf21864 100644 --- a/src/simtools/applications/derive_pulse_shape_parameters.py +++ b/src/simtools/applications/derive_pulse_shape_parameters.py @@ -138,6 +138,7 @@ def main(): model_version=app_context.args["model_version"], site=site, telescope_name=app_context.args["telescope"], + overwrite_model_parameters=app_context.args.get("overwrite_model_parameters"), ) fadc_sum_bins = telescope_model.get_parameter_value("fadc_sum_bins") diff --git a/src/simtools/applications/validate_cumulative_psf.py b/src/simtools/applications/validate_cumulative_psf.py index a3af81473d..381aa6c38b 100644 --- a/src/simtools/applications/validate_cumulative_psf.py +++ b/src/simtools/applications/validate_cumulative_psf.py @@ -133,6 +133,7 @@ def main(): site=app_context.args["site"], telescope_name=app_context.args["telescope"], model_version=app_context.args["model_version"], + overwrite_model_parameters=app_context.args.get("overwrite_model_parameters"), ) if app_context.args.get("overwrite_model_parameters"): diff --git a/src/simtools/applications/validate_optics.py b/src/simtools/applications/validate_optics.py index 8ebd8a7bb7..75ce3b276d 100644 --- a/src/simtools/applications/validate_optics.py +++ b/src/simtools/applications/validate_optics.py @@ -125,6 +125,7 @@ def main(): site=app_context.args["site"], telescope_name=app_context.args["telescope"], model_version=app_context.args["model_version"], + overwrite_model_parameters=app_context.args.get("overwrite_model_parameters"), ) app_context.logger.info( diff --git a/src/simtools/camera/camera_efficiency.py b/src/simtools/camera/camera_efficiency.py index cf6c74078e..ee9a99341e 100644 --- a/src/simtools/camera/camera_efficiency.py +++ b/src/simtools/camera/camera_efficiency.py @@ -40,6 +40,7 @@ def __init__(self, config_data, label): model_version=config_data["model_version"], site=config_data["site"], telescope_name=config_data["telescope"], + overwrite_model_parameters=config_data.get("overwrite_model_parameters"), ) self.output_dir = self.io_handler.get_output_directory() diff --git a/src/simtools/model/model_utils.py b/src/simtools/model/model_utils.py index b426d0c1aa..4529b7958c 100644 --- a/src/simtools/model/model_utils.py +++ b/src/simtools/model/model_utils.py @@ -10,7 +10,12 @@ def initialize_simulation_models( - label, model_version, site, telescope_name, calibration_device_name=None + label, + model_version, + site, + telescope_name, + calibration_device_name=None, + overwrite_model_parameters=None, ): """ Initialize simulation models for a single telescope, site, and calibration device model. @@ -38,11 +43,13 @@ def initialize_simulation_models( telescope_name=telescope_name, model_version=model_version, label=label, + overwrite_model_parameters=overwrite_model_parameters, ) site_model = SiteModel( site=site, model_version=model_version, label=label, + overwrite_model_parameters=overwrite_model_parameters, ) if calibration_device_name is not None: calibration_model = CalibrationModel( @@ -50,6 +57,7 @@ def initialize_simulation_models( calibration_device_model_name=calibration_device_name, model_version=model_version, label=label, + overwrite_model_parameters=overwrite_model_parameters, ) else: calibration_model = None diff --git a/src/simtools/ray_tracing/incident_angles.py b/src/simtools/ray_tracing/incident_angles.py index f620863094..48d7324594 100644 --- a/src/simtools/ray_tracing/incident_angles.py +++ b/src/simtools/ray_tracing/incident_angles.py @@ -81,6 +81,7 @@ def __init__( site=config_data["site"], telescope_name=config_data["telescope"], model_version=config_data["model_version"], + overwrite_model_parameters=config_data.get("overwrite_model_parameters"), ) def _label_suffix(self): diff --git a/src/simtools/ray_tracing/mirror_panel_psf.py b/src/simtools/ray_tracing/mirror_panel_psf.py index d6d63031f7..69946ccd55 100644 --- a/src/simtools/ray_tracing/mirror_panel_psf.py +++ b/src/simtools/ray_tracing/mirror_panel_psf.py @@ -73,6 +73,7 @@ def _define_telescope_model(self, label): site=self.args_dict["site"], telescope_name=self.args_dict["telescope"], model_version=self.args_dict["model_version"], + overwrite_model_parameters=self.args_dict.get("overwrite_model_parameters"), ) if self.args_dict["mirror_list"] is not None: mirror_list_file = gen.find_file( diff --git a/src/simtools/schemas/application_workflow.metaschema.yml b/src/simtools/schemas/application_workflow.metaschema.yml index 361a4d556a..96e60570e6 100644 --- a/src/simtools/schemas/application_workflow.metaschema.yml +++ b/src/simtools/schemas/application_workflow.metaschema.yml @@ -142,6 +142,10 @@ definitions: description: | "Allowed tolerance for floating point comparison." type: number + scaling: + description: | + "Scaling factor to apply to the model parameter value before comparison." + type: number test_simtel_cfg_files: description: | "Reference file used for comparison of sim_telarray configuration files." diff --git a/src/simtools/simtel/simulator_light_emission.py b/src/simtools/simtel/simulator_light_emission.py index 04c0f2559a..47f095343e 100644 --- a/src/simtools/simtel/simulator_light_emission.py +++ b/src/simtools/simtel/simulator_light_emission.py @@ -48,6 +48,7 @@ def __init__(self, light_emission_config, label=None): telescope_name=light_emission_config.get("telescope"), calibration_device_name=light_emission_config.get("light_source"), model_version=light_emission_config.get("model_version"), + overwrite_model_parameters=light_emission_config.get("overwrite_model_parameters"), ) ) self.telescope_model.write_sim_telarray_config_file(additional_models=self.site_model) diff --git a/src/simtools/testing/validate_output.py b/src/simtools/testing/validate_output.py index a9437b5caa..bbc1b5df06 100644 --- a/src/simtools/testing/validate_output.py +++ b/src/simtools/testing/validate_output.py @@ -189,6 +189,7 @@ def _validate_model_parameter_json_file(config, model_parameter_validation): model_parameter["value"], reference_model_parameter[reference_parameter_name]["value"], model_parameter_validation["tolerance"], + model_parameter_validation.get("scaling", 1.0), ) @@ -275,8 +276,32 @@ def compare_json_or_yaml_files(file1, file2, tolerance=1.0e-2): return _comparison -def _compare_value_from_parameter_dict(data1, data2, tolerance=1.0e-5): - """Compare value fields given in different formats.""" +def _compare_value_from_parameter_dict(data_1, data_2, tolerance=1.0e-5, factor_1=1.0): + """ + Compare value fields given in different formats. + + Parameters + ---------- + data_1 : float, int, str, list, numpy.ndarray + First value or collection of values to compare. May be a scalar, + a sequence, a numpy array, or a string representation of a list. + data_2 : float, int, str, list, numpy.ndarray + Second value or collection of values to compare, with the same + allowed formats as ``data_2``. + tolerance : float, optional + Relative tolerance used when comparing numerical values via + ``numpy.allclose``. + factor1 : float, optional + Multiplicative factor applied to ``data_1`` before comparison. This + can be used to account for unit conversions or normalisation + differences between ``data_1`` and ``data_2``. + + Returns + ------- + bool + True if the two values are considered equal within the given + tolerance, False otherwise. + """ def _as_list(value): if isinstance(value, str): @@ -285,12 +310,13 @@ def _as_list(value): return value return [value] - _logger.info(f"Comparing values: {data1} and {data2} (tolerance: {tolerance})") + _logger.info(f"Comparing values: {data_1} and {data_2} (tolerance: {tolerance})") - _as_list_1 = _as_list(data1) - _as_list_2 = _as_list(data2) + _as_list_1 = _as_list(data_1) + _as_list_2 = _as_list(data_2) if isinstance(_as_list_1, str): return _as_list_1 == _as_list_2 + _as_list_1 = np.array(_as_list_1) * factor_1 return np.allclose(_as_list_1, _as_list_2, rtol=tolerance) diff --git a/src/simtools/visualization/plot_incident_angles.py b/src/simtools/visualization/plot_incident_angles.py index 9d2bb7c420..bb134e0928 100644 --- a/src/simtools/visualization/plot_incident_angles.py +++ b/src/simtools/visualization/plot_incident_angles.py @@ -351,6 +351,7 @@ def _plot_component_angles( out_path, bin_width_deg, log, + model_version=None, ): arrays = _gather_angle_arrays(results_by_offset, column, log) if not arrays: @@ -365,6 +366,17 @@ def _plot_component_angles( ax.set_title(f"Incident angle {title_suffix} vs off-axis angle") ax.grid(True, alpha=0.3) ax.legend() + if model_version: + ax.text( + 0.03, + 0.97, + f"Model version: {model_version}", + transform=ax.transAxes, + fontsize=8, + verticalalignment="top", + horizontalalignment="left", + bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.5}, + ) plt.tight_layout() plt.savefig(out_path, dpi=300) plt.close(fig) @@ -378,8 +390,30 @@ def plot_incident_angles( radius_bin_width_m=0.01, debug_plots=False, logger=None, + model_version=None, ): - """Plot overlaid histograms of focal, primary, secondary angles, and primary hit radius.""" + """Plot overlaid histograms of focal, primary, secondary angles, and primary hit radius. + + Parameters + ---------- + results_by_offset : dict + Mapping from off-axis angle to result tables containing angle and radius columns. + output_dir : path-like + Base output directory where the ``plots`` subdirectory will be created. + label : str + Label used to distinguish this set of plots in the output filenames. + bin_width_deg : float, optional + Bin width in degrees for the angle-of-incidence histograms. + radius_bin_width_m : float, optional + Bin width in meters for the primary mirror hit-radius histograms. + debug_plots : bool, optional + If True, generate additional diagnostic plots. + logger : logging.Logger or None, optional + Logger instance to use for messages. If None, a module-level logger is used. + model_version : str or None, optional + Semantic model version identifier to annotate the generated plots. If None, + no model version text is added to the figures. + """ log = logger or logging.getLogger(__name__) if not results_by_offset: log.warning("No results provided for multi-offset plot") @@ -402,6 +436,17 @@ def plot_incident_angles( ax.set_title("Incident angle distribution vs off-axis angle") ax.grid(True, alpha=0.3) ax.legend() + if model_version: + ax.text( + 0.03, + 0.97, + f"Model version: {model_version}", + transform=ax.transAxes, + fontsize=8, + verticalalignment="top", + horizontalalignment="left", + bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.5}, + ) plt.tight_layout() plt.savefig(out_dir / f"incident_angles_multi_{label}.png", dpi=300) plt.close(fig) @@ -414,6 +459,7 @@ def plot_incident_angles( out_path=out_dir / f"incident_angles_primary_multi_{label}.png", bin_width_deg=bin_width_deg, log=log, + model_version=model_version, ) _plot_component_angles( results_by_offset=results_by_offset, @@ -422,6 +468,7 @@ def plot_incident_angles( out_path=out_dir / f"incident_angles_secondary_multi_{label}.png", bin_width_deg=bin_width_deg, log=log, + model_version=model_version, ) # Debug plots diff --git a/tests/integration_tests/config/validate_camera_efficiency_lstn-02_overwrite.yml b/tests/integration_tests/config/validate_camera_efficiency_lstn-02_overwrite.yml new file mode 100644 index 0000000000..7ea72279c4 --- /dev/null +++ b/tests/integration_tests/config/validate_camera_efficiency_lstn-02_overwrite.yml @@ -0,0 +1,22 @@ +--- +applications: +- application: simtools-validate-camera-efficiency + configuration: + model_version: 6.0.2 + output_path: simtools-output + site: North + telescope: LSTN-02 + zenith_angle: 0. + parameter_version: 0.0.99 + write_reference_nsb_rate_as_parameter: true + overwrite_model_parameters: tests/resources/info_test_camera_efficiency_change_LSTN-02.yml + log_level: debug + integration_tests: + - model_parameter_validation: + parameter_file: nsb_pixel_rate/nsb_pixel_rate-0.0.99.json + reference_parameter_name: nsb_pixel_rate + tolerance: 1.e-1 + scaling: 10.0 # camera efficiency changed by factor 0.1 + test_name: LSTN_overwrite +schema_name: application_workflow.metaschema +schema_version: 0.4.0 diff --git a/tests/resources/info_test_camera_efficiency_change_LSTN-02.yml b/tests/resources/info_test_camera_efficiency_change_LSTN-02.yml new file mode 100644 index 0000000000..8b1247181e --- /dev/null +++ b/tests/resources/info_test_camera_efficiency_change_LSTN-02.yml @@ -0,0 +1,12 @@ +--- +model_version: "6.0.2" +model_update: "patch_update" +model_version_history: + - "6.0.2" +description: "test file" +changes: + LSTN-design: + camera_transmission: + version: "2.0.0" + value: 0.1 + unit: null diff --git a/tests/unit_tests/model/test_model_utils.py b/tests/unit_tests/model/test_model_utils.py index d2cf0ec9eb..ddf23ff7e5 100644 --- a/tests/unit_tests/model/test_model_utils.py +++ b/tests/unit_tests/model/test_model_utils.py @@ -47,9 +47,12 @@ def test_initialize_simulation_models(mocker, site, telescope_name): telescope_name=telescope_name, model_version=model_version, label=label, + overwrite_model_parameters=None, ) - mock_site_model.assert_called_once_with(site=site, model_version=model_version, label=label) + mock_site_model.assert_called_once_with( + site=site, model_version=model_version, label=label, overwrite_model_parameters=None + ) mock_tel_model.return_value.export_model_files.assert_called_once() mock_site_model.return_value.export_model_files.assert_called_once() @@ -80,6 +83,7 @@ def test_initialize_simulation_models_with_calibration_device(mocker): calibration_device_model_name=calibration_device_name, model_version=model_version, label=label, + overwrite_model_parameters=None, ) assert calibration_model == mock_cal_model.return_value diff --git a/tests/unit_tests/ray_tracing/test_mirror_panel_psf.py b/tests/unit_tests/ray_tracing/test_mirror_panel_psf.py index b0b7e83850..6483b70086 100644 --- a/tests/unit_tests/ray_tracing/test_mirror_panel_psf.py +++ b/tests/unit_tests/ray_tracing/test_mirror_panel_psf.py @@ -92,6 +92,7 @@ def test_define_telescope_model( site=args_dict["site"], telescope_name=args_dict["telescope"], model_version=args_dict["model_version"], + overwrite_model_parameters=None, ) mock_find_file.assert_not_called() tel.overwrite_model_parameter.assert_not_called() @@ -117,6 +118,7 @@ def test_define_telescope_model( site=args_dict["site"], telescope_name=args_dict["telescope"], model_version=args_dict["model_version"], + overwrite_model_parameters=None, ) mock_find_file.assert_called_once() assert tel.overwrite_model_parameter.call_count == 2 diff --git a/tests/unit_tests/simtel/test_simulator_light_emission.py b/tests/unit_tests/simtel/test_simulator_light_emission.py index 8dccef1b5e..fdcd431f56 100644 --- a/tests/unit_tests/simtel/test_simulator_light_emission.py +++ b/tests/unit_tests/simtel/test_simulator_light_emission.py @@ -1330,6 +1330,7 @@ def test___init__(tmp_test_directory): telescope_name="LSTN-01", calibration_device_name="calibration_device", model_version="6.0.0", + overwrite_model_parameters=None, ) # Verify telescope model config file was written diff --git a/tests/unit_tests/testing/test_validate_output.py b/tests/unit_tests/testing/test_validate_output.py index 08bd599f1f..0e157e8b85 100644 --- a/tests/unit_tests/testing/test_validate_output.py +++ b/tests/unit_tests/testing/test_validate_output.py @@ -652,7 +652,7 @@ def test_validate_model_parameter_json_file(mocker, output_path): mock_collect_data_from_file.assert_called_once_with( Path(output_path) / "test_telescope" / TEST_PARAM_JSON ) - mock_compare_value.assert_called_once_with([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], 1.0e-5) + mock_compare_value.assert_called_once_with([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], 1.0e-5, 1.0) def test_validate_model_parameter_json_file_mismatch(mocker, output_path): @@ -695,7 +695,7 @@ def test_validate_model_parameter_json_file_mismatch(mocker, output_path): mock_collect_data_from_file.assert_called_once_with( Path(output_path) / "test_telescope" / TEST_PARAM_JSON ) - mock_compare_value.assert_called_once_with([1.1, 2.1, 3.1], [1.0, 2.0, 3.0], 1.0e-5) + mock_compare_value.assert_called_once_with([1.1, 2.1, 3.1], [1.0, 2.0, 3.0], 1.0e-5, 1.0) def test_versions_match_semantics(): diff --git a/tests/unit_tests/visualization/test_plot_incident_angles.py b/tests/unit_tests/visualization/test_plot_incident_angles.py index 730bea35cc..e1841e05b2 100644 --- a/tests/unit_tests/visualization/test_plot_incident_angles.py +++ b/tests/unit_tests/visualization/test_plot_incident_angles.py @@ -620,3 +620,146 @@ def test_plot_radius_histograms_covers_continue_lines(tmp_test_directory): # Plot should still be produced thanks to the valid table assert out_path.exists() assert out_path.stat().st_size > 0 + + +def test_plot_incident_angles_with_model_version(tmp_test_directory): + out_dir = Path(tmp_test_directory) / "plots" + t = QTable() + t["angle_incidence_focal"] = np.array([0.1, 0.2, 0.3]) * u.deg + t["angle_incidence_primary"] = np.array([1.0, 1.1, 1.2]) * u.deg + t["angle_incidence_secondary"] = np.array([2.0, 2.1, 2.2]) * u.deg + results = {0.0: t} + pia.plot_incident_angles( + results, + tmp_test_directory, + "version_test", + model_version="1.0.0", + ) + assert (out_dir / "incident_angles_multi_version_test.png").exists() + assert (out_dir / "incident_angles_primary_multi_version_test.png").exists() + assert (out_dir / "incident_angles_secondary_multi_version_test.png").exists() + + +def test_plot_incident_angles_none_logger_uses_module_logger(tmp_test_directory, caplog): + caplog.set_level(logging.WARNING) + results = {} + pia.plot_incident_angles(results, tmp_test_directory, "no_logger", logger=None) + msgs = [r.message for r in caplog.records] + assert any("No results provided for multi-offset plot" in m for m in msgs) + + +def test_plot_incident_angles_custom_bin_widths(tmp_test_directory): + out_dir = Path(tmp_test_directory) / "plots" + t = QTable() + t["angle_incidence_focal"] = np.array([0.05, 0.15, 0.25, 0.35]) * u.deg + t["angle_incidence_primary"] = np.array([0.5, 1.5, 2.5, 3.5]) * u.deg + t["angle_incidence_secondary"] = np.array([1.0, 2.0, 3.0, 4.0]) * u.deg + results = {0.0: t} + pia.plot_incident_angles( + results, + tmp_test_directory, + "custom_bins", + bin_width_deg=0.2, + radius_bin_width_m=0.02, + ) + assert (out_dir / "incident_angles_multi_custom_bins.png").exists() + assert (out_dir / "incident_angles_primary_multi_custom_bins.png").exists() + assert (out_dir / "incident_angles_secondary_multi_custom_bins.png").exists() + + +def test_plot_incident_angles_debug_plots_true(tmp_test_directory): + out_dir = Path(tmp_test_directory) / "plots" + t0 = QTable() + t0["angle_incidence_focal"] = np.array([0.1, 0.2]) * u.deg + t0["angle_incidence_primary"] = np.array([1.0, 1.1]) * u.deg + t0["angle_incidence_secondary"] = np.array([2.0, 2.1]) * u.deg + t0["primary_hit_radius"] = np.array([0.1, 0.12]) * u.m + t0["secondary_hit_radius"] = np.array([0.05, 0.07]) * u.m + t0["primary_hit_x"] = np.array([0.0, 0.1]) * u.m + t0["primary_hit_y"] = np.array([0.0, -0.1]) * u.m + t0["secondary_hit_x"] = np.array([0.02, -0.02]) * u.m + t0["secondary_hit_y"] = np.array([0.03, 0.01]) * u.m + results = {0.0: t0} + pia.plot_incident_angles( + results, + tmp_test_directory, + "dbg_enabled", + debug_plots=True, + ) + assert (out_dir / "incident_angles_multi_dbg_enabled.png").exists() + assert (out_dir / "incident_angles_primary_multi_dbg_enabled.png").exists() + assert (out_dir / "incident_angles_secondary_multi_dbg_enabled.png").exists() + assert (out_dir / "incident_radius_primary_multi_dbg_enabled.png").exists() + assert (out_dir / "incident_radius_secondary_multi_dbg_enabled.png").exists() + assert (out_dir / "incident_primary_radius_vs_angle_multi_dbg_enabled.png").exists() + assert (out_dir / "incident_secondary_radius_vs_angle_multi_dbg_enabled.png").exists() + assert (out_dir / "incident_primary_xy_heatmap_off0_dbg_enabled.png").exists() + assert (out_dir / "incident_secondary_xy_heatmap_off0_dbg_enabled.png").exists() + + +def test_plot_incident_angles_debug_plots_false(tmp_test_directory): + out_dir = Path(tmp_test_directory) / "plots" + t = QTable() + t["angle_incidence_focal"] = np.array([0.1, 0.2]) * u.deg + t["angle_incidence_primary"] = np.array([1.0, 1.1]) * u.deg + t["angle_incidence_secondary"] = np.array([2.0, 2.1]) * u.deg + t["primary_hit_radius"] = np.array([0.1, 0.12]) * u.m + results = {0.0: t} + pia.plot_incident_angles( + results, + tmp_test_directory, + "dbg_disabled", + debug_plots=False, + ) + assert (out_dir / "incident_angles_multi_dbg_disabled.png").exists() + assert not any(out_dir.glob("incident_radius_primary_multi_*.png")) + + +def test_plot_incident_angles_multiple_offsets(tmp_test_directory): + out_dir = Path(tmp_test_directory) / "plots" + t0 = QTable() + t0["angle_incidence_focal"] = np.array([0.1, 0.2]) * u.deg + t0["angle_incidence_primary"] = np.array([1.0, 1.1]) * u.deg + t0["angle_incidence_secondary"] = np.array([2.0, 2.1]) * u.deg + t1 = QTable() + t1["angle_incidence_focal"] = np.array([0.3, 0.4]) * u.deg + t1["angle_incidence_primary"] = np.array([1.5, 1.6]) * u.deg + t1["angle_incidence_secondary"] = np.array([2.5, 2.6]) * u.deg + results = {0.0: t0, 1.0: t1} + pia.plot_incident_angles(results, tmp_test_directory, "multi_offset") + assert (out_dir / "incident_angles_multi_multi_offset.png").exists() + + +def test_plot_incident_angles_no_focal_angles(tmp_test_directory, caplog): + caplog.set_level(logging.WARNING) + out_dir = Path(tmp_test_directory) / "plots" + t = QTable() + t["angle_incidence_primary"] = np.array([1.0, 1.1]) * u.deg + t["angle_incidence_secondary"] = np.array([2.0, 2.1]) * u.deg + results = {0.0: t} + pia.plot_incident_angles(results, tmp_test_directory, "no_focal") + assert not (out_dir / "incident_angles_multi_no_focal.png").exists() + assert (out_dir / "incident_angles_primary_multi_no_focal.png").exists() + assert (out_dir / "incident_angles_secondary_multi_no_focal.png").exists() + + +def test_plot_incident_angles_only_focal_angles(tmp_test_directory): + out_dir = Path(tmp_test_directory) / "plots" + t = QTable() + t["angle_incidence_focal"] = np.array([0.1, 0.2, 0.3]) * u.deg + results = {0.0: t} + pia.plot_incident_angles(results, tmp_test_directory, "focal_only") + assert (out_dir / "incident_angles_multi_focal_only.png").exists() + assert not (out_dir / "incident_angles_primary_multi_focal_only.png").exists() + assert not (out_dir / "incident_angles_secondary_multi_focal_only.png").exists() + + +def test_plot_incident_angles_creates_output_directory(tmp_test_directory): + out_dir = Path(tmp_test_directory) / "plots" + assert not out_dir.exists() + t = QTable() + t["angle_incidence_focal"] = np.array([0.1]) * u.deg + results = {0.0: t} + pia.plot_incident_angles(results, tmp_test_directory, "creates_dir") + assert out_dir.exists() + assert out_dir.is_dir()