diff --git a/.gitignore b/.gitignore index 7ef18ed6a..d77e1145f 100644 --- a/.gitignore +++ b/.gitignore @@ -115,6 +115,4 @@ TAGS *~ ## Colors -**/data/colors_data/.extraction_complete -data/colors_data/filters/ -data/colors_data/stellar_models/ +/data/colors_data/** \ No newline at end of file diff --git a/colors/Makefile b/colors/Makefile index 532b58234..c7345804e 100644 --- a/colors/Makefile +++ b/colors/Makefile @@ -4,11 +4,12 @@ include ../make/defaults-module.mk MODULE_NAME := colors SRCS := public/colors_def.f90 \ + private/colors_utils.f90 \ public/colors_lib.f90 \ private/bolometric.f90 \ + private/colors_iteration.f90 \ private/colors_ctrls_io.f90 \ private/colors_history.f90 \ - private/colors_utils.f90 \ private/hermite_interp.f90 \ private/knn_interp.f90 \ private/linear_interp.f90 \ diff --git a/colors/README.rst b/colors/README.rst index d80216147..889c83e6c 100644 --- a/colors/README.rst +++ b/colors/README.rst @@ -1,18 +1,18 @@ -.. _custom_colors: +.. _colors: ****** Colors ****** -This test suite case demonstrates the functionality of the MESA ``colors`` module, a framework introduced in MESA r25.10.1 for calculating synthetic photometry and bolometric quantities during stellar evolution. +The MESA ``colors`` module calculates synthetic photometry and bolometric quantities during stellar evolution. What is MESA colors? ==================== -MESA colors is a post-processing and runtime module that allows users to generate "observer-ready" data directly from stellar evolution models. Instead of limiting output to theoretical quantities like Luminosity (:math:`L`) and Surface Temperature (:math:`T_{\rm eff}`), the colors module computes: +MESA colors is a post-processing and runtime module that allows users to generate "observer-ready" data directly from stellar evolution models. Instead of limiting output to theoretical quantities like Luminosity (L) and Surface Temperature (T_eff), the colors module computes: -* **Bolometric Magnitude** (:math:`M_{\rm bol}`) -* **Bolometric Flux** (:math:`F_{\rm bol}`) +* **Bolometric Magnitude** (M_bol) +* **Bolometric Flux** (F_bol) * **Synthetic Magnitudes** in specific photometric filters (e.g., Johnson V, Gaia G, 2MASS J). This bridges the gap between theoretical evolutionary tracks and observational color-magnitude diagrams (CMDs). @@ -20,9 +20,7 @@ This bridges the gap between theoretical evolutionary tracks and observational c How does the MESA colors module work? ===================================== -The module operates by coupling the stellar structure model with pre-computed grids of stellar atmospheres. - -1. **Interpolation**: At each timestep, the module takes the star's current surface parameters—Effective Temperature (:math:`T_{\rm eff}`), Surface Gravity (:math:`\log g`), and Metallicity ([M/H])—and queries a user-specified library of stellar atmospheres (defined in ``stellar_atm``). It interpolates within this grid to construct a specific Spectral Energy Distribution (SED) for the stars current features. +1. **Interpolation**: At each timestep, the module takes the star's current surface parameters—Effective Temperature (T_eff), Surface Gravity (log g), and Metallicity ([M/H])—and queries a user-specified library of stellar atmospheres (defined in ``stellar_atm``). It interpolates within this grid to construct a specific Spectral Energy Distribution (SED) for the star's current parameters. 2. **Convolution**: This specific SED is then convolved with filter transmission curves (defined in ``instrument``) to calculate the flux passing through each filter. @@ -33,129 +31,197 @@ Inlist Options & Parameters The colors module is controlled via the ``&colors`` namelist. Below is a detailed guide to the key parameters. +use_colors +---------- + +**Default:** ``.false.`` + +Master switch for the module. Must be set to ``.true.`` to enable any photometric output. + +**Example:** + +.. code-block:: fortran + + use_colors = .true. + + instrument ---------- -**Default:** `'/data/colors_data/filters/Generic/Johnson'` +**Default:** ``'data/colors_data/filters/Generic/Johnson'`` + +Path to the filter instrument directory, structured as ``facility/instrument``. + +* The directory must contain an index file with the same name as the instrument + (e.g., ``Johnson``), listing one filter filename per line. +* The module loads every ``.dat`` transmission curve listed in that index and + creates a corresponding history column for each. -This points to the directory containing the filter transmission curves you wish to use. The path must be structured as ``facility/instrument``. +.. rubric:: Note on paths -* The directory must contain a file named after the instrument (e.g., ``Johnson``) which acts as an index. -* The module will read every ``.dat`` file listed in that directory and create a corresponding history column for it. +All path parameters (``instrument``, ``stellar_atm``, ``vega_sed``) are resolved +using the same logic: + +* ``'data/colors_data/...'`` — no leading slash; ``$MESA_DIR`` is prepended. This + is the recommended form for all standard data paths. +* ``'/absolute/path/...'`` — tested on disk first; if found, used as-is. If not + found, ``$MESA_DIR`` is prepended (preserves backwards compatibility). +* ``'./local/path/...'`` or ``'../up/one/...'`` — used exactly as supplied, + relative to the MESA working directory. **Example:** .. code-block:: fortran - instrument = '/data/colors_data/filters/GAIA/GAIA' + instrument = 'data/colors_data/filters/GAIA/GAIA' stellar_atm ----------- -**Default:** `'/data/colors_data/stellar_models/Kurucz2003all/'` +**Default:** ``'data/colors_data/stellar_models/Kurucz2003all/'`` -Specifies the path to the directory containing the grid of stellar atmosphere models. This directory must contain: +Path to the directory containing the grid of stellar atmosphere models. Paths may be relative to ``$MESA_DIR``, relative to the working directory, or absolute. This directory must contain: -1. **lookup_table.csv**: A map linking filenames to physical parameters (:math:`T_{\rm eff}`, :math:`\log g`, [M/H]). +1. **lookup_table.csv**: A map linking filenames to physical parameters (T_eff, log g, [M/H]). 2. **SED files**: The actual spectra (text or binary format). 3. **flux_cube.bin**: (Optional but recommended) A binary cube for rapid interpolation. -The module queries this grid using the star's current parameters. If the star evolves outside the grid boundaries, the module may clamp to the nearest edge or extrapolate, depending on internal settings. +The module queries this grid using the star's current parameters. If the star evolves outside the grid boundaries, the module will clamp to the nearest edge. **Example:** .. code-block:: fortran - stellar_atm = '/data/colors_data/stellar_models/sg-SPHINX/' + stellar_atm = 'data/colors_data/stellar_models/sg-SPHINX/' distance -------- -**Default:** `3.0857d19` (10 parsecs in cm) +**Default:** ``3.0857d19`` (10 parsecs in cm) -The distance to the star in centimeters. +The distance to the star in centimetres, used to convert surface flux to observed flux. -* This value is used to convert surface flux to observed flux. -* **Default Behavior:** It defaults to 10 parsecs (:math:`3.0857 \times 10^{19}` cm), resulting in **Absolute Magnitudes**. -* **Custom Usage:** You can set this to a specific source distance (e.g., distance to Betelgeuse) to calculate Apparent Magnitudes. +* **Default Behaviour:** At 10 parsecs (3.0857 * 10^19 cm) the output is **Absolute Magnitudes**. +* **Custom Usage:** Set this to a specific source distance to calculate Apparent Magnitudes. **Example:** .. code-block:: fortran - distance = 5.1839d20 + distance = 5.1839d20 + make_csv -------- -**Default:** `.false.` +**Default:** ``.false.`` If set to ``.true.``, the module exports the full calculated SED at every profile interval. * **Destination:** Files are saved to the directory defined by ``colors_results_directory``. * **Format:** CSV files containing Wavelength vs. Flux. -* **Use Case:** useful for debugging or plotting the full spectrum of the star at a specific age. +* **Use Case:** Useful for debugging or plotting the full spectrum of the star at a specific evolutionary age. **Example:** .. code-block:: fortran - make_csv = .true. + make_csv = .true. + colors_results_directory ------------------------ -**Default:** `'SED'` +**Default:** ``'SED'`` -The folder where csv files (if ``make_csv = .true.``) and other debug outputs are saved. +The folder where CSV files (if ``make_csv = .true.``) and other outputs are saved. **Example:** .. code-block:: fortran - colors_results_directory = 'sed' + colors_results_directory = 'sed' mag_system ---------- -**Default:** `'Vega'` +**Default:** ``'Vega'`` Defines the zero-point system for magnitude calculations. Options are: * ``'AB'``: Based on a flat spectral flux density of 3631 Jy. * ``'ST'``: Based on a flat spectral flux density per unit wavelength. -* ``'Vega'``: Calibrated such that the star Vega has magnitude 0 in all bands. +* ``'Vega'``: Calibrated such that Vega has magnitude 0 in all bands. **Example:** .. code-block:: fortran - mag_system = 'AB' + mag_system = 'AB' vega_sed -------- -**Default:** `'/data/colors_data/stellar_models/vega_flam.csv'` +**Default:** ``'data/colors_data/stellar_models/vega_flam.csv'`` + +Required only if ``mag_system = 'Vega'``. Points to the reference SED file for Vega, used to compute photometric zero-points. Paths may be relative to ``$MESA_DIR``, relative to the working directory, or absolute. + +**Example:** + +.. code-block:: fortran + + vega_sed = '/path/to/my/vega_SED.csv' + +sed_per_model +------------- + +**Default:** ``.false.`` + +Requires ``make_csv = .true.``. If set to ``.true.``, each exported SED file is stamped with the model number, preserving one SED file per model rather than overwriting a single file. + +.. warning:: + + Enabling this feature will cause the ``colors_results_directory`` to grow very rapidly. Do not enable it without first ensuring you have sufficient storage. + +* **Destination:** Files are saved to the directory defined by ``colors_results_directory``. +* **Format:** CSV files containing Wavelength vs. Flux, with the model number as a filename suffix. +* **Use Case:** Useful for tracking the full SED evolution of the star over time. + +**Example:** + +.. code-block:: fortran + + sed_per_model = .true. + -Required only if ``mag_system = 'Vega'``. This points to the reference SED file for Vega. The default path points to a file provided with the MESA data distribution. +colors_per_newton_step +---------------------- + +**Default:** ``.false.`` + +If set to ``.true.``, the colors module computes synthetic photometry at every Newton iteration within each timestep, rather than only once per converged model. This is useful for studying rapid stellar variability or evolutionary phases where the stellar parameters change significantly within a single timestep (e.g., thermal pulses, shell flashes). + +.. warning:: + + Enabling this feature substantially increases the computational cost of the run, as photometric calculations are performed multiple times per timestep. It should only be used when sub-timestep resolution is scientifically required. **Example:** .. code-block:: fortran - vega_sed = '/another/file/for/vega_SED.csv' + colors_per_newton_step = .true. Data Preparation (SED_Tools) ============================ The ``colors`` module requires pre-processed stellar atmospheres and filter -profiles organized in a very specific directory structure. To automate this +profiles organised in a specific directory structure. To automate this entire workflow, we provide the dedicated repository: **Repository:** `SED_Tools `_ @@ -167,15 +233,14 @@ filter transmission curves from the following public archives: * `MAST BOSZ Stellar Atmosphere Library `_ * `MSG / Townsend Atmosphere Grids `_ -These sources provide heterogeneous formats and file organizations. SED_Tools -standardizes them into the exact structure required by MESA: +These sources provide heterogeneous formats and file organisations. SED_Tools +standardises them into the exact structure required by MESA: * ``lookup_table.csv`` -* Raw SED files (text or/and HDF5) +* Raw SED files (text and/or HDF5) * ``flux_cube.bin`` (binary cube for fast interpolation) * Filter index files and ``*.dat`` transmission curves - SED_Tools produces: .. code-block:: text @@ -207,19 +272,20 @@ This server provides a live view of: Defaults Reference ================== -Below are the default values for the colors module parameters as defined in ``colors.defaults``. These are used if you do not override them in your inlist. +Below are the default values for all user-facing ``colors`` module parameters as defined in ``colors.defaults``. .. code-block:: fortran use_colors = .false. - instrument = '/data/colors_data/filters/Generic/Johnson' - vega_sed = '/data/colors_data/stellar_models/vega_flam.csv' - stellar_atm = '/data/colors_data/stellar_models/Kurucz2003all/' + instrument = 'data/colors_data/filters/Generic/Johnson' + stellar_atm = 'data/colors_data/stellar_models/Kurucz2003all/' + vega_sed = 'data/colors_data/stellar_models/vega_flam.csv' distance = 3.0857d19 ! 10 parsecs in cm (Absolute Magnitude) make_csv = .false. colors_results_directory = 'SED' mag_system = 'Vega' - vega_sed = '/data/colors_data/stellar_models/vega_flam.csv' + sed_per_model = .false. + colors_per_newton_step = .false. Visual Summary of Data Flow =========================== @@ -237,8 +303,8 @@ Visual Summary of Data Flow | 1. Query Stellar Atmosphere Grid with input model | | 2. Interpolate grid to construct specific SED | | 3. Convolve SED with filters to generate band flux | - | 2. Apply distance flux dilution to generate bolometric flux -> Flux_bol | - | 4. Apply zero point (Vega/AB/ST) to generate magnitudes | + | 4. Apply distance flux dilution to generate bolometric flux -> Flux_bol | + | 5. Apply zero point (Vega/AB/ST) to generate magnitudes | | (Both bolometric and per filter) | +-------------------------------------------------------------------------+ | diff --git a/colors/defaults/colors.defaults b/colors/defaults/colors.defaults index 0bdb1719a..62cc0a761 100644 --- a/colors/defaults/colors.defaults +++ b/colors/defaults/colors.defaults @@ -1,43 +1,80 @@ ! ``colors`` module controls ! ========================== - ! The MESA/colors parameters are given default values here. - ! Colors User Parameters ! ---------------------- ! ``use_colors`` ! ~~~~~~~~~~~~~~ + ! Set to .true. to enable bolometric and synthetic photometry output. + ! ``instrument`` ! ~~~~~~~~~~~~~~ - ! ``vega_sed`` - ! ~~~~~~~~~~~~ + ! Path to the filter instrument directory (structured as facility/instrument). + ! The directory must contain an index file named after the instrument and + ! one .dat transmission curve per filter. Each filter becomes a history column. + ! ``stellar_atm`` ! ~~~~~~~~~~~~~~~ + ! Path to the stellar atmosphere model grid directory. + ! Must contain: lookup_table.csv, SED files, and optionally flux_cube.bin. + + ! ``vega_sed`` + ! ~~~~~~~~~~~~ + ! Path to the Vega reference SED file. Required when mag_system = 'Vega'. + ! Used to compute photometric zero-points for all filters. + ! ``distance`` ! ~~~~~~~~~~~~ + ! Distance to the star in cm. Determines whether output magnitudes are + ! absolute (default: 10 pc = 3.0857e19 cm) or apparent (set to source distance). + ! ``make_csv`` ! ~~~~~~~~~~~~ + ! If .true., exports the full SED as a CSV file at every profile interval. + ! Files are written to colors_results_directory. + ! ``colors_results_directory`` ! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ! Directory where CSV outputs (make_csv, sed_per_model) are saved. + ! ``mag_system`` ! ~~~~~~~~~~~~~~ + ! Photometric zero-point system. Options: 'Vega', 'AB', 'ST'. + ! Vega : zero-point set so Vega has magnitude 0 in all bands. + ! AB : zero-point based on flat f_nu = 3631 Jy. + ! ST : zero-point based on flat f_lambda per unit wavelength. - ! If ``use_colors`` is true, the colors module is turned on, which will calculate - ! bolometric and synthetic magnitudes by interpolating stellar atmosphere model grids and convolving with photometric filter transmission curves. - ! Vega SED for Vega photometric system is used for photometric zero points. - ! Stellar distance is given in cm. + ! ``sed_per_model`` + ! ~~~~~~~~~~~~~~~~~ + ! Requires make_csv = .true. If .true., each SED output file is suffixed with + ! the model number, preserving the full SED history rather than overwriting. + ! WARNING: can produce very large numbers of files. Ensure adequate storage. + + ! ``colors_per_newton_step`` + ! ~~~~~~~~~~~~~~~~~~~~~~~~~~ + ! If .true., computes synthetic photometry at every Newton iteration within + ! each timestep rather than once per converged model. Useful for capturing + ! rapid parameter changes (e.g., thermal pulses, shell flashes). + ! WARNING: substantially increases computational cost. Only enable when + ! sub-timestep resolution is scientifically required. + + ! If ``use_colors`` is true, the colors module computes bolometric and synthetic + ! magnitudes by interpolating stellar atmosphere model grids and convolving with + ! photometric filter transmission curves. Output is appended to history.data. ! :: use_colors = .false. - instrument = '/data/colors_data/filters/Generic/Johnson' - stellar_atm = '/data/colors_data/stellar_models/Kurucz2003all/' - distance = 3.0857d19 + instrument = 'data/colors_data/filters/Generic/Johnson' + stellar_atm = 'data/colors_data/stellar_models/Kurucz2003all/' + vega_sed = 'data/colors_data/stellar_models/vega_flam.csv' + distance = 3.0857d19 ! 10 parsecs in cm -> Absolute Magnitudes make_csv = .false. colors_results_directory = 'SED' mag_system = 'Vega' - vega_sed = '/data/colors_data/stellar_models/vega_flam.csv' + sed_per_model = .false. + colors_per_newton_step = .false. ! Extra inlist controls @@ -46,7 +83,6 @@ ! One can split a colors inlist into pieces using the following parameters. ! It works recursively, so the extras can read extras too. - ! ``read_extra_colors_inlist(1..5)`` ! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ! ``extra_colors_inlist_name(1..5)`` diff --git a/colors/private/bolometric.f90 b/colors/private/bolometric.f90 index 6a8c5b36a..e8291306c 100644 --- a/colors/private/bolometric.f90 +++ b/colors/private/bolometric.f90 @@ -20,6 +20,7 @@ module bolometric use const_def, only: dp + use colors_def, only: Colors_General_Info use colors_utils, only: romberg_integration use hermite_interp, only: construct_sed_hermite use linear_interp, only: construct_sed_linear @@ -32,76 +33,59 @@ module bolometric contains - !**************************** - ! Calculate Bolometric Photometry Using Multiple SEDs - ! Accepts cached lookup table data instead of loading from file - !**************************** - subroutine calculate_bolometric(teff, log_g, metallicity, R, d, bolometric_magnitude, & - bolometric_flux, wavelengths, fluxes, sed_filepath, interpolation_radius, & - lu_file_names, lu_teff, lu_logg, lu_meta) + ! rq carries cached lookup table data and the preloaded flux cube (if available) + subroutine calculate_bolometric(rq, teff, log_g, metallicity, R, d, bolometric_magnitude, & + bolometric_flux, wavelengths, fluxes, sed_filepath, interpolation_radius) + type(Colors_General_Info), intent(inout) :: rq real(dp), intent(in) :: teff, log_g, metallicity, R, d character(len=*), intent(in) :: sed_filepath real(dp), intent(out) :: bolometric_magnitude, bolometric_flux, interpolation_radius real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes - ! Cached lookup table data (passed in from colors_settings) - character(len=100), intent(in) :: lu_file_names(:) - real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:) - character(len=32) :: interpolation_method interpolation_method = 'Hermite' ! or 'Linear' / 'KNN' later - ! Quantify how far (teff, log_g, metallicity) is from the grid points + ! how far (teff, log_g, metallicity) is from the nearest grid point interpolation_radius = compute_interp_radius(teff, log_g, metallicity, & - lu_teff, lu_logg, lu_meta) + rq%lu_teff, rq%lu_logg, rq%lu_meta) select case (interpolation_method) case ('Hermite', 'hermite', 'HERMITE') - call construct_sed_hermite(teff, log_g, metallicity, R, d, lu_file_names, & - lu_teff, lu_logg, lu_meta, sed_filepath, & - wavelengths, fluxes) + call construct_sed_hermite(rq, teff, log_g, metallicity, R, d, & + sed_filepath, wavelengths, fluxes) case ('Linear', 'linear', 'LINEAR') - call construct_sed_linear(teff, log_g, metallicity, R, d, lu_file_names, & - lu_teff, lu_logg, lu_meta, sed_filepath, & - wavelengths, fluxes) + call construct_sed_linear(rq, teff, log_g, metallicity, R, d, & + sed_filepath, wavelengths, fluxes) case ('KNN', 'knn', 'Knn') - call construct_sed_knn(teff, log_g, metallicity, R, d, lu_file_names, & - lu_teff, lu_logg, lu_meta, sed_filepath, & - wavelengths, fluxes) + call construct_sed_knn(rq, teff, log_g, metallicity, R, d, & + sed_filepath, wavelengths, fluxes) case default - ! Fallback: Hermite - call construct_sed_hermite(teff, log_g, metallicity, R, d, lu_file_names, & - lu_teff, lu_logg, lu_meta, sed_filepath, & - wavelengths, fluxes) + ! fallback: hermite + call construct_sed_hermite(rq, teff, log_g, metallicity, R, d, & + sed_filepath, wavelengths, fluxes) end select - ! Calculate bolometric flux and magnitude call calculate_bolometric_phot(wavelengths, fluxes, bolometric_magnitude, bolometric_flux) end subroutine calculate_bolometric - !**************************** - ! Calculate Bolometric Magnitude and Flux - !**************************** subroutine calculate_bolometric_phot(wavelengths, fluxes, bolometric_magnitude, bolometric_flux) real(dp), dimension(:), intent(inout) :: wavelengths, fluxes real(dp), intent(out) :: bolometric_magnitude, bolometric_flux integer :: i - ! Validate inputs and replace invalid values with 0 + ! zero out any invalid flux/wavelength values do i = 1, size(wavelengths) - 1 if (wavelengths(i) <= 0.0d0 .or. fluxes(i) < 0.0d0) then fluxes(i) = 0.0d0 end if end do - ! Integrate to get bolometric flux call romberg_integration(wavelengths, fluxes, bolometric_flux) - ! Validate and calculate magnitude if (bolometric_flux <= 0.0d0) then print *, "Error: Flux integration resulted in non-positive value." bolometric_magnitude = 99.0d0 @@ -113,22 +97,17 @@ subroutine calculate_bolometric_phot(wavelengths, fluxes, bolometric_magnitude, bolometric_magnitude = flux_to_magnitude(bolometric_flux) end subroutine calculate_bolometric_phot - !**************************** - ! Convert Flux to Magnitude - !**************************** real(dp) function flux_to_magnitude(flux) real(dp), intent(in) :: flux if (flux <= 0.0d0) then print *, "Error: Flux must be positive to calculate magnitude." flux_to_magnitude = 99.0d0 else - flux_to_magnitude = -2.5d0 * log10(flux) + flux_to_magnitude = -2.5d0*log10(flux) end if end function flux_to_magnitude - !-------------------------------------------------------------------- - ! Scalar metric: distance to nearest grid point in normalized space - !-------------------------------------------------------------------- + ! scalar metric: distance to nearest grid point in normalized space real(dp) function compute_interp_radius(teff, log_g, metallicity, & lu_teff, lu_logg, lu_meta) @@ -145,7 +124,7 @@ real(dp) function compute_interp_radius(teff, log_g, metallicity, & logical :: use_teff, use_logg, use_meta real(dp), parameter :: eps = 1.0d-12 - ! Detect dummy columns (entire axis is 0 or ±999) + ! detect dummy columns (entire axis is 0 or ±999) use_teff = .not. (all(lu_teff == 0.0d0) .or. & all(lu_teff == 999.0d0) .or. & all(lu_teff == -999.0d0)) @@ -158,29 +137,28 @@ real(dp) function compute_interp_radius(teff, log_g, metallicity, & all(lu_meta == 999.0d0) .or. & all(lu_meta == -999.0d0)) - ! Compute min/max for valid axes if (use_teff) then teff_min = minval(lu_teff) teff_max = maxval(lu_teff) teff_range = max(teff_max - teff_min, eps) - norm_teff = (teff - teff_min) / teff_range + norm_teff = (teff - teff_min)/teff_range end if if (use_logg) then logg_min = minval(lu_logg) logg_max = maxval(lu_logg) logg_range = max(logg_max - logg_min, eps) - norm_logg = (log_g - logg_min) / logg_range + norm_logg = (log_g - logg_min)/logg_range end if if (use_meta) then meta_min = minval(lu_meta) meta_max = maxval(lu_meta) meta_range = max(meta_max - meta_min, eps) - norm_meta = (metallicity - meta_min) / meta_range + norm_meta = (metallicity - meta_min)/meta_range end if - ! Find minimum distance to any grid point + ! find minimum distance to any grid point d_min = huge(1.0d0) n = size(lu_teff) @@ -188,17 +166,17 @@ real(dp) function compute_interp_radius(teff, log_g, metallicity, & d = 0.0d0 if (use_teff) then - grid_teff = (lu_teff(i) - teff_min) / teff_range + grid_teff = (lu_teff(i) - teff_min)/teff_range d = d + (norm_teff - grid_teff)**2 end if if (use_logg) then - grid_logg = (lu_logg(i) - logg_min) / logg_range + grid_logg = (lu_logg(i) - logg_min)/logg_range d = d + (norm_logg - grid_logg)**2 end if if (use_meta) then - grid_meta = (lu_meta(i) - meta_min) / meta_range + grid_meta = (lu_meta(i) - meta_min)/meta_range d = d + (norm_meta - grid_meta)**2 end if diff --git a/colors/private/colors_ctrls_io.f90 b/colors/private/colors_ctrls_io.f90 index dc560be8e..3c35c841e 100644 --- a/colors/private/colors_ctrls_io.f90 +++ b/colors/private/colors_ctrls_io.f90 @@ -35,11 +35,13 @@ module colors_ctrls_io character(len=256) :: vega_sed character(len=256) :: stellar_atm character(len=256) :: colors_results_directory - character(len=32) :: mag_system + character(len=256) :: mag_system real(dp) :: distance logical :: make_csv + logical :: sed_per_model logical :: use_colors + logical :: colors_per_newton_step namelist /colors/ & instrument, & @@ -47,15 +49,16 @@ module colors_ctrls_io stellar_atm, & distance, & make_csv, & + sed_per_model, & mag_system, & colors_results_directory, & use_colors, & + colors_per_newton_step, & read_extra_colors_inlist, & extra_colors_inlist_name contains -! read a "namelist" file and set parameters subroutine read_namelist(handle, inlist, ierr) integer, intent(in) :: handle character(len=*), intent(in) :: inlist @@ -148,7 +151,6 @@ end subroutine set_default_controls subroutine store_controls(rq, ierr) type(Colors_General_Info), pointer, intent(inout) :: rq - integer :: i integer, intent(out) :: ierr rq%instrument = instrument @@ -156,8 +158,10 @@ subroutine store_controls(rq, ierr) rq%stellar_atm = stellar_atm rq%distance = distance rq%make_csv = make_csv + rq%sed_per_model = sed_per_model rq%colors_results_directory = colors_results_directory rq%use_colors = use_colors + rq%colors_per_newton_step = colors_per_newton_step rq%mag_system = mag_system end subroutine store_controls @@ -192,8 +196,10 @@ subroutine set_controls_for_writing(rq) stellar_atm = rq%stellar_atm distance = rq%distance make_csv = rq%make_csv + sed_per_model = rq%sed_per_model colors_results_directory = rq%colors_results_directory use_colors = rq%use_colors + colors_per_newton_step = rq%colors_per_newton_step mag_system = rq%mag_system end subroutine set_controls_for_writing @@ -211,18 +217,17 @@ subroutine get_colors_controls(rq, name, val, ierr) ierr = 0 - ! First save current controls + ! save current controls call set_controls_for_writing(rq) - ! Write namelist to temporary file + ! write namelist to temporary file open (newunit=iounit, status='scratch') write (iounit, nml=colors) rewind (iounit) - ! Namelists get written in capitals + ! namelists get written in capitals upper_name = trim(StrUpCase(name))//'=' val = '' - ! Search for name inside namelist do read (iounit, '(A)', iostat=iostat) str ind = index(trim(str), trim(upper_name)) @@ -250,19 +255,17 @@ subroutine set_colors_controls(rq, name, val, ierr) ierr = 0 - ! First save current colors_controls + ! save current controls call set_controls_for_writing(rq) tmp = '' tmp = '&colors '//trim(name)//'='//trim(val)//' /' - ! Load into namelist read (tmp, nml=colors) - ! Add to colors call store_controls(rq, ierr) if (ierr /= 0) return end subroutine set_colors_controls -end module colors_ctrls_io +end module colors_ctrls_io \ No newline at end of file diff --git a/colors/private/colors_history.f90 b/colors/private/colors_history.f90 index e65c699e9..5537885c8 100644 --- a/colors/private/colors_history.f90 +++ b/colors/private/colors_history.f90 @@ -19,10 +19,10 @@ module colors_history - use const_def, only: dp, mesa_dir + use const_def, only: dp use utils_lib, only: mesa_error use colors_def, only: Colors_General_Info, get_colors_ptr, num_color_filters, color_filter_names - use colors_utils, only: remove_dat + use colors_utils, only: remove_dat, resolve_path use bolometric, only: calculate_bolometric use synthetic, only: calculate_synthetic @@ -39,7 +39,7 @@ integer function how_many_colors_history_columns(colors_handle) call get_colors_ptr(colors_handle, colors_settings, ierr) if (ierr /= 0) then write (*, *) 'failed in colors_ptr' - num_cols = 0 + how_many_colors_history_columns = 0 return end if @@ -53,19 +53,21 @@ integer function how_many_colors_history_columns(colors_handle) end function how_many_colors_history_columns subroutine data_for_colors_history_columns( & - t_eff, log_g, R, metallicity, & + t_eff, log_g, R, metallicity, model_number, & colors_handle, n, names, vals, ierr) real(dp), intent(in) :: t_eff, log_g, R, metallicity integer, intent(in) :: colors_handle, n character(len=80) :: names(n) real(dp) :: vals(n) integer, intent(out) :: ierr + integer, intent(in) :: model_number type(Colors_General_Info), pointer :: cs ! colors_settings integer :: i, filter_offset real(dp) :: d, bolometric_magnitude, bolometric_flux, interpolation_radius real(dp) :: zero_point - character(len=256) :: sed_filepath, filter_name + character(len=256) :: sed_filepath + character(len=80) :: filter_name logical :: make_sed real(dp), dimension(:), allocatable :: wavelengths, fluxes @@ -90,14 +92,12 @@ subroutine data_for_colors_history_columns( & end if d = cs%distance - sed_filepath = trim(mesa_dir)//cs%stellar_atm + sed_filepath = trim(resolve_path(cs%stellar_atm)) make_sed = cs%make_csv - ! Calculate bolometric magnitude using cached lookup table - call calculate_bolometric(t_eff, log_g, metallicity, R, d, & + call calculate_bolometric(cs, t_eff, log_g, metallicity, R, d, & bolometric_magnitude, bolometric_flux, wavelengths, fluxes, & - sed_filepath, interpolation_radius, & - cs%lu_file_names, cs%lu_teff, cs%lu_logg, cs%lu_meta) + sed_filepath, interpolation_radius) names(1) = "Mag_bol" vals(1) = bolometric_magnitude @@ -113,7 +113,7 @@ subroutine data_for_colors_history_columns( & names(i + filter_offset) = filter_name if (t_eff >= 0 .and. metallicity >= 0) then - ! Select precomputed zero-point based on magnitude system + ! pick the precomputed zero-point for the requested mag system select case (trim(cs%mag_system)) case ('VEGA', 'Vega', 'vega') zero_point = cs%filters(i)%vega_zero_point @@ -126,14 +126,15 @@ subroutine data_for_colors_history_columns( & zero_point = -1.0_dp end select - ! Calculate synthetic magnitude using cached filter data and precomputed zero-point vals(i + filter_offset) = calculate_synthetic(t_eff, log_g, metallicity, ierr, & wavelengths, fluxes, & cs%filters(i)%wavelengths, & cs%filters(i)%transmission, & zero_point, & color_filter_names(i), & - make_sed, cs%colors_results_directory) + make_sed, cs%sed_per_model, & + cs%colors_results_directory, model_number) + if (ierr /= 0) vals(i + filter_offset) = -1.0_dp else vals(i + filter_offset) = -1.0_dp @@ -145,9 +146,9 @@ subroutine data_for_colors_history_columns( & call mesa_error(__FILE__, __LINE__, 'colors: data_for_colors_history_columns array size mismatch') end if - ! Clean up allocated arrays from calculate_bolometric - if (allocated(wavelengths)) deallocate(wavelengths) - if (allocated(fluxes)) deallocate(fluxes) + ! clean up + if (allocated(wavelengths)) deallocate (wavelengths) + if (allocated(fluxes)) deallocate (fluxes) end subroutine data_for_colors_history_columns diff --git a/colors/private/colors_iteration.f90 b/colors/private/colors_iteration.f90 new file mode 100644 index 000000000..16cc5dca1 --- /dev/null +++ b/colors/private/colors_iteration.f90 @@ -0,0 +1,187 @@ +! *********************************************************************** +! +! Copyright (C) 2025 Niall Miller & The MESA Team +! +! This program is free software: you can redistribute it and/or modify +! it under the terms of the GNU Lesser General Public License +! as published by the Free Software Foundation, +! either version 3 of the License, or (at your option) any later version. +! +! This program is distributed in the hope that it will be useful, +! but WITHOUT ANY WARRANTY; without even the implied warranty of +! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +! See the GNU Lesser General Public License for more details. +! +! You should have received a copy of the GNU Lesser General Public License +! along with this program. If not, see . +! +! *********************************************************************** + +module colors_iteration + + use const_def, only: dp + use colors_def + use colors_utils, only: remove_dat, resolve_path + use bolometric, only: calculate_bolometric + use synthetic, only: calculate_synthetic + + implicit none + private + public :: write_iteration_colors, open_iteration_file, close_iteration_file + +contains + + subroutine open_iteration_file(colors_handle, ierr) + integer, intent(in) :: colors_handle + integer, intent(out) :: ierr + type(Colors_General_Info), pointer :: cs + character(len=512) :: filename + character(len=100) :: filter_name + integer :: i + + ierr = 0 + call get_colors_ptr(colors_handle, cs, ierr) + if (ierr /= 0) return + + if (cs%iteration_file_open) return ! already open + + filename = trim(cs%colors_results_directory)//'/iteration_colors.data' + call execute_command_line('mkdir -p "'//trim(cs%colors_results_directory)//'"', wait=.true.) + open (newunit=cs%iteration_output_unit, file=trim(filename), & + status='replace', action='write', iostat=ierr) + if (ierr /= 0) then + write (*, *) 'Error opening iteration colors file: ', trim(filename) + return + end if + + ! Write header + write (cs%iteration_output_unit, '(a)', advance='no') & + '# model iter star_age dt Teff' + write (cs%iteration_output_unit, '(a)', advance='no') & + ' log_g R Mag_bol Flux_bol' + do i = 1, num_color_filters + filter_name = trim(remove_dat(color_filter_names(i))) + write (cs%iteration_output_unit, '(2x,a14)', advance='no') trim(filter_name) + end do + write (cs%iteration_output_unit, *) + + cs%iteration_file_open = .true. + + end subroutine open_iteration_file + + subroutine write_iteration_colors( & + colors_handle, model_number, iter, star_age, dt, & + t_eff, log_g, R, metallicity, ierr) + + integer, intent(in) :: colors_handle, model_number, iter + real(dp), intent(in) :: star_age, dt, t_eff, log_g, R, metallicity + integer, intent(out) :: ierr + + type(Colors_General_Info), pointer :: cs + real(dp) :: bolometric_magnitude, bolometric_flux, interpolation_radius + real(dp) :: magnitude, d, zero_point + character(len=256) :: sed_filepath + real(dp), dimension(:), allocatable :: wavelengths, fluxes + integer :: i, iounit + logical :: make_sed + + ierr = 0 + call get_colors_ptr(colors_handle, cs, ierr) + if (ierr /= 0) return + + ! check if per-iteration colors is enabled + if (.not. cs%colors_per_newton_step) return + if (.not. cs%use_colors) return + + ! verify data was loaded at initialization + if (.not. cs%lookup_loaded) then + write (*, *) 'colors_iteration error: lookup table not loaded' + ierr = -1 + return + end if + if (.not. cs%filters_loaded) then + write (*, *) 'colors_iteration error: filter data not loaded' + ierr = -1 + return + end if + + ! Open file if needed + if (.not. cs%iteration_file_open) then + call open_iteration_file(colors_handle, ierr) + if (ierr /= 0) return + end if + + iounit = cs%iteration_output_unit + + d = cs%distance + sed_filepath = trim(resolve_path(cs%stellar_atm)) + make_sed = .false. ! don't write individual SEDs for iteration output + + call calculate_bolometric(cs, t_eff, log_g, metallicity, R, d, & + bolometric_magnitude, bolometric_flux, wavelengths, fluxes, & + sed_filepath, interpolation_radius) + + ! Write basic data + write (iounit, '(i8, i6)', advance='no') model_number, iter + write (iounit, '(6(1pe15.7))', advance='no') & + star_age, dt, t_eff, log_g, R, bolometric_magnitude + write (iounit, '(1pe15.7)', advance='no') bolometric_flux + + ! calculate and write each filter magnitude + do i = 1, num_color_filters + if (t_eff >= 0 .and. metallicity >= 0 .and. & + allocated(wavelengths) .and. allocated(fluxes)) then + + ! pick the precomputed zero-point for the requested mag system + select case (trim(cs%mag_system)) + case ('VEGA', 'Vega', 'vega') + zero_point = cs%filters(i)%vega_zero_point + case ('AB', 'ab') + zero_point = cs%filters(i)%ab_zero_point + case ('ST', 'st') + zero_point = cs%filters(i)%st_zero_point + case default + zero_point = cs%filters(i)%vega_zero_point + end select + + magnitude = calculate_synthetic(t_eff, log_g, metallicity, ierr, & + wavelengths, fluxes, & + cs%filters(i)%wavelengths, & + cs%filters(i)%transmission, & + zero_point, & + color_filter_names(i), & + make_sed, cs%sed_per_model, & + cs%colors_results_directory, model_number) + + if (ierr /= 0) magnitude = -99.0_dp + else + magnitude = -99.0_dp + end if + + write (iounit, '(1pe15.7)', advance='no') magnitude + end do + + write (iounit, *) ! newline + + ! clean up + if (allocated(wavelengths)) deallocate (wavelengths) + if (allocated(fluxes)) deallocate (fluxes) + + end subroutine write_iteration_colors + + subroutine close_iteration_file(colors_handle, ierr) + integer, intent(in) :: colors_handle + integer, intent(out) :: ierr + type(Colors_General_Info), pointer :: cs + + ierr = 0 + call get_colors_ptr(colors_handle, cs, ierr) + if (ierr /= 0) return + + if (cs%iteration_file_open) then + close (cs%iteration_output_unit) + cs%iteration_file_open = .false. + end if + end subroutine close_iteration_file + +end module colors_iteration \ No newline at end of file diff --git a/colors/private/colors_utils.f90 b/colors/private/colors_utils.f90 index 530e86af6..a2995943b 100644 --- a/colors/private/colors_utils.f90 +++ b/colors/private/colors_utils.f90 @@ -19,42 +19,33 @@ module colors_utils use const_def, only: dp, strlen, mesa_dir - use colors_def, only: Colors_General_Info + use colors_def, only: Colors_General_Info, sed_mem_cache_cap use utils_lib, only: mesa_error implicit none public :: dilute_flux, trapezoidal_integration, romberg_integration, & simpson_integration, load_sed, load_filter, load_vega_sed, & - load_lookup_table, remove_dat + load_lookup_table, remove_dat, load_flux_cube, build_unique_grids, & + build_grid_to_lu_map, & + find_containing_cell, find_interval, find_nearest_point, & + find_bracket_index, load_sed_cached, load_stencil contains - !--------------------------------------------------------------------------- - ! Apply dilution factor to convert surface flux to observed flux - !--------------------------------------------------------------------------- + ! apply dilution factor (R/d)^2 to convert surface flux to observed flux subroutine dilute_flux(surface_flux, R, d, calibrated_flux) real(dp), intent(in) :: surface_flux(:) real(dp), intent(in) :: R, d ! R = stellar radius, d = distance (both in the same units, e.g., cm) real(dp), intent(out) :: calibrated_flux(:) - ! Check that the output array has the same size as the input if (size(calibrated_flux) /= size(surface_flux)) then print *, "Error in dilute_flux: Output array must have the same size as input array." call mesa_error(__FILE__, __LINE__) end if - ! Apply the dilution factor (R/d)^2 to each element calibrated_flux = surface_flux*((R/d)**2) end subroutine dilute_flux - !########################################################### - !## MATHS - !########################################################### - - !**************************** - !Trapezoidal and Simpson Integration For Flux Calculation - !**************************** - subroutine trapezoidal_integration(x, y, result) real(dp), dimension(:), intent(in) :: x, y real(dp), intent(out) :: result @@ -65,7 +56,6 @@ subroutine trapezoidal_integration(x, y, result) n = size(x) sum = 0.0_dp - ! Validate input sizes if (size(x) /= size(y)) then print *, "Error: x and y arrays must have the same size." call mesa_error(__FILE__, __LINE__) @@ -76,7 +66,6 @@ subroutine trapezoidal_integration(x, y, result) call mesa_error(__FILE__, __LINE__) end if - ! Perform trapezoidal integration do i = 1, n - 1 sum = sum + 0.5_dp*(x(i + 1) - x(i))*(y(i + 1) + y(i)) end do @@ -94,7 +83,6 @@ subroutine simpson_integration(x, y, result) n = size(x) sum = 0.0_dp - ! Validate input sizes if (size(x) /= size(y)) then print *, "Error: x and y arrays must have the same size." call mesa_error(__FILE__, __LINE__) @@ -105,10 +93,10 @@ subroutine simpson_integration(x, y, result) call mesa_error(__FILE__, __LINE__) end if - ! Perform adaptive Simpson's rule + ! adaptive Simpson's rule do i = 1, n - 2, 2 - h1 = x(i + 1) - x(i) ! Step size for first interval - h2 = x(i + 2) - x(i + 1) ! Step size for second interval + h1 = x(i + 1) - x(i) + h2 = x(i + 2) - x(i + 1) f0 = y(i) f1 = y(i + 1) @@ -118,7 +106,7 @@ subroutine simpson_integration(x, y, result) sum = sum + (h1 + h2)/6.0_dp*(f0 + 4.0_dp*f1 + f2) end do - ! Handle the case where n is odd (last interval) + ! handle the last interval if n is even (odd number of points) if (MOD(n, 2) == 0) then sum = sum + 0.5_dp*(x(n) - x(n - 1))*(y(n) + y(n - 1)) end if @@ -137,7 +125,6 @@ subroutine romberg_integration(x, y, result) n = size(x) m = int(log(real(n, DP))/log(2.0_dp)) + 1 ! Number of refinement levels - ! Validate input sizes if (size(x) /= size(y)) then print *, "Error: x and y arrays must have the same size." call mesa_error(__FILE__, __LINE__) @@ -150,11 +137,10 @@ subroutine romberg_integration(x, y, result) allocate (R(m)) - ! Compute initial trapezoidal rule estimate h = x(n) - x(1) R(1) = 0.5_dp*h*(y(1) + y(n)) - ! Refinement using Romberg's method + ! refinement using Romberg's method do j = 2, m sum = 0.0_dp do i = 1, 2**(j - 2) @@ -175,13 +161,6 @@ subroutine romberg_integration(x, y, result) result = R(1) end subroutine romberg_integration - !----------------------------------------------------------------------- - ! File I/O functions - !----------------------------------------------------------------------- - - !**************************** - ! Load Vega SED for Zero Point Calculation - !**************************** subroutine load_vega_sed(filepath, wavelengths, flux) character(len=*), intent(in) :: filepath real(dp), dimension(:), allocatable, intent(out) :: wavelengths, flux @@ -189,21 +168,18 @@ subroutine load_vega_sed(filepath, wavelengths, flux) integer :: unit, n_rows, status, i real(dp) :: temp_wave, temp_flux - unit = 20 - open (unit, file=trim(filepath), status='OLD', action='READ', iostat=status) + open (newunit=unit, file=trim(filepath), status='OLD', action='READ', iostat=status) if (status /= 0) then print *, "Error: Could not open Vega SED file ", trim(filepath) call mesa_error(__FILE__, __LINE__) end if - ! Skip header line read (unit, '(A)', iostat=status) line if (status /= 0) then print *, "Error: Could not read header from Vega SED file ", trim(filepath) call mesa_error(__FILE__, __LINE__) end if - ! Count the number of data lines n_rows = 0 do read (unit, '(A)', iostat=status) line @@ -212,14 +188,14 @@ subroutine load_vega_sed(filepath, wavelengths, flux) end do rewind (unit) - read (unit, '(A)', iostat=status) line ! Skip header again + read (unit, '(A)', iostat=status) line ! skip header again allocate (wavelengths(n_rows)) allocate (flux(n_rows)) i = 0 do - read (unit, *, iostat=status) temp_wave, temp_flux ! Ignore any extra columns + read (unit, *, iostat=status) temp_wave, temp_flux ! ignore any extra columns if (status /= 0) exit i = i + 1 wavelengths(i) = temp_wave @@ -229,9 +205,6 @@ subroutine load_vega_sed(filepath, wavelengths, flux) close (unit) end subroutine load_vega_sed - !**************************** - ! Load Filter File - !**************************** subroutine load_filter(directory, filter_wavelengths, filter_trans) character(len=*), intent(in) :: directory real(dp), dimension(:), allocatable, intent(out) :: filter_wavelengths, filter_trans @@ -240,22 +213,18 @@ subroutine load_filter(directory, filter_wavelengths, filter_trans) integer :: unit, n_rows, status, i real(dp) :: temp_wavelength, temp_trans - ! Open the file - unit = 20 - open (unit, file=trim(directory), status='OLD', action='READ', iostat=status) + open (newunit=unit, file=trim(directory), status='OLD', action='READ', iostat=status) if (status /= 0) then print *, "Error: Could not open file ", trim(directory) call mesa_error(__FILE__, __LINE__) end if - ! Skip header line read (unit, '(A)', iostat=status) line if (status /= 0) then print *, "Error: Could not read the file", trim(directory) call mesa_error(__FILE__, __LINE__) end if - ! Count rows in the file n_rows = 0 do read (unit, '(A)', iostat=status) line @@ -263,11 +232,10 @@ subroutine load_filter(directory, filter_wavelengths, filter_trans) n_rows = n_rows + 1 end do - ! Allocate arrays allocate (filter_wavelengths(n_rows)) allocate (filter_trans(n_rows)) - ! Rewind to the first non-comment line + ! rewind to the first non-comment line rewind (unit) do read (unit, '(A)', iostat=status) line @@ -278,7 +246,6 @@ subroutine load_filter(directory, filter_wavelengths, filter_trans) if (line(1:1) /= "#") exit end do - ! Read and parse data i = 0 do read (unit, *, iostat=status) temp_wavelength, temp_trans @@ -292,9 +259,7 @@ subroutine load_filter(directory, filter_wavelengths, filter_trans) close (unit) end subroutine load_filter - !**************************** - ! Load Lookup Table For Identifying Stellar Atmosphere Models - !**************************** + ! parses a csv lookup table mapping atmosphere grid parameters to SED filenames subroutine load_lookup_table(lookup_file, lookup_table, out_file_names, & out_logg, out_meta, out_teff) @@ -303,21 +268,19 @@ subroutine load_lookup_table(lookup_file, lookup_table, out_file_names, & character(len=100), allocatable, intent(inout) :: out_file_names(:) real(dp), allocatable, intent(inout) :: out_logg(:), out_meta(:), out_teff(:) - integer :: i, n_rows, status, unit + integer :: i, n_rows, status, unit, ios character(len=512) :: line character(len=*), parameter :: delimiter = "," character(len=100), allocatable :: columns(:), headers(:) + character(len=256) :: token integer :: logg_col, meta_col, teff_col - ! Open the file - unit = 10 - open (unit, file=lookup_file, status='old', action='read', iostat=status) + open (newunit=unit, file=lookup_file, status='old', action='read', iostat=status) if (status /= 0) then print *, "Error: Could not open file", lookup_file call mesa_error(__FILE__, __LINE__) end if - ! Read header line read (unit, '(A)', iostat=status) line if (status /= 0) then print *, "Error: Could not read header line" @@ -326,14 +289,30 @@ subroutine load_lookup_table(lookup_file, lookup_table, out_file_names, & call split_line(line, delimiter, headers) - ! Determine column indices for logg, meta, and teff + ! determine column indices -- try all plausible header name variants logg_col = get_column_index(headers, "logg") + if (logg_col < 0) logg_col = get_column_index(headers, "log_g") + if (logg_col < 0) logg_col = get_column_index(headers, "log(g)") + if (logg_col < 0) logg_col = get_column_index(headers, "log10g") + if (logg_col < 0) logg_col = get_column_index(headers, "log10_g") + teff_col = get_column_index(headers, "teff") + if (teff_col < 0) teff_col = get_column_index(headers, "t_eff") + if (teff_col < 0) teff_col = get_column_index(headers, "t(eff)") + if (teff_col < 0) teff_col = get_column_index(headers, "temperature") + if (teff_col < 0) teff_col = get_column_index(headers, "temp") meta_col = get_column_index(headers, "meta") - if (meta_col < 0) then - meta_col = get_column_index(headers, "feh") - end if + if (meta_col < 0) meta_col = get_column_index(headers, "feh") + if (meta_col < 0) meta_col = get_column_index(headers, "fe_h") + if (meta_col < 0) meta_col = get_column_index(headers, "[fe/h]") + if (meta_col < 0) meta_col = get_column_index(headers, "mh") + if (meta_col < 0) meta_col = get_column_index(headers, "[m/h]") + if (meta_col < 0) meta_col = get_column_index(headers, "m_h") + if (meta_col < 0) meta_col = get_column_index(headers, "z") + if (meta_col < 0) meta_col = get_column_index(headers, "logz") + if (meta_col < 0) meta_col = get_column_index(headers, "metallicity") + if (meta_col < 0) meta_col = get_column_index(headers, "metals") n_rows = 0 do @@ -343,14 +322,11 @@ subroutine load_lookup_table(lookup_file, lookup_table, out_file_names, & end do rewind (unit) - ! Skip header - read (unit, '(A)', iostat=status) line + read (unit, '(A)', iostat=status) line ! skip header - ! Allocate output arrays allocate (out_file_names(n_rows)) allocate (out_logg(n_rows), out_meta(n_rows), out_teff(n_rows)) - ! Read and parse the file i = 0 do read (unit, '(A)', iostat=status) line @@ -359,34 +335,40 @@ subroutine load_lookup_table(lookup_file, lookup_table, out_file_names, & call split_line(line, delimiter, columns) - ! Populate arrays out_file_names(i) = columns(1) + ! robust numeric parsing: never crash on bad/missing values if (logg_col > 0) then - if (columns(logg_col) /= "") then - read (columns(logg_col), *) out_logg(i) + token = trim(adjustl(columns(logg_col))) + if (len_trim(token) == 0 .or. token == '""') then + out_logg(i) = -999.0_dp else - out_logg(i) = -999.0 + read(token, *, iostat=ios) out_logg(i) + if (ios /= 0) out_logg(i) = -999.0_dp end if else out_logg(i) = -999.0 end if if (meta_col > 0) then - if (columns(meta_col) /= "") then - read (columns(meta_col), *) out_meta(i) + token = trim(adjustl(columns(meta_col))) + if (len_trim(token) == 0 .or. token == '""') then + out_meta(i) = 0.0_dp else - out_meta(i) = 0.0 + read(token, *, iostat=ios) out_meta(i) + if (ios /= 0) out_meta(i) = 0.0_dp end if else out_meta(i) = 0.0 end if if (teff_col > 0) then - if (columns(teff_col) /= "") then - read (columns(teff_col), *) out_teff(i) + token = trim(adjustl(columns(teff_col))) + if (len_trim(token) == 0 .or. token == '""') then + out_teff(i) = 0.0_dp else - out_teff(i) = 0.0 + read(token, *, iostat=ios) out_teff(i) + if (ios /= 0) out_teff(i) = 0.0_dp end if else out_teff(i) = 0.0 @@ -404,10 +386,10 @@ function get_column_index(headers, target) result(index) character(len=100) :: clean_header, clean_target index = -1 - clean_target = trim(adjustl(target)) ! Clean the target string + clean_target = trim(adjustl(target)) do i = 1, size(headers) - clean_header = trim(adjustl(headers(i))) ! Clean each header + clean_header = trim(adjustl(headers(i))) if (clean_header == clean_target) then index = i exit @@ -450,74 +432,66 @@ subroutine append_token(tokens, token) else n = size(tokens) allocate (temp(n)) - temp = tokens ! Backup the current tokens - deallocate (tokens) ! Deallocate the old array - allocate (tokens(n + 1)) ! Allocate with one extra space - tokens(1:n) = temp ! Restore old tokens - tokens(n + 1) = token ! Add the new token + temp = tokens + deallocate (tokens) + allocate (tokens(n + 1)) + tokens(1:n) = temp + tokens(n + 1) = token deallocate (temp) ! unsure if this is till needed. end if end subroutine append_token end subroutine load_lookup_table + + + subroutine load_sed(directory, index, wavelengths, flux) character(len=*), intent(in) :: directory integer, intent(in) :: index real(dp), dimension(:), allocatable, intent(out) :: wavelengths, flux character(len=512) :: line - integer :: unit, n_rows, status, i + integer :: unit, n_rows, status, i, header_lines real(dp) :: temp_wavelength, temp_flux - ! Open the file - unit = 20 - open (unit, file=trim(directory), status='OLD', action='READ', iostat=status) + + header_lines = 0 + open (newunit=unit, file=trim(directory), status='OLD', action='READ', iostat=status) if (status /= 0) then print *, "Error: Could not open file ", trim(directory) call mesa_error(__FILE__, __LINE__) end if - ! Skip header lines do read (unit, '(A)', iostat=status) line - if (status /= 0) then - print *, "Error: Could not read the file", trim(directory) - call mesa_error(__FILE__, __LINE__) - end if + if (status /= 0) exit if (line(1:1) /= "#") exit + header_lines = header_lines + 1 end do - ! Count rows in the file - n_rows = 0 + ! count data rows (we've already read one; count it) + n_rows = 1 do read (unit, '(A)', iostat=status) line if (status /= 0) exit n_rows = n_rows + 1 end do - ! Allocate arrays allocate (wavelengths(n_rows)) allocate (flux(n_rows)) - ! Rewind to the first non-comment line rewind (unit) - do + ! skip exactly header_lines lines + do i = 1, header_lines read (unit, '(A)', iostat=status) line - if (status /= 0) then - print *, "Error: Could not rewind file", trim(directory) - call mesa_error(__FILE__, __LINE__) - end if - if (line(1:1) /= "#") exit end do - ! Read and parse data i = 0 do read (unit, *, iostat=status) temp_wavelength, temp_flux if (status /= 0) exit i = i + 1 - ! Convert f_lambda to f_nu wavelengths(i) = temp_wavelength flux(i) = temp_flux end do @@ -526,28 +500,20 @@ subroutine load_sed(directory, index, wavelengths, flux) end subroutine load_sed - !----------------------------------------------------------------------- - ! Helper function for file names - !----------------------------------------------------------------------- - function remove_dat(path) result(base) - ! Extracts the portion of the string before the first dot + ! returns the portion of the string before the first dot character(len=*), intent(in) :: path character(len=strlen) :: base integer :: first_dot - ! Find the position of the first dot first_dot = 0 do while (first_dot < len_trim(path) .and. path(first_dot + 1:first_dot + 1) /= '.') first_dot = first_dot + 1 end do - ! Check if a dot was found if (first_dot < len_trim(path)) then - ! Extract the part before the dot base = path(:first_dot) else - ! No dot found, return the input string base = path end if end function remove_dat @@ -564,6 +530,39 @@ function basename(path) result(name) name = path(i + 1:) end function basename + function resolve_path(path) result(full_path) + use const_def, only: mesa_dir + character(len=*), intent(in) :: path + character(len=512) :: full_path + character(len=:), allocatable :: p + logical :: exists + integer :: n + + exists = .false. + p = trim(adjustl(path)) + n = len_trim(p) + + if (n >= 2 .and. p(1:2) == './') then + full_path = p + else if (n >= 3 .and. p(1:3) == '../') then + full_path = p + + else if (n >= 1 .and. p(1:1) == '/') then + inquire (file=p, exist=exists) + if (.not. exists) inquire (file=p//'/.', exist=exists) + + if (exists) then + full_path = p + else + write (*, *) trim(p), " not found. Trying ", trim(mesa_dir)//trim(p) + full_path = trim(mesa_dir)//trim(p) + end if + + else + full_path = trim(mesa_dir)//'/'//trim(p) + end if + end function resolve_path + subroutine read_strings_from_file(colors_settings, strings, n, ierr) character(len=512) :: filename character(len=100), allocatable, intent(out) :: strings(:) @@ -574,14 +573,11 @@ subroutine read_strings_from_file(colors_settings, strings, n, ierr) ierr = 0 - filename = trim(mesa_dir)//trim(colors_settings%instrument)//"/"// & + filename = trim(resolve_path(colors_settings%instrument))//"/"// & trim(basename(colors_settings%instrument)) - !filename = trim(mesa_dir)//trim(colors_settings%instrument)//"/" - n = 0 - unit = 10 - open (unit, file=filename, status='old', action='read', iostat=status) + open (newunit=unit, file=filename, status='old', action='read', iostat=status) if (status /= 0) then ierr = -1 print *, "Error: Could not open file", filename @@ -602,4 +598,527 @@ subroutine read_strings_from_file(colors_settings, strings, n, ierr) close (unit) end subroutine read_strings_from_file -end module colors_utils + ! load flux cube from binary file into handle at initialization. + ! if the file cannot be opened or the large flux_cube array cannot be + ! allocated, sets cube_loaded = .false. so the runtime will fall back to + ! loading individual SED files via the lookup table. + ! grids and wavelengths are always loaded (small); only the 4-D cube + ! allocation is treated as the fallback trigger. + subroutine load_flux_cube(rq, stellar_model_dir) + type(Colors_General_Info), intent(inout) :: rq + character(len=*), intent(in) :: stellar_model_dir + + character(len=512) :: bin_filename + integer :: unit, status, n_teff, n_logg, n_meta, n_lambda + real(dp) :: cube_mb + + rq%cube_loaded = .false. + + bin_filename = trim(resolve_path(stellar_model_dir))//'/flux_cube.bin' + + open (newunit=unit, file=trim(bin_filename), status='OLD', & + access='STREAM', form='UNFORMATTED', iostat=status) + if (status /= 0) then + ! no binary cube available -- will use individual SED files + write (*, '(a)') 'colors: no flux_cube.bin found; using per-file SED loading' + return + end if + + read (unit, iostat=status) n_teff, n_logg, n_meta, n_lambda + if (status /= 0) then + close (unit) + return + end if + + ! attempt the large allocation first -- this is the one that may fail. + ! doing it before the small grid allocations avoids partial cleanup. + allocate (rq%cube_flux(n_teff, n_logg, n_meta, n_lambda), stat=status) + if (status /= 0) then + if (allocated(rq%cube_flux)) deallocate (rq%cube_flux) + cube_mb = real(n_teff, dp)*n_logg*n_meta*n_lambda*8.0_dp/(1024.0_dp**2) + write (*, '(a,f0.1,a)') & + 'colors: flux cube allocation failed (', cube_mb, & + ' MB); falling back to per-file SED loading' + close (unit) + return + end if + + ! grid arrays are small -- always expected to succeed + allocate (rq%cube_teff_grid(n_teff), stat=status) + if (status /= 0) goto 900 + + allocate (rq%cube_logg_grid(n_logg), stat=status) + if (status /= 0) goto 900 + + allocate (rq%cube_meta_grid(n_meta), stat=status) + if (status /= 0) goto 900 + + allocate (rq%cube_wavelengths(n_lambda), stat=status) + if (status /= 0) goto 900 + + read (unit, iostat=status) rq%cube_teff_grid + if (status /= 0) goto 900 + + read (unit, iostat=status) rq%cube_logg_grid + if (status /= 0) goto 900 + + read (unit, iostat=status) rq%cube_meta_grid + if (status /= 0) goto 900 + + read (unit, iostat=status) rq%cube_wavelengths + if (status /= 0) goto 900 + + read (unit, iostat=status) rq%cube_flux + if (status /= 0) goto 900 + + close (unit) + rq%cube_loaded = .true. + + cube_mb = real(n_teff, dp)*n_logg*n_meta*n_lambda*8.0_dp/(1024.0_dp**2) + write (*, '(a,i0,a,i0,a,i0,a,i0,a,f0.1,a)') & + 'colors: flux cube loaded (', & + n_teff, ' x ', n_logg, ' x ', n_meta, ' x ', n_lambda, & + ', ', cube_mb, ' MB)' + return + + ! error cleanup -- deallocate everything that may have been allocated +900 continue + write (*, '(a)') 'colors: error reading flux_cube.bin; falling back to per-file SED loading' + if (allocated(rq%cube_flux)) deallocate (rq%cube_flux) + if (allocated(rq%cube_teff_grid)) deallocate (rq%cube_teff_grid) + if (allocated(rq%cube_logg_grid)) deallocate (rq%cube_logg_grid) + if (allocated(rq%cube_meta_grid)) deallocate (rq%cube_meta_grid) + if (allocated(rq%cube_wavelengths)) deallocate (rq%cube_wavelengths) + close (unit) + + end subroutine load_flux_cube + + ! build unique sorted grids from lookup table and store on handle. + ! called once at init so the fallback interpolation path never rebuilds these. + subroutine build_unique_grids(rq) + type(Colors_General_Info), intent(inout) :: rq + logical :: found + + if (rq%unique_grids_built) return + if (.not. rq%lookup_loaded) return + + call extract_unique_sorted(rq%lu_teff, rq%u_teff) + call extract_unique_sorted(rq%lu_logg, rq%u_logg) + call extract_unique_sorted(rq%lu_meta, rq%u_meta) + + rq%unique_grids_built = .true. + + contains + + subroutine extract_unique_sorted(arr, unique) + real(dp), intent(in) :: arr(:) + real(dp), allocatable, intent(out) :: unique(:) + real(dp), allocatable :: buf(:) + integer :: ii, jj, nn, nnu + real(dp) :: sw + + nn = size(arr) + allocate (buf(nn)) + nnu = 0 + + do ii = 1, nn + found = .false. + do jj = 1, nnu + if (abs(arr(ii) - buf(jj)) < 1.0e-10_dp) then + found = .true. + exit + end if + end do + if (.not. found) then + nnu = nnu + 1 + buf(nnu) = arr(ii) + end if + end do + + ! insertion sort (grids are small) + do ii = 2, nnu + sw = buf(ii) + jj = ii - 1 + do while (jj >= 1 .and. buf(jj) > sw) + buf(jj + 1) = buf(jj) + jj = jj - 1 + end do + buf(jj + 1) = sw + end do + + allocate (unique(nnu)) + unique = buf(1:nnu) + deallocate (buf) + end subroutine extract_unique_sorted + + end subroutine build_unique_grids + + ! build a 3-D mapping from unique-grid indices to lookup-table rows. + ! grid_to_lu(i_t, i_g, i_m) = row in lu_* whose parameters match + ! (u_teff(i_t), u_logg(i_g), u_meta(i_m)). + ! called once at init; replaces the O(n_lu) nearest-neighbour scan + ! that previously ran per corner at runtime. + subroutine build_grid_to_lu_map(rq) + type(Colors_General_Info), intent(inout) :: rq + + integer :: nt, ng, nm, n_lu + integer :: it, ig, im, idx + integer :: best_idx + real(dp) :: best_dist, dist + real(dp), parameter :: tol = 1.0e-10_dp + + if (rq%grid_map_built) return + if (.not. rq%unique_grids_built) return + if (.not. rq%lookup_loaded) return + + nt = size(rq%u_teff) + ng = size(rq%u_logg) + nm = size(rq%u_meta) + n_lu = size(rq%lu_teff) + + allocate (rq%grid_to_lu(nt, ng, nm)) + rq%grid_to_lu = 0 + + do it = 1, nt + do ig = 1, ng + do im = 1, nm + best_idx = 1 + best_dist = huge(1.0_dp) + do idx = 1, n_lu + dist = abs(rq%lu_teff(idx) - rq%u_teff(it)) + & + abs(rq%lu_logg(idx) - rq%u_logg(ig)) + & + abs(rq%lu_meta(idx) - rq%u_meta(im)) + if (dist < best_dist) then + best_dist = dist + best_idx = idx + end if + if (best_dist < tol) exit ! exact match found + end do + rq%grid_to_lu(it, ig, im) = best_idx + end do + end do + end do + + rq%grid_map_built = .true. + + end subroutine build_grid_to_lu_map + + subroutine find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, & + i_x, i_y, i_z, t_x, t_y, t_z) + real(dp), intent(in) :: x_val, y_val, z_val + real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) + integer, intent(out) :: i_x, i_y, i_z + real(dp), intent(out) :: t_x, t_y, t_z + + call find_interval(x_grid, x_val, i_x, t_x) + call find_interval(y_grid, y_val, i_y, t_y) + call find_interval(z_grid, z_val, i_z, t_z) + end subroutine find_containing_cell + + ! find the interval in a sorted array containing a value. + ! returns index i such that x(i) <= val <= x(i+1) and fractional position t in [0,1]. + ! detects dummy axes (all zeros / 999 / -999) and collapses them to i=1, t=0. + subroutine find_interval(x, val, i, t) + real(dp), intent(in) :: x(:), val + integer, intent(out) :: i + real(dp), intent(out) :: t + + integer :: n, lo, hi, mid + logical :: dummy_axis + + n = size(x) + + ! detect dummy axis: all values == 0, 999, or -999 + dummy_axis = all(x == 0.0_dp) .or. all(x == 999.0_dp) .or. all(x == -999.0_dp) + + if (dummy_axis) then + i = 1 + t = 0.0_dp + return + end if + + if (val <= x(1)) then + i = 1 + t = 0.0_dp + return + else if (val >= x(n)) then + i = n - 1 + t = 1.0_dp + return + end if + + lo = 1 + hi = n + do while (hi - lo > 1) + mid = (lo + hi)/2 + if (val >= x(mid)) then + lo = mid + else + hi = mid + end if + end do + + i = lo + if (abs(x(i + 1) - x(i)) < 1.0e-30_dp) then + t = 0.0_dp ! degenerate interval -- no interpolation needed + else + t = (val - x(i))/(x(i + 1) - x(i)) + end if + end subroutine find_interval + + subroutine find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, & + i_x, i_y, i_z) + real(dp), intent(in) :: x_val, y_val, z_val + real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) + integer, intent(out) :: i_x, i_y, i_z + + i_x = minloc(abs(x_val - x_grid), 1) + i_y = minloc(abs(y_val - y_grid), 1) + i_z = minloc(abs(z_val - z_grid), 1) + end subroutine find_nearest_point + + ! returns idx such that grid(idx) <= val < grid(idx+1), clamped to bounds + subroutine find_bracket_index(grid, val, idx) + real(dp), intent(in) :: grid(:), val + integer, intent(out) :: idx + + integer :: n, lo, hi, mid + + n = size(grid) + if (n < 2) then + idx = 1 + return + end if + + if (val <= grid(1)) then + idx = 1 + return + else if (val >= grid(n)) then + idx = n - 1 + return + end if + + lo = 1 + hi = n + do while (hi - lo > 1) + mid = (lo + hi)/2 + if (val >= grid(mid)) then + lo = mid + else + hi = mid + end if + end do + idx = lo + end subroutine find_bracket_index + + ! fallback SED cache and stencil loader + + ! load a stencil sub-cube of SED fluxes for the given index ranges. + ! uses load_sed_cached so repeated visits to the same grid point are + ! served from memory rather than disk. + subroutine load_stencil(rq, resolved_dir, lo_t, hi_t, lo_g, hi_g, lo_m, hi_m) + type(Colors_General_Info), intent(inout) :: rq + character(len=*), intent(in) :: resolved_dir + integer, intent(in) :: lo_t, hi_t, lo_g, hi_g, lo_m, hi_m + + integer :: st, sg, sm, n_lambda, lu_idx + integer :: it, ig, im + real(dp), dimension(:), allocatable :: sed_flux + + st = hi_t - lo_t + 1 + sg = hi_g - lo_g + 1 + sm = hi_m - lo_m + 1 + + ! free previous stencil flux data + if (allocated(rq%stencil_fluxes)) deallocate (rq%stencil_fluxes) + + n_lambda = 0 + + do it = lo_t, hi_t + do ig = lo_g, hi_g + do im = lo_m, hi_m + lu_idx = rq%grid_to_lu(it, ig, im) + call load_sed_cached(rq, resolved_dir, lu_idx, sed_flux) + + if (n_lambda == 0) then + n_lambda = size(sed_flux) + allocate (rq%stencil_fluxes(st, sg, sm, n_lambda)) + else if (size(sed_flux) /= n_lambda) then + ! SED files have inconsistent wavelength counts — this is a + ! fatal grid inconsistency; crash with a clear message. + write (*, '(a,i0,a,i0,a,i0)') & + 'colors ERROR: SED at lu_idx=', lu_idx, & + ' has ', size(sed_flux), ' wavelength points; expected ', n_lambda + write (*, '(a)') & + 'colors: bt-settl (or other) grid files have non-uniform wavelength grids.' + call mesa_error(__FILE__, __LINE__) + end if + + rq%stencil_fluxes(it - lo_t + 1, ig - lo_g + 1, im - lo_m + 1, :) = & + sed_flux(1:n_lambda) + if (allocated(sed_flux)) deallocate (sed_flux) + end do + end do + end do + + ! set stencil wavelengths from the canonical copy on the handle + if (allocated(rq%stencil_wavelengths)) deallocate (rq%stencil_wavelengths) + allocate (rq%stencil_wavelengths(n_lambda)) + rq%stencil_wavelengths = rq%fallback_wavelengths(1:n_lambda) + + end subroutine load_stencil + + ! retrieve an SED flux from the memory cache, or load from disk on miss. + ! uses a bounded circular buffer (sed_mem_cache_cap slots). + ! on the first disk read, the wavelength array is stored once on the handle + ! as rq%fallback_wavelengths -- all SEDs in a given atmosphere grid share + ! the same wavelengths, so only flux is cached and returned. + subroutine load_sed_cached(rq, resolved_dir, lu_idx, flux) + type(Colors_General_Info), intent(inout) :: rq + character(len=*), intent(in) :: resolved_dir + integer, intent(in) :: lu_idx + real(dp), dimension(:), allocatable, intent(out) :: flux + + integer :: slot, n_lam, status + character(len=512) :: filepath + real(dp), dimension(:), allocatable :: sed_wave + real(dp), dimension(:), allocatable :: flux_interp + + ! initialise the cache on first call + if (.not. rq%sed_mcache_init) then + allocate (rq%sed_mcache_keys(sed_mem_cache_cap)) + rq%sed_mcache_keys = 0 ! 0 means empty slot + rq%sed_mcache_count = 0 + rq%sed_mcache_next = 1 + rq%sed_mcache_nlam = 0 + rq%sed_mcache_init = .true. + end if + + ! search for a cache hit (linear scan over a small array) + do slot = 1, rq%sed_mcache_count + if (rq%sed_mcache_keys(slot) == lu_idx) then + ! hit -- return cached flux + n_lam = rq%sed_mcache_nlam + allocate (flux(n_lam)) + flux = rq%sed_mcache_data(:, slot) + return + end if + end do + + ! miss -- load from disk + filepath = trim(resolved_dir)//'/'//trim(rq%lu_file_names(lu_idx)) + call load_sed(filepath, lu_idx, sed_wave, flux) + + ! store the canonical wavelength array on the handle (once only) + if (.not. rq%fallback_wavelengths_set) then + n_lam = size(sed_wave) + allocate (rq%fallback_wavelengths(n_lam)) + rq%fallback_wavelengths = sed_wave + rq%fallback_wavelengths_set = .true. + end if + + + + ! store flux in the cache + n_lam = size(flux) + if (rq%sed_mcache_nlam == 0) then + rq%sed_mcache_nlam = n_lam + allocate (rq%sed_mcache_data(n_lam, sed_mem_cache_cap), stat=status) + if (status /= 0) then + write (*, '(a,f0.1,a)') 'colors: SED memory cache allocation failed (', & + real(n_lam, dp)*sed_mem_cache_cap*8.0_dp/1024.0_dp**2, ' MB); cache disabled' + rq%sed_mcache_init = .false. ! fall through to disk-only path + if (allocated(sed_wave)) deallocate (sed_wave) + return + end if + else if (n_lam /= rq%sed_mcache_nlam) then + ! non-uniform wavelength grids across SED files: interpolate onto the canonical grid + if (.not. rq%fallback_wavelengths_set) then + allocate (rq%fallback_wavelengths(n_lam)) + rq%fallback_wavelengths = sed_wave + rq%fallback_wavelengths_set = .true. + end if + + allocate (flux_interp(rq%sed_mcache_nlam), stat=status) + if (status /= 0) then + write (*, '(a)') 'colors: WARNING: could not allocate interpolation buffer; SED cache disabled' + rq%sed_mcache_init = .false. + if (allocated(sed_wave)) deallocate (sed_wave) + return + end if + + call interp_linear_internal(rq%fallback_wavelengths, sed_wave, flux, flux_interp) + + deallocate (flux) + allocate (flux(rq%sed_mcache_nlam)) + flux = flux_interp + deallocate (flux_interp) + + n_lam = rq%sed_mcache_nlam + end if + + ! write to the next slot (circular) + slot = rq%sed_mcache_next + rq%sed_mcache_keys(slot) = lu_idx + rq%sed_mcache_data(:, slot) = flux(1:n_lam) + + ! advance the circular pointer + if (rq%sed_mcache_count < sed_mem_cache_cap) & + rq%sed_mcache_count = rq%sed_mcache_count + 1 + rq%sed_mcache_next = mod(slot, sed_mem_cache_cap) + 1 + + if (allocated(sed_wave)) deallocate (sed_wave) + end subroutine load_sed_cached + + + ! simple 1D linear interpolation (np.interp-like): clamps to endpoints + subroutine interp_linear_internal(x_out, x_in, y_in, y_out) + real(dp), intent(in) :: x_out(:), x_in(:), y_in(:) + real(dp), intent(out) :: y_out(:) + integer :: i, n_in, n_out, lo, hi, mid + real(dp) :: t, denom + + n_in = size(x_in) + n_out = size(x_out) + + if (n_out <= 0) return + + if (n_in <= 0) then + y_out = 0.0_dp + return + end if + + if (n_in == 1) then + y_out = y_in(1) + return + end if + + do i = 1, n_out + if (x_out(i) <= x_in(1)) then + y_out(i) = y_in(1) + else if (x_out(i) >= x_in(n_in)) then + y_out(i) = y_in(n_in) + else + lo = 1 + hi = n_in + do while (hi - lo > 1) + mid = (lo + hi)/2 + if (x_out(i) >= x_in(mid)) then + lo = mid + else + hi = mid + end if + end do + + denom = x_in(lo + 1) - x_in(lo) + if (abs(denom) <= 0.0_dp) then + y_out(i) = y_in(lo) + else + t = (x_out(i) - x_in(lo))/denom + y_out(i) = y_in(lo) + t*(y_in(lo + 1) - y_in(lo)) + end if + end if + end do + + end subroutine interp_linear_internal +end module colors_utils \ No newline at end of file diff --git a/colors/private/hermite_interp.f90 b/colors/private/hermite_interp.f90 index 43c6019f3..958801d4b 100644 --- a/colors/private/hermite_interp.f90 +++ b/colors/private/hermite_interp.f90 @@ -23,7 +23,9 @@ module hermite_interp use const_def, only: dp - use colors_utils, only: dilute_flux + use colors_def, only: Colors_General_Info + use colors_utils, only: dilute_flux, find_containing_cell, find_interval, & + find_nearest_point, find_bracket_index, load_stencil implicit none private @@ -32,166 +34,309 @@ module hermite_interp contains !--------------------------------------------------------------------------- - ! Main entry point: Construct a SED using Hermite tensor interpolation + ! Main entry point: Construct a SED using Hermite tensor interpolation. + ! Data loading strategy is determined by rq%cube_loaded (set at init): + ! cube_loaded = .true. -> use the preloaded 4-D cube on the handle + ! cube_loaded = .false. -> load individual SED files via the lookup table !--------------------------------------------------------------------------- - subroutine construct_sed_hermite(teff, log_g, metallicity, R, d, file_names, & - lu_teff, lu_logg, lu_meta, stellar_model_dir, & - wavelengths, fluxes) + subroutine construct_sed_hermite(rq, teff, log_g, metallicity, R, d, & + stellar_model_dir, wavelengths, fluxes) + type(Colors_General_Info), intent(inout) :: rq real(dp), intent(in) :: teff, log_g, metallicity, R, d - real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:) character(len=*), intent(in) :: stellar_model_dir - character(len=100), intent(in) :: file_names(:) real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes - integer :: i, n_lambda, status, n_teff, n_logg, n_meta + integer :: n_lambda real(dp), dimension(:), allocatable :: interp_flux, diluted_flux - real(dp), dimension(:, :, :, :), allocatable :: precomputed_flux_cube - real(dp), dimension(:, :, :), allocatable :: flux_cube_lambda - ! Parameter grids - real(dp), allocatable :: teff_grid(:), logg_grid(:), meta_grid(:) - character(len=256) :: bin_filename + if (rq%cube_loaded) then + ! ---- Fast path: use preloaded cube from handle ---- + n_lambda = size(rq%cube_wavelengths) + + ! Copy wavelengths to output + allocate (wavelengths(n_lambda)) + wavelengths = rq%cube_wavelengths + + ! Vectorised interpolation over all wavelengths in one pass — + ! cell location is computed once and reused, no per-wavelength + ! allocation or 3-D slice extraction needed. + allocate (interp_flux(n_lambda)) + call hermite_interp_vector(teff, log_g, metallicity, & + rq%cube_teff_grid, rq%cube_logg_grid, & + rq%cube_meta_grid, & + rq%cube_flux, n_lambda, interp_flux) + else + ! ---- Fallback path: load individual SED files from lookup table ---- + call construct_sed_from_files(rq, teff, log_g, metallicity, & + stellar_model_dir, interp_flux, wavelengths) + n_lambda = size(wavelengths) + end if - ! Construct the binary filename - bin_filename = trim(stellar_model_dir)//'/flux_cube.bin' + ! Apply distance dilution to get observed flux + allocate (diluted_flux(n_lambda)) + call dilute_flux(interp_flux, R, d, diluted_flux) + fluxes = diluted_flux - ! Load the data from binary file - call load_binary_data(bin_filename, teff_grid, logg_grid, meta_grid, & - wavelengths, precomputed_flux_cube, status) + end subroutine construct_sed_hermite - n_teff = size(teff_grid) - n_logg = size(logg_grid) - n_meta = size(meta_grid) - n_lambda = size(wavelengths) + !--------------------------------------------------------------------------- + ! Fallback: Build a local sub-cube from individual SED files with enough + ! context for Hermite derivative computation, then interpolate all + ! wavelengths in a single pass. + ! + ! Correctness guarantee + ! --------------------- + ! The cube path passes the FULL grid arrays to hermite_tensor_interp3d, + ! so compute_derivatives_at_point can use centred differences at interior + ! nodes. To replicate this exactly, we load not just the 2x2x2 cell + ! corners but also one extra grid point on each side when available + ! (the "derivative stencil"). The resulting sub-grid has 2-4 points per + ! axis, and the interpolator sees the same derivative context it would + ! get from the full cube. + ! + ! Performance + ! ----------- + ! * grid_to_lu map -- O(1) lookup per stencil point (vs O(n_lu) scan) + ! * SED memory cache -- file I/O only on first visit to a grid point + ! * Stencil cache -- no work at all when the cell hasn't changed + ! * Vectorised wavelength loop -- cell geometry computed once, reused + !--------------------------------------------------------------------------- + subroutine construct_sed_from_files(rq, teff, log_g, metallicity, & + stellar_model_dir, interp_flux, wavelengths) + use colors_utils, only: resolve_path, build_grid_to_lu_map + type(Colors_General_Info), intent(inout) :: rq + real(dp), intent(in) :: teff, log_g, metallicity + character(len=*), intent(in) :: stellar_model_dir + real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths + + integer :: i_t, i_g, i_m ! bracketing indices in unique grids + integer :: lo_t, hi_t, lo_g, hi_g, lo_m, hi_m ! stencil bounds + integer :: nt, ng, nm, n_lambda + integer :: it, ig, im, lu_idx, i + character(len=512) :: resolved_dir + logical :: need_reload + + resolved_dir = trim(resolve_path(stellar_model_dir)) + + ! Ensure the grid-to-lu mapping exists (built once, then reused) + if (.not. rq%grid_map_built) call build_grid_to_lu_map(rq) + + ! Find bracketing cell in the unique grids + call find_bracket_index(rq%u_teff, teff, i_t) + call find_bracket_index(rq%u_logg, log_g, i_g) + call find_bracket_index(rq%u_meta, metallicity, i_m) + + ! Check if the stencil cache is still valid for this cell + need_reload = .true. + if (rq%stencil_valid .and. & + i_t == rq%stencil_i_t .and. & + i_g == rq%stencil_i_g .and. & + i_m == rq%stencil_i_m) then + need_reload = .false. + end if - ! Allocate space for interpolated flux + if (need_reload) then + ! Determine the extended stencil bounds: + ! For each axis, include one point before and after the cell + ! when available, so that centred differences match the cube. + nt = size(rq%u_teff) + ng = size(rq%u_logg) + nm = size(rq%u_meta) + + if (nt < 2) then + lo_t = 1; hi_t = 1 + else + lo_t = max(1, i_t - 1) + hi_t = min(nt, i_t + 2) + end if + + if (ng < 2) then + lo_g = 1; hi_g = 1 + else + lo_g = max(1, i_g - 1) + hi_g = min(ng, i_g + 2) + end if + + if (nm < 2) then + lo_m = 1; hi_m = 1 + else + lo_m = max(1, i_m - 1) + hi_m = min(nm, i_m + 2) + end if + + ! Load SEDs for every stencil point (using memory cache) + call load_stencil(rq, resolved_dir, lo_t, hi_t, lo_g, hi_g, lo_m, hi_m) + + ! Store subgrid arrays on the handle + if (allocated(rq%stencil_teff)) deallocate(rq%stencil_teff) + if (allocated(rq%stencil_logg)) deallocate(rq%stencil_logg) + if (allocated(rq%stencil_meta)) deallocate(rq%stencil_meta) + + allocate (rq%stencil_teff(hi_t - lo_t + 1)) + allocate (rq%stencil_logg(hi_g - lo_g + 1)) + allocate (rq%stencil_meta(hi_m - lo_m + 1)) + rq%stencil_teff = rq%u_teff(lo_t:hi_t) + rq%stencil_logg = rq%u_logg(lo_g:hi_g) + rq%stencil_meta = rq%u_meta(lo_m:hi_m) + + rq%stencil_i_t = i_t + rq%stencil_i_g = i_g + rq%stencil_i_m = i_m + rq%stencil_valid = .true. + end if + + ! Copy wavelengths to output + n_lambda = size(rq%stencil_wavelengths) + allocate (wavelengths(n_lambda)) + wavelengths = rq%stencil_wavelengths + + ! Interpolate all wavelengths using precomputed stencil allocate (interp_flux(n_lambda)) + call hermite_interp_vector(teff, log_g, metallicity, & + rq%stencil_teff, rq%stencil_logg, rq%stencil_meta, & + rq%stencil_fluxes, n_lambda, interp_flux) - ! Process each wavelength point - do i = 1, n_lambda - allocate (flux_cube_lambda(n_teff, n_logg, n_meta)) - flux_cube_lambda = precomputed_flux_cube(:, :, :, i) + end subroutine construct_sed_from_files - interp_flux(i) = hermite_tensor_interp3d(teff, log_g, metallicity, & - teff_grid, logg_grid, meta_grid, flux_cube_lambda) + !--------------------------------------------------------------------------- + ! Vectorised Hermite interpolation over all wavelengths. + ! + ! The cell location (i_x, i_y, i_z, t_x, t_y, t_z) depends only on + ! (teff, logg, meta) and the sub-grids — not on wavelength. Computing + ! it once and reusing across all n_lambda samples eliminates redundant + ! binary searches and basis-function evaluations. + !--------------------------------------------------------------------------- + subroutine hermite_interp_vector(x_val, y_val, z_val, & + x_grid, y_grid, z_grid, & + f_values_4d, n_lambda, result_flux) + real(dp), intent(in) :: x_val, y_val, z_val + real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) + real(dp), intent(in) :: f_values_4d(:,:,:,:) ! (nx, ny, nz, n_lambda) + integer, intent(in) :: n_lambda + real(dp), intent(out) :: result_flux(n_lambda) - deallocate (flux_cube_lambda) - end do + integer :: i_x, i_y, i_z + real(dp) :: t_x, t_y, t_z + real(dp) :: dx, dy, dz + integer :: nx, ny, nz + integer :: ix, iy, iz, lam + real(dp) :: h_x(2), h_y(2), h_z(2) + real(dp) :: hx_d(2), hy_d(2), hz_d(2) + real(dp) :: val, df_dx, df_dy, df_dz, s + real(dp) :: wx, wy, wz, wxd, wyd, wzd - ! Apply distance dilution to get observed flux - allocate (diluted_flux(n_lambda)) - call dilute_flux(interp_flux, R, d, diluted_flux) - fluxes = diluted_flux + nx = size(x_grid) + ny = size(y_grid) + nz = size(z_grid) - end subroutine construct_sed_hermite + ! Find containing cell (done once for all wavelengths) + call find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, & + i_x, i_y, i_z, t_x, t_y, t_z) -!--------------------------------------------------------------------------- -! Load data from binary file -!--------------------------------------------------------------------------- - subroutine load_binary_data(filename, teff_grid, logg_grid, meta_grid, & - wavelengths, flux_cube, status) - character(len=*), intent(in) :: filename - real(dp), allocatable, intent(out) :: teff_grid(:), logg_grid(:), meta_grid(:) - real(dp), allocatable, intent(out) :: wavelengths(:) - real(dp), allocatable, intent(out) :: flux_cube(:, :, :, :) - integer, intent(out) :: status - - integer :: unit, n_teff, n_logg, n_meta, n_lambda - - unit = 99 - status = 0 - - ! Open the binary file - open (unit=unit, file=filename, status='OLD', ACCESS='STREAM', FORM='UNFORMATTED', iostat=status) - if (status /= 0) then - print *, 'Error opening binary file:', trim(filename) - return - end if + ! If outside grid, use nearest point for all wavelengths + if (i_x < 1 .or. i_x >= nx .or. & + i_y < 1 .or. i_y >= ny .or. & + i_z < 1 .or. i_z >= nz) then - ! Read dimensions - read (unit, iostat=status) n_teff, n_logg, n_meta, n_lambda - if (status /= 0) then - print *, 'Error reading dimensions from binary file' - close (unit) + call find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, & + i_x, i_y, i_z) + do lam = 1, n_lambda + result_flux(lam) = f_values_4d(i_x, i_y, i_z, lam) + end do return end if - ! Allocate arrays based on dimensions - allocate (teff_grid(n_teff), STAT=status) - if (status /= 0) then - print *, 'Error allocating teff_grid array' - close (unit) - return - end if + ! Grid cell spacing + dx = x_grid(i_x + 1) - x_grid(i_x) + dy = y_grid(i_y + 1) - y_grid(i_y) + dz = z_grid(i_z + 1) - z_grid(i_z) - allocate (logg_grid(n_logg), STAT=status) - if (status /= 0) then - print *, 'Error allocating logg_grid array' - close (unit) - return - end if + ! Precompute Hermite basis functions (same for all wavelengths) + h_x = [h00(t_x), h01(t_x)] + hx_d = [h10(t_x), h11(t_x)] + h_y = [h00(t_y), h01(t_y)] + hy_d = [h10(t_y), h11(t_y)] + h_z = [h00(t_z), h01(t_z)] + hz_d = [h10(t_z), h11(t_z)] - allocate (meta_grid(n_meta), STAT=status) - if (status /= 0) then - print *, 'Error allocating meta_grid array' - close (unit) - return - end if + ! Loop over wavelengths — the hot loop + do lam = 1, n_lambda + s = 0.0_dp + do iz = 0, 1 + wz = h_z(iz + 1) + wzd = hz_d(iz + 1) + do iy = 0, 1 + wy = h_y(iy + 1) + wyd = hy_d(iy + 1) + do ix = 0, 1 + wx = h_x(ix + 1) + wxd = hx_d(ix + 1) + + val = f_values_4d(i_x + ix, i_y + iy, i_z + iz, lam) + + call compute_derivatives_at_point_4d( & + f_values_4d, i_x + ix, i_y + iy, i_z + iz, lam, & + nx, ny, nz, dx, dy, dz, df_dx, df_dy, df_dz) + + s = s + wx*wy*wz * val & + + wxd*wy*wz * dx * df_dx & + + wx*wyd*wz * dy * df_dy & + + wx*wy*wzd * dz * df_dz + end do + end do + end do + result_flux(lam) = s + end do - allocate (wavelengths(n_lambda), STAT=status) - if (status /= 0) then - print *, 'Error allocating wavelengths array' - close (unit) - return - end if + end subroutine hermite_interp_vector - allocate (flux_cube(n_teff, n_logg, n_meta, n_lambda), STAT=status) - if (status /= 0) then - print *, 'Error allocating flux_cube array' - close (unit) - return - end if + !--------------------------------------------------------------------------- + ! Compute derivatives directly from the 4-D array at a given wavelength, + ! avoiding the need to extract a 3-D slice first. + !--------------------------------------------------------------------------- + subroutine compute_derivatives_at_point_4d(f4d, i, j, k, lam, nx, ny, nz, & + dx, dy, dz, df_dx, df_dy, df_dz) + real(dp), intent(in) :: f4d(:,:,:,:) + integer, intent(in) :: i, j, k, lam, nx, ny, nz + real(dp), intent(in) :: dx, dy, dz + real(dp), intent(out) :: df_dx, df_dy, df_dz - ! Read grid arrays - read (unit, iostat=status) teff_grid - if (status /= 0) then - print *, 'Error reading teff_grid' - close (unit) - return + ! x derivative + if (dx < 1.0e-30_dp) then + df_dx = 0.0_dp + else if (i > 1 .and. i < nx) then + df_dx = (f4d(i + 1, j, k, lam) - f4d(i - 1, j, k, lam)) / (2.0_dp * dx) + else if (i == 1) then + df_dx = (f4d(i + 1, j, k, lam) - f4d(i, j, k, lam)) / dx + else + df_dx = (f4d(i, j, k, lam) - f4d(i - 1, j, k, lam)) / dx end if - read (unit, iostat=status) logg_grid - if (status /= 0) then - print *, 'Error reading logg_grid' - close (unit) - return + ! y derivative + if (dy < 1.0e-30_dp) then + df_dy = 0.0_dp + else if (j > 1 .and. j < ny) then + df_dy = (f4d(i, j + 1, k, lam) - f4d(i, j - 1, k, lam)) / (2.0_dp * dy) + else if (j == 1) then + df_dy = (f4d(i, j + 1, k, lam) - f4d(i, j, k, lam)) / dy + else + df_dy = (f4d(i, j, k, lam) - f4d(i, j - 1, k, lam)) / dy end if - read (unit, iostat=status) meta_grid - if (status /= 0) then - print *, 'Error reading meta_grid' - close (unit) - return + ! z derivative + if (dz < 1.0e-30_dp) then + df_dz = 0.0_dp + else if (k > 1 .and. k < nz) then + df_dz = (f4d(i, j, k + 1, lam) - f4d(i, j, k - 1, lam)) / (2.0_dp * dz) + else if (k == 1) then + df_dz = (f4d(i, j, k + 1, lam) - f4d(i, j, k, lam)) / dz + else + df_dz = (f4d(i, j, k, lam) - f4d(i, j, k - 1, lam)) / dz end if - read (unit, iostat=status) wavelengths - if (status /= 0) then - print *, 'Error reading wavelengths' - close (unit) - return - end if + end subroutine compute_derivatives_at_point_4d - ! Read flux cube - read (unit, iostat=status) flux_cube - if (status /= 0) then - print *, 'Error reading flux_cube' - close (unit) - return - end if - ! Close file and return success - close (unit) - end subroutine load_binary_data function hermite_tensor_interp3d(x_val, y_val, z_val, x_grid, y_grid, & z_grid, f_values) result(f_interp) @@ -246,11 +391,11 @@ function hermite_tensor_interp3d(x_val, y_val, z_val, x_grid, y_grid, & end do ! Precompute Hermite basis functions and derivatives - h_x = [h00(t_x), h01(t_x)] + h_x = [h00(t_x), h01(t_x)] hx_d = [h10(t_x), h11(t_x)] - h_y = [h00(t_y), h01(t_y)] + h_y = [h00(t_y), h01(t_y)] hy_d = [h10(t_y), h11(t_y)] - h_z = [h00(t_z), h01(t_z)] + h_z = [h00(t_z), h01(t_z)] hz_d = [h10(t_z), h11(t_z)] ! Final interpolation sum @@ -258,10 +403,10 @@ function hermite_tensor_interp3d(x_val, y_val, z_val, x_grid, y_grid, & do iz = 1, 2 do iy = 1, 2 do ix = 1, 2 - sum = sum + h_x(ix)*h_y(iy)*h_z(iz)*values(ix, iy, iz) - sum = sum + hx_d(ix)*h_y(iy)*h_z(iz)*dx*dx_values(ix, iy, iz) - sum = sum + h_x(ix)*hy_d(iy)*h_z(iz)*dy*dy_values(ix, iy, iz) - sum = sum + h_x(ix)*h_y(iy)*hz_d(iz)*dz*dz_values(ix, iy, iz) + sum = sum + h_x(ix)*h_y(iy)*h_z(iz) * values(ix, iy, iz) + sum = sum + hx_d(ix)*h_y(iy)*h_z(iz) * dx * dx_values(ix, iy, iz) + sum = sum + h_x(ix)*hy_d(iy)*h_z(iz) * dy * dy_values(ix, iy, iz) + sum = sum + h_x(ix)*h_y(iy)*hz_d(iz) * dz * dz_values(ix, iy, iz) end do end do end do @@ -269,94 +414,9 @@ function hermite_tensor_interp3d(x_val, y_val, z_val, x_grid, y_grid, & f_interp = sum end function hermite_tensor_interp3d - !--------------------------------------------------------------------------- - ! Find the cell containing the interpolation point - !--------------------------------------------------------------------------- - subroutine find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, & - i_x, i_y, i_z, t_x, t_y, t_z) - real(dp), intent(in) :: x_val, y_val, z_val - real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) - integer, intent(out) :: i_x, i_y, i_z - real(dp), intent(out) :: t_x, t_y, t_z - - ! Find x interval - call find_interval(x_grid, x_val, i_x, t_x) - - ! Find y interval - call find_interval(y_grid, y_val, i_y, t_y) - - ! Find z interval - call find_interval(z_grid, z_val, i_z, t_z) - end subroutine find_containing_cell - - !--------------------------------------------------------------------------- - ! Find the interval in a sorted array containing a value - !--------------------------------------------------------------------------- - - subroutine find_interval(x, val, i, t) - real(dp), intent(in) :: x(:), val - integer, intent(out) :: i - real(dp), intent(out) :: t - - integer :: n, lo, hi, mid - logical :: dummy_axis - - n = size(x) - - ! Detect dummy axis: all values == 0, 999, or -999 - dummy_axis = all(x == 0.0_dp) .or. all(x == 999.0_dp) .or. all(x == -999.0_dp) - - if (dummy_axis) then - ! Collapse axis: always use first point, no interpolation - i = 1 - t = 0.0_dp - return - end if - - ! ---------- ORIGINAL CODE BELOW ---------------- - - if (val <= x(1)) then - i = 1 - t = 0.0_dp - return - else if (val >= x(n)) then - i = n - 1 - t = 1.0_dp - return - end if - - lo = 1 - hi = n - do while (hi - lo > 1) - mid = (lo + hi)/2 - if (val >= x(mid)) then - lo = mid - else - hi = mid - end if - end do - - i = lo - t = (val - x(i))/(x(i + 1) - x(i)) - end subroutine find_interval - - !--------------------------------------------------------------------------- - ! Find the nearest grid point - !--------------------------------------------------------------------------- - subroutine find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, & - i_x, i_y, i_z) - real(dp), intent(in) :: x_val, y_val, z_val - real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) - integer, intent(out) :: i_x, i_y, i_z - - ! Find nearest grid points using intrinsic minloc - i_x = minloc(abs(x_val - x_grid), 1) - i_y = minloc(abs(y_val - y_grid), 1) - i_z = minloc(abs(z_val - z_grid), 1) - end subroutine find_nearest_point !--------------------------------------------------------------------------- - ! Compute derivatives at a grid point + ! Compute derivatives at a grid point (3-D version, used by scalar path) !--------------------------------------------------------------------------- subroutine compute_derivatives_at_point(f, i, j, k, nx, ny, nz, dx, dy, dz, & df_dx, df_dy, df_dz) @@ -366,7 +426,9 @@ subroutine compute_derivatives_at_point(f, i, j, k, nx, ny, nz, dx, dy, dz, & real(dp), intent(out) :: df_dx, df_dy, df_dz ! Compute x derivative using centered differences where possible - if (i > 1 .and. i < nx) then + if (dx < 1.0e-30_dp) then + df_dx = 0.0_dp ! degenerate axis + else if (i > 1 .and. i < nx) then df_dx = (f(i + 1, j, k) - f(i - 1, j, k))/(2.0_dp*dx) else if (i == 1) then df_dx = (f(i + 1, j, k) - f(i, j, k))/dx @@ -375,7 +437,9 @@ subroutine compute_derivatives_at_point(f, i, j, k, nx, ny, nz, dx, dy, dz, & end if ! Compute y derivative using centered differences where possible - if (j > 1 .and. j < ny) then + if (dy < 1.0e-30_dp) then + df_dy = 0.0_dp ! degenerate axis + else if (j > 1 .and. j < ny) then df_dy = (f(i, j + 1, k) - f(i, j - 1, k))/(2.0_dp*dy) else if (j == 1) then df_dy = (f(i, j + 1, k) - f(i, j, k))/dy @@ -384,7 +448,9 @@ subroutine compute_derivatives_at_point(f, i, j, k, nx, ny, nz, dx, dy, dz, & end if ! Compute z derivative using centered differences where possible - if (k > 1 .and. k < nz) then + if (dz < 1.0e-30_dp) then + df_dz = 0.0_dp ! degenerate axis + else if (k > 1 .and. k < nz) then df_dz = (f(i, j, k + 1) - f(i, j, k - 1))/(2.0_dp*dz) else if (k == 1) then df_dz = (f(i, j, k + 1) - f(i, j, k))/dz @@ -420,4 +486,4 @@ function h11(t) result(h) h = t**3 - t**2 end function h11 -end module hermite_interp +end module hermite_interp \ No newline at end of file diff --git a/colors/private/knn_interp.f90 b/colors/private/knn_interp.f90 index 172e92070..ecd179792 100644 --- a/colors/private/knn_interp.f90 +++ b/colors/private/knn_interp.f90 @@ -17,99 +17,265 @@ ! ! *********************************************************************** -! *********************************************************************** -! K-Nearest Neighbors interpolation module for spectral energy distributions (SEDs) -! *********************************************************************** +! knn interpolation for SEDs +! +! data-loading strategy selected by rq%cube_loaded: +! .true. -> extract neighbor SEDs from the preloaded 4-D cube +! .false. -> load individual SED files via the lookup table (fallback) module knn_interp use const_def, only: dp - use colors_utils, only: dilute_flux, load_sed + use colors_def, only: Colors_General_Info + use colors_utils, only: dilute_flux, load_sed_cached use utils_lib, only: mesa_error implicit none private - public :: construct_sed_knn, load_sed, interpolate_array, dilute_flux + public :: construct_sed_knn, interpolate_array contains - !--------------------------------------------------------------------------- - ! Main entry point: Construct a SED using KNN interpolation - !--------------------------------------------------------------------------- - subroutine construct_sed_knn(teff, log_g, metallicity, R, d, file_names, & - lu_teff, lu_logg, lu_meta, stellar_model_dir, & - wavelengths, fluxes) + ! main entry point -- construct a SED using KNN interpolation + ! strategy controlled by rq%cube_loaded (set at init) + subroutine construct_sed_knn(rq, teff, log_g, metallicity, R, d, & + stellar_model_dir, wavelengths, fluxes) + type(Colors_General_Info), intent(inout) :: rq real(dp), intent(in) :: teff, log_g, metallicity, R, d - real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:) character(len=*), intent(in) :: stellar_model_dir - character(len=100), intent(in) :: file_names(:) real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes + integer :: n_lambda + real(dp), dimension(:), allocatable :: interp_flux, diluted_flux + + if (rq%cube_loaded) then + ! fast path: extract neighbors from preloaded cube + call construct_sed_from_cube(rq, teff, log_g, metallicity, & + interp_flux, wavelengths) + n_lambda = size(wavelengths) + else + ! fallback path: load individual SED files + call construct_sed_from_files(rq, teff, log_g, metallicity, & + stellar_model_dir, interp_flux, wavelengths) + n_lambda = size(wavelengths) + end if + + allocate (diluted_flux(n_lambda)) + call dilute_flux(interp_flux, R, d, diluted_flux) + fluxes = diluted_flux + + end subroutine construct_sed_knn + + ! cube path: find 4 nearest grid points, extract their SEDs from cube_flux, blend by IDW + subroutine construct_sed_from_cube(rq, teff, log_g, metallicity, & + interp_flux, wavelengths) + type(Colors_General_Info), intent(inout) :: rq + real(dp), intent(in) :: teff, log_g, metallicity + real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths + + integer :: n_lambda, k + integer, dimension(4) :: nbr_it, nbr_ig, nbr_im + real(dp), dimension(4) :: distances, weights + real(dp) :: sum_weights + + n_lambda = size(rq%cube_wavelengths) + allocate (wavelengths(n_lambda)) + wavelengths = rq%cube_wavelengths + + ! find the 4 nearest grid points in the structured cube + call get_closest_grid_points(teff, log_g, metallicity, & + rq%cube_teff_grid, rq%cube_logg_grid, & + rq%cube_meta_grid, & + nbr_it, nbr_ig, nbr_im, distances) + + ! compute inverse-distance weights + do k = 1, 4 + if (distances(k) == 0.0_dp) distances(k) = 1.0e-10_dp + weights(k) = 1.0_dp/distances(k) + end do + sum_weights = sum(weights) + weights = weights/sum_weights + + ! blend neighbor SEDs from cube + allocate (interp_flux(n_lambda)) + interp_flux = 0.0_dp + do k = 1, 4 + interp_flux = interp_flux + weights(k)* & + rq%cube_flux(nbr_it(k), nbr_ig(k), nbr_im(k), :) + end do + + end subroutine construct_sed_from_cube + + ! fallback path: find 4 nearest models in the lookup table, load SEDs, blend by IDW + subroutine construct_sed_from_files(rq, teff, log_g, metallicity, & + stellar_model_dir, interp_flux, wavelengths) + use colors_utils, only: resolve_path + type(Colors_General_Info), intent(inout) :: rq + real(dp), intent(in) :: teff, log_g, metallicity + character(len=*), intent(in) :: stellar_model_dir + real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths + integer, dimension(4) :: closest_indices - real(dp), dimension(:), allocatable :: temp_wavelengths, temp_flux, common_wavelengths + real(dp), dimension(:), allocatable :: temp_flux, common_wavelengths real(dp), dimension(:, :), allocatable :: model_fluxes real(dp), dimension(4) :: weights, distances integer :: i, n_points real(dp) :: sum_weights - real(dp), dimension(:), allocatable :: diluted_flux - - ! Get the four closest stellar models - call get_closest_stellar_models(teff, log_g, metallicity, lu_teff, & - lu_logg, lu_meta, closest_indices) - - ! Load the first SED to define the wavelength grid - call load_sed(trim(stellar_model_dir)//trim(file_names(closest_indices(1))), & - closest_indices(1), temp_wavelengths, temp_flux) - - n_points = size(temp_wavelengths) - allocate (common_wavelengths(n_points)) - common_wavelengths = temp_wavelengths + character(len=512) :: resolved_dir + + resolved_dir = trim(resolve_path(stellar_model_dir)) + + ! get the four closest stellar models from the flat lookup table + call get_closest_stellar_models(teff, log_g, metallicity, & + rq%lu_teff, rq%lu_logg, rq%lu_meta, & + closest_indices) + + ! load the first SED to define the wavelength grid (using cache) + call load_sed_cached(rq, resolved_dir, closest_indices(1), temp_flux) + + ! get wavelengths from canonical copy on the handle + if (rq%fallback_wavelengths_set) then + n_points = size(rq%fallback_wavelengths) + allocate (common_wavelengths(n_points)) + common_wavelengths = rq%fallback_wavelengths + else + ! should not happen -- load_sed_cached sets this on first call + print *, 'KNN fallback: wavelengths not set after first SED load' + call mesa_error(__FILE__, __LINE__) + end if - ! Allocate flux array for the models (4 models, n_points each) allocate (model_fluxes(4, n_points)) - call interpolate_array(temp_wavelengths, temp_flux, common_wavelengths, model_fluxes(1, :)) + model_fluxes(1, :) = temp_flux(1:n_points) + if (allocated(temp_flux)) deallocate (temp_flux) - ! Load and interpolate remaining SEDs + ! load and store remaining SEDs do i = 2, 4 - call load_sed(trim(stellar_model_dir)//trim(file_names(closest_indices(i))), & - closest_indices(i), temp_wavelengths, temp_flux) - - call interpolate_array(temp_wavelengths, temp_flux, common_wavelengths, model_fluxes(i, :)) + call load_sed_cached(rq, resolved_dir, closest_indices(i), temp_flux) + model_fluxes(i, :) = temp_flux(1:n_points) + if (allocated(temp_flux)) deallocate (temp_flux) end do - ! Compute distances and weights for the four models + ! compute distances and weights for the four models do i = 1, 4 - distances(i) = sqrt((lu_teff(closest_indices(i)) - teff)**2 + & - (lu_logg(closest_indices(i)) - log_g)**2 + & - (lu_meta(closest_indices(i)) - metallicity)**2) - if (distances(i) == 0.0_dp) distances(i) = 1.0d-10 ! Prevent division by zero + distances(i) = sqrt((rq%lu_teff(closest_indices(i)) - teff)**2 + & + (rq%lu_logg(closest_indices(i)) - log_g)**2 + & + (rq%lu_meta(closest_indices(i)) - metallicity)**2) + if (distances(i) == 0.0_dp) distances(i) = 1.0e-10_dp weights(i) = 1.0_dp/distances(i) end do - ! Normalize weights sum_weights = sum(weights) weights = weights/sum_weights - ! Allocate output arrays - allocate (wavelengths(n_points), fluxes(n_points)) + allocate (wavelengths(n_points)) wavelengths = common_wavelengths - fluxes = 0.0_dp - ! Perform weighted combination of the model fluxes (still at the stellar surface) + allocate (interp_flux(n_points)) + interp_flux = 0.0_dp + + ! weighted combination of model fluxes do i = 1, 4 - fluxes = fluxes + weights(i)*model_fluxes(i, :) + interp_flux = interp_flux + weights(i)*model_fluxes(i, :) end do - ! Now, apply the dilution factor (R/d)^2 to convert the surface flux density - ! into the observed flux density at Earth. - allocate (diluted_flux(n_points)) - call dilute_flux(fluxes, R, d, diluted_flux) - fluxes = diluted_flux + end subroutine construct_sed_from_files - end subroutine construct_sed_knn + ! find the 4 closest grid points in the structured cube (normalised euclidean distance) + subroutine get_closest_grid_points(teff, log_g, metallicity, & + teff_grid, logg_grid, meta_grid, & + nbr_it, nbr_ig, nbr_im, distances) + real(dp), intent(in) :: teff, log_g, metallicity + real(dp), intent(in) :: teff_grid(:), logg_grid(:), meta_grid(:) + integer, dimension(4), intent(out) :: nbr_it, nbr_ig, nbr_im + real(dp), dimension(4), intent(out) :: distances + + integer :: it, ig, im, j + real(dp) :: dist, norm_teff, norm_logg, norm_meta + real(dp) :: teff_min, teff_max, logg_min, logg_max, meta_min, meta_max + real(dp) :: scaled_t, scaled_g, scaled_m, dt, dg, dm + logical :: use_teff_dim, use_logg_dim, use_meta_dim + + distances = huge(1.0_dp) + nbr_it = 1; nbr_ig = 1; nbr_im = 1 + + ! normalisation ranges + teff_min = minval(teff_grid); teff_max = maxval(teff_grid) + logg_min = minval(logg_grid); logg_max = maxval(logg_grid) + meta_min = minval(meta_grid); meta_max = maxval(meta_grid) + + ! detect dummy axes + use_teff_dim = .not. (all(teff_grid == 0.0_dp) .or. & + all(teff_grid == 999.0_dp) .or. all(teff_grid == -999.0_dp)) + use_logg_dim = .not. (all(logg_grid == 0.0_dp) .or. & + all(logg_grid == 999.0_dp) .or. all(logg_grid == -999.0_dp)) + use_meta_dim = .not. (all(meta_grid == 0.0_dp) .or. & + all(meta_grid == 999.0_dp) .or. all(meta_grid == -999.0_dp)) + + ! normalised target values + norm_teff = 0.0_dp; norm_logg = 0.0_dp; norm_meta = 0.0_dp + if (use_teff_dim .and. teff_max - teff_min > 0.0_dp) & + norm_teff = (teff - teff_min)/(teff_max - teff_min) + if (use_logg_dim .and. logg_max - logg_min > 0.0_dp) & + norm_logg = (log_g - logg_min)/(logg_max - logg_min) + if (use_meta_dim .and. meta_max - meta_min > 0.0_dp) & + norm_meta = (metallicity - meta_min)/(meta_max - meta_min) + + do it = 1, size(teff_grid) + if (use_teff_dim .and. teff_max - teff_min > 0.0_dp) then + scaled_t = (teff_grid(it) - teff_min)/(teff_max - teff_min) + else + scaled_t = 0.0_dp + end if + dt = 0.0_dp + if (use_teff_dim) dt = (scaled_t - norm_teff)**2 + + do ig = 1, size(logg_grid) + if (use_logg_dim .and. logg_max - logg_min > 0.0_dp) then + scaled_g = (logg_grid(ig) - logg_min)/(logg_max - logg_min) + else + scaled_g = 0.0_dp + end if + dg = 0.0_dp + if (use_logg_dim) dg = (scaled_g - norm_logg)**2 + + do im = 1, size(meta_grid) + if (use_meta_dim .and. meta_max - meta_min > 0.0_dp) then + scaled_m = (meta_grid(im) - meta_min)/(meta_max - meta_min) + else + scaled_m = 0.0_dp + end if + dm = 0.0_dp + if (use_meta_dim) dm = (scaled_m - norm_meta)**2 + + dist = dt + dg + dm + + ! insert into sorted top-4 if closer + do j = 1, 4 + if (dist < distances(j)) then + if (j < 4) then + distances(j + 1:4) = distances(j:3) + nbr_it(j + 1:4) = nbr_it(j:3) + nbr_ig(j + 1:4) = nbr_ig(j:3) + nbr_im(j + 1:4) = nbr_im(j:3) + end if + distances(j) = dist + nbr_it(j) = it + nbr_ig(j) = ig + nbr_im(j) = im + exit + end if + end do + end do + end do + end do + + ! convert squared distances to actual distances for weighting + do j = 1, 4 + distances(j) = sqrt(distances(j)) + end do + + end subroutine get_closest_grid_points - !--------------------------------------------------------------------------- - ! Identify the four closest stellar models - !--------------------------------------------------------------------------- + ! find the four closest stellar models in the flat lookup table subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, & lu_logg, lu_meta, closest_indices) real(dp), intent(in) :: teff, log_g, metallicity @@ -129,7 +295,7 @@ subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, & min_distances = huge(1.0) indices = -1 - ! Find min and max for normalization + ! find min and max for normalisation teff_min = minval(lu_teff) teff_max = maxval(lu_teff) logg_min = minval(lu_logg) @@ -137,7 +303,6 @@ subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, & meta_min = minval(lu_meta) meta_max = maxval(lu_meta) - ! Allocate and scale lookup table values allocate (scaled_lu_teff(n), scaled_lu_logg(n), scaled_lu_meta(n)) if (teff_max - teff_min > 0.0_dp) then @@ -152,17 +317,16 @@ subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, & scaled_lu_meta = (lu_meta - meta_min)/(meta_max - meta_min) end if - ! Normalize input parameters + ! normalise input parameters norm_teff = (teff - teff_min)/(teff_max - teff_min) norm_logg = (log_g - logg_min)/(logg_max - logg_min) norm_meta = (metallicity - meta_min)/(meta_max - meta_min) - ! Detect dummy axes once (outside the loop) + ! detect dummy axes -- skip degenerate dimensions in distance calc use_teff_dim = .not. (all(lu_teff == 0.0_dp) .or. all(lu_teff == 999.0_dp) .or. all(lu_teff == -999.0_dp)) use_logg_dim = .not. (all(lu_logg == 0.0_dp) .or. all(lu_logg == 999.0_dp) .or. all(lu_logg == -999.0_dp)) use_meta_dim = .not. (all(lu_meta == 0.0_dp) .or. all(lu_meta == 999.0_dp) .or. all(lu_meta == -999.0_dp)) - ! Find closest models do i = 1, n teff_dist = 0.0_dp logg_dist = 0.0_dp @@ -180,7 +344,7 @@ subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, & meta_dist = scaled_lu_meta(i) - norm_meta end if - ! Compute distance using only valid dimensions + ! compute distance using only valid dimensions distance = 0.0_dp if (use_teff_dim) distance = distance + teff_dist**2 if (use_logg_dim) distance = distance + logg_dist**2 @@ -188,7 +352,7 @@ subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, & do j = 1, 4 if (distance < min_distances(j)) then - ! Shift larger distances down + ! shift larger distances down if (j < 4) then min_distances(j + 1:4) = min_distances(j:3) indices(j + 1:4) = indices(j:3) @@ -203,15 +367,12 @@ subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, & closest_indices = indices end subroutine get_closest_stellar_models - !--------------------------------------------------------------------------- - ! Linear interpolation (binary search version for efficiency) - !--------------------------------------------------------------------------- + ! linear interpolation -- binary search subroutine linear_interpolate(x, y, x_val, y_val) real(dp), intent(in) :: x(:), y(:), x_val real(dp), intent(out) :: y_val integer :: low, high, mid - ! Validate input sizes if (size(x) < 2) then print *, "Error: x array has fewer than 2 points." y_val = 0.0_dp @@ -224,7 +385,7 @@ subroutine linear_interpolate(x, y, x_val, y_val) return end if - ! Handle out-of-bounds cases + ! handle out-of-bounds cases if (x_val <= x(1)) then y_val = y(1) return @@ -233,7 +394,7 @@ subroutine linear_interpolate(x, y, x_val, y_val) return end if - ! Binary search to find the proper interval [x(low), x(low+1)] + ! binary search to find interval [x(low), x(low+1)] low = 1 high = size(x) do while (high - low > 1) @@ -245,19 +406,15 @@ subroutine linear_interpolate(x, y, x_val, y_val) end if end do - ! Linear interpolation between x(low) and x(low+1) y_val = y(low) + (y(low + 1) - y(low))/(x(low + 1) - x(low))*(x_val - x(low)) end subroutine linear_interpolate - !--------------------------------------------------------------------------- - ! Array interpolation for SED construction - !--------------------------------------------------------------------------- + ! array interpolation for SED/filter alignment subroutine interpolate_array(x_in, y_in, x_out, y_out) real(dp), intent(in) :: x_in(:), y_in(:), x_out(:) real(dp), intent(out) :: y_out(:) integer :: i - ! Validate input sizes if (size(x_in) < 2 .or. size(y_in) < 2) then print *, "Error: x_in or y_in arrays have fewer than 2 points." call mesa_error(__FILE__, __LINE__) @@ -278,4 +435,4 @@ subroutine interpolate_array(x_in, y_in, x_out, y_out) end do end subroutine interpolate_array -end module knn_interp +end module knn_interp \ No newline at end of file diff --git a/colors/private/linear_interp.f90 b/colors/private/linear_interp.f90 index f5aa09207..86f54f8d1 100644 --- a/colors/private/linear_interp.f90 +++ b/colors/private/linear_interp.f90 @@ -17,13 +17,18 @@ ! ! *********************************************************************** -! *********************************************************************** -! Linear interpolation module for spectral energy distributions (SEDs) -! *********************************************************************** +! linear interpolation for SEDs +! +! data-loading strategy selected by rq%cube_loaded: +! .true. -> use the preloaded 4-D flux cube on the handle +! .false. -> load individual SED files via the lookup table (fallback) module linear_interp use const_def, only: dp - use colors_utils, only: dilute_flux + use colors_def, only: Colors_General_Info + use colors_utils, only: dilute_flux, find_containing_cell, find_interval, & + find_nearest_point, find_bracket_index, & + load_sed_cached, load_stencil use utils_lib, only: mesa_error implicit none @@ -32,223 +37,238 @@ module linear_interp contains - !--------------------------------------------------------------------------- - ! Main entry point: Construct a SED using linear interpolation - !--------------------------------------------------------------------------- - - subroutine construct_sed_linear(teff, log_g, metallicity, R, d, file_names, & - lu_teff, lu_logg, lu_meta, stellar_model_dir, & - wavelengths, fluxes) - + ! main entry point -- construct a SED using trilinear interpolation + ! strategy controlled by rq%cube_loaded (set at init) + subroutine construct_sed_linear(rq, teff, log_g, metallicity, R, d, & + stellar_model_dir, wavelengths, fluxes) + type(Colors_General_Info), intent(inout) :: rq real(dp), intent(in) :: teff, log_g, metallicity, R, d - real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:) character(len=*), intent(in) :: stellar_model_dir - character(len=100), intent(in) :: file_names(:) real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes - integer :: i, n_lambda, status, n_teff, n_logg, n_meta + integer :: n_lambda real(dp), dimension(:), allocatable :: interp_flux, diluted_flux - real(dp), dimension(:, :, :, :), allocatable :: precomputed_flux_cube - real(dp), dimension(:, :, :), allocatable :: flux_cube_lambda - real(dp) :: min_flux, max_flux, mean_flux, progress_pct - - ! Parameter grids - real(dp), allocatable :: teff_grid(:), logg_grid(:), meta_grid(:) - character(len=256) :: bin_filename, clean_path - logical :: file_exists - - ! Clean up any double slashes in the path - clean_path = trim(stellar_model_dir) - if (clean_path(len_trim(clean_path):len_trim(clean_path)) == '/') then - bin_filename = trim(clean_path)//'flux_cube.bin' - else - bin_filename = trim(clean_path)//'/flux_cube.bin' - end if - - ! Check if file exists first - INQUIRE (file=bin_filename, EXIST=file_exists) - if (.not. file_exists) then - print *, 'Missing required binary file for interpolation' - call mesa_error(__FILE__, __LINE__) - end if + if (rq%cube_loaded) then + ! fast path: use preloaded cube from handle + n_lambda = size(rq%cube_wavelengths) - ! Load the data from binary file - call load_binary_data(bin_filename, teff_grid, logg_grid, meta_grid, & - wavelengths, precomputed_flux_cube, status) + allocate (wavelengths(n_lambda)) + wavelengths = rq%cube_wavelengths - if (status /= 0) then - print *, 'Binary data loading error' - call mesa_error(__FILE__, __LINE__) + ! vectorised interpolation -- cell location computed once, reused across n_lambda + allocate (interp_flux(n_lambda)) + call trilinear_interp_vector(teff, log_g, metallicity, & + rq%cube_teff_grid, rq%cube_logg_grid, & + rq%cube_meta_grid, & + rq%cube_flux, n_lambda, interp_flux) + else + ! fallback path: load individual SED files from lookup table + call construct_sed_from_files(rq, teff, log_g, metallicity, & + stellar_model_dir, interp_flux, wavelengths) + n_lambda = size(wavelengths) end if - n_teff = size(teff_grid) - n_logg = size(logg_grid) - n_meta = size(meta_grid) - n_lambda = size(wavelengths) - - ! Allocate space for interpolated flux - allocate (interp_flux(n_lambda)) - allocate (flux_cube_lambda(n_teff, n_logg, n_meta)) - - ! Perform trilinear interpolation for each wavelength - do i = 1, n_lambda + allocate (diluted_flux(n_lambda)) + call dilute_flux(interp_flux, R, d, diluted_flux) + fluxes = diluted_flux - ! Extract the 3D grid for this wavelength - flux_cube_lambda = precomputed_flux_cube(:, :, :, i) + end subroutine construct_sed_linear - ! Simple trilinear interpolation at the target parameters - interp_flux(i) = trilinear_interp(teff, log_g, metallicity, & - teff_grid, logg_grid, meta_grid, flux_cube_lambda) - end do + ! fallback: build a 2x2x2 sub-cube from SED files, then trilinear-interpolate + ! unlike hermite, no derivative context needed -- stencil is exactly the 2x2x2 cell corners + subroutine construct_sed_from_files(rq, teff, log_g, metallicity, & + stellar_model_dir, interp_flux, wavelengths) + use colors_utils, only: resolve_path, build_grid_to_lu_map + type(Colors_General_Info), intent(inout) :: rq + real(dp), intent(in) :: teff, log_g, metallicity + character(len=*), intent(in) :: stellar_model_dir + real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths + + integer :: i_t, i_g, i_m ! bracketing indices in unique grids + integer :: lo_t, hi_t, lo_g, hi_g, lo_m, hi_m ! stencil bounds + integer :: nt, ng, nm, n_lambda + character(len=512) :: resolved_dir + logical :: need_reload + + resolved_dir = trim(resolve_path(stellar_model_dir)) + + ! ensure the grid-to-lu mapping exists (built once, then reused) + if (.not. rq%grid_map_built) call build_grid_to_lu_map(rq) + + ! find bracketing cell in the unique grids + call find_bracket_index(rq%u_teff, teff, i_t) + call find_bracket_index(rq%u_logg, log_g, i_g) + call find_bracket_index(rq%u_meta, metallicity, i_m) + + ! check if the stencil cache is still valid for this cell + need_reload = .true. + if (rq%stencil_valid .and. & + i_t == rq%stencil_i_t .and. & + i_g == rq%stencil_i_g .and. & + i_m == rq%stencil_i_m) then + need_reload = .false. + end if + if (need_reload) then + ! trilinear needs exactly the 2x2x2 cell corners -- no extension + nt = size(rq%u_teff) + ng = size(rq%u_logg) + nm = size(rq%u_meta) - deallocate(flux_cube_lambda) + if (nt < 2) then + lo_t = 1; hi_t = 1 + else + lo_t = i_t + hi_t = min(nt, i_t + 1) + end if - ! Calculate statistics for validation - min_flux = minval(interp_flux) - max_flux = maxval(interp_flux) - mean_flux = sum(interp_flux)/n_lambda + if (ng < 2) then + lo_g = 1; hi_g = 1 + else + lo_g = i_g + hi_g = min(ng, i_g + 1) + end if - ! Apply distance dilution to get observed flux - allocate (diluted_flux(n_lambda)) - call dilute_flux(interp_flux, R, d, diluted_flux) - fluxes = diluted_flux + if (nm < 2) then + lo_m = 1; hi_m = 1 + else + lo_m = i_m + hi_m = min(nm, i_m + 1) + end if - ! Calculate statistics after dilution - min_flux = minval(diluted_flux) - max_flux = maxval(diluted_flux) - mean_flux = sum(diluted_flux)/n_lambda + ! load SEDs for every stencil point (using memory cache) + call load_stencil(rq, resolved_dir, lo_t, hi_t, lo_g, hi_g, lo_m, hi_m) + + ! store subgrid arrays on the handle + if (allocated(rq%stencil_teff)) deallocate (rq%stencil_teff) + if (allocated(rq%stencil_logg)) deallocate (rq%stencil_logg) + if (allocated(rq%stencil_meta)) deallocate (rq%stencil_meta) + + allocate (rq%stencil_teff(hi_t - lo_t + 1)) + allocate (rq%stencil_logg(hi_g - lo_g + 1)) + allocate (rq%stencil_meta(hi_m - lo_m + 1)) + rq%stencil_teff = rq%u_teff(lo_t:hi_t) + rq%stencil_logg = rq%u_logg(lo_g:hi_g) + rq%stencil_meta = rq%u_meta(lo_m:hi_m) + + rq%stencil_i_t = i_t + rq%stencil_i_g = i_g + rq%stencil_i_m = i_m + rq%stencil_valid = .true. + end if - end subroutine construct_sed_linear + n_lambda = size(rq%stencil_wavelengths) + allocate (wavelengths(n_lambda)) + wavelengths = rq%stencil_wavelengths - !--------------------------------------------------------------------------- - ! Load data from binary file - !--------------------------------------------------------------------------- - subroutine load_binary_data(filename, teff_grid, logg_grid, meta_grid, & - wavelengths, flux_cube, status) - character(len=*), intent(in) :: filename - real(dp), allocatable, intent(out) :: teff_grid(:), logg_grid(:), meta_grid(:) - real(dp), allocatable, intent(out) :: wavelengths(:) - real(dp), allocatable, intent(out) :: flux_cube(:, :, :, :) - integer, intent(out) :: status - - integer :: unit, n_teff, n_logg, n_meta, n_lambda - - unit = 99 - status = 0 - - ! Open the binary file - open (unit=unit, file=filename, status='OLD', ACCESS='STREAM', FORM='UNFORMATTED', iostat=status) - if (status /= 0) then - !print *, 'Error opening binary file:', trim(filename) - return - end if + allocate (interp_flux(n_lambda)) + call trilinear_interp_vector(teff, log_g, metallicity, & + rq%stencil_teff, rq%stencil_logg, rq%stencil_meta, & + rq%stencil_fluxes, n_lambda, interp_flux) - ! Read dimensions - read (unit, iostat=status) n_teff, n_logg, n_meta, n_lambda - if (status /= 0) then - !print *, 'Error reading dimensions from binary file' - close (unit) - return - end if + end subroutine construct_sed_from_files - ! Allocate arrays based on dimensions - allocate (teff_grid(n_teff), STAT=status) - if (status /= 0) then - !print *, 'Error allocating teff_grid array' - close (unit) - return - end if + ! vectorised trilinear interpolation over all wavelengths + ! cell location depends only on (teff, logg, meta) -- computed once, reused across n_lambda + subroutine trilinear_interp_vector(x_val, y_val, z_val, & + x_grid, y_grid, z_grid, & + f_values_4d, n_lambda, result_flux) + real(dp), intent(in) :: x_val, y_val, z_val + real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) + real(dp), intent(in) :: f_values_4d(:, :, :, :) ! (nx, ny, nz, n_lambda) + integer, intent(in) :: n_lambda + real(dp), intent(out) :: result_flux(n_lambda) - allocate (logg_grid(n_logg), STAT=status) - if (status /= 0) then - !print *, 'Error allocating logg_grid array' - close (unit) - return - end if + integer :: i_x, i_y, i_z, lam + real(dp) :: t_x, t_y, t_z + integer :: nx, ny, nz + real(dp) :: c000, c001, c010, c011, c100, c101, c110, c111 + real(dp) :: c00, c01, c10, c11, c0, c1 + real(dp) :: lin_result, log_result + real(dp), parameter :: tiny_value = 1.0e-10_dp - allocate (meta_grid(n_meta), STAT=status) - if (status /= 0) then - !print *, 'Error allocating meta_grid array' - close (unit) - return - end if + nx = size(x_grid) + ny = size(y_grid) + nz = size(z_grid) - allocate (wavelengths(n_lambda), STAT=status) - if (status /= 0) then - !print *, 'Error allocating wavelengths array' - close (unit) - return - end if + ! locate the cell once + call find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, & + i_x, i_y, i_z, t_x, t_y, t_z) - allocate (flux_cube(n_teff, n_logg, n_meta, n_lambda), STAT=status) - if (status /= 0) then - !print *, 'Error allocating flux_cube array' - close (unit) - return - end if + ! boundary safety check + if (i_x < 1) i_x = 1 + if (i_y < 1) i_y = 1 + if (i_z < 1) i_z = 1 + if (i_x >= nx) i_x = max(1, nx - 1) + if (i_y >= ny) i_y = max(1, ny - 1) + if (i_z >= nz) i_z = max(1, nz - 1) + + ! clamp interpolation parameters to [0,1] + t_x = max(0.0_dp, min(1.0_dp, t_x)) + t_y = max(0.0_dp, min(1.0_dp, t_y)) + t_z = max(0.0_dp, min(1.0_dp, t_z)) + + ! loop over wavelengths with the same cell location + do lam = 1, n_lambda + ! get the 8 corners of the cube with safety floor + c000 = max(tiny_value, f_values_4d(i_x, i_y, i_z, lam)) + c001 = max(tiny_value, f_values_4d(i_x, i_y, i_z + 1, lam)) + c010 = max(tiny_value, f_values_4d(i_x, i_y + 1, i_z, lam)) + c011 = max(tiny_value, f_values_4d(i_x, i_y + 1, i_z + 1, lam)) + c100 = max(tiny_value, f_values_4d(i_x + 1, i_y, i_z, lam)) + c101 = max(tiny_value, f_values_4d(i_x + 1, i_y, i_z + 1, lam)) + c110 = max(tiny_value, f_values_4d(i_x + 1, i_y + 1, i_z, lam)) + c111 = max(tiny_value, f_values_4d(i_x + 1, i_y + 1, i_z + 1, lam)) + + ! standard linear interpolation first (safer) + c00 = c000*(1.0_dp - t_x) + c100*t_x + c01 = c001*(1.0_dp - t_x) + c101*t_x + c10 = c010*(1.0_dp - t_x) + c110*t_x + c11 = c011*(1.0_dp - t_x) + c111*t_x - ! Read grid arrays - read (unit, iostat=status) teff_grid - if (status /= 0) then - !print *, 'Error reading teff_grid' - GOTO 999 ! Cleanup and return - end if + c0 = c00*(1.0_dp - t_y) + c10*t_y + c1 = c01*(1.0_dp - t_y) + c11*t_y - read (unit, iostat=status) logg_grid - if (status /= 0) then - !print *, 'Error reading logg_grid' - GOTO 999 ! Cleanup and return - end if + lin_result = c0*(1.0_dp - t_z) + c1*t_z - read (unit, iostat=status) meta_grid - if (status /= 0) then - !print *, 'Error reading meta_grid' - GOTO 999 ! Cleanup and return - end if + ! if valid, try log-space interpolation (smoother for flux) + if (lin_result > tiny_value) then + c00 = log(c000)*(1.0_dp - t_x) + log(c100)*t_x + c01 = log(c001)*(1.0_dp - t_x) + log(c101)*t_x + c10 = log(c010)*(1.0_dp - t_x) + log(c110)*t_x + c11 = log(c011)*(1.0_dp - t_x) + log(c111)*t_x - read (unit, iostat=status) wavelengths - if (status /= 0) then - !print *, 'Error reading wavelengths' - GOTO 999 ! Cleanup and return - end if + c0 = c00*(1.0_dp - t_y) + c10*t_y + c1 = c01*(1.0_dp - t_y) + c11*t_y - ! Read flux cube - read (unit, iostat=status) flux_cube - if (status /= 0) then - !print *, 'Error reading flux_cube' - GOTO 999 ! Cleanup and return - end if + log_result = c0*(1.0_dp - t_z) + c1*t_z - ! Close file and return success - close (unit) - return + ! only use log-space result if valid + if (log_result == log_result) then ! NaN check + lin_result = exp(log_result) + end if + end if -999 CONTINUE - ! Cleanup on error - close (unit) - return + ! final sanity check -- fall back to nearest neighbour + if (lin_result /= lin_result .or. lin_result <= 0.0_dp) then + call find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, & + i_x, i_y, i_z) + lin_result = max(tiny_value, f_values_4d(i_x, i_y, i_z, lam)) + end if -! After reading the grid arrays -!print *, 'Teff grid min/max:', minval(teff_grid), maxval(teff_grid) -!print *, 'logg grid min/max:', minval(logg_grid), maxval(logg_grid) -!print *, 'meta grid min/max:', minval(meta_grid), maxval(meta_grid) + result_flux(lam) = lin_result + end do - end subroutine load_binary_data + end subroutine trilinear_interp_vector - !--------------------------------------------------------------------------- - ! Simple trilinear interpolation function - !--------------------------------------------------------------------------- -!--------------------------------------------------------------------------- -! Log-space trilinear interpolation function with normalization -!--------------------------------------------------------------------------- + ! scalar trilinear interpolation (external callers / single-wavelength use) + ! retained for backward compatibility function trilinear_interp(x_val, y_val, z_val, x_grid, y_grid, z_grid, f_values) result(f_interp) real(dp), intent(in) :: x_val, y_val, z_val real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) real(dp), intent(in) :: f_values(:, :, :) real(dp) :: f_interp - ! Compute log-space result real(dp) :: log_result integer :: i_x, i_y, i_z real(dp) :: t_x, t_y, t_z @@ -256,24 +276,24 @@ function trilinear_interp(x_val, y_val, z_val, x_grid, y_grid, z_grid, f_values) real(dp) :: c00, c01, c10, c11, c0, c1 real(dp), parameter :: tiny_value = 1.0e-10_dp - ! Find containing cell and parameter values using binary search + ! find containing cell and parameter values using binary search call find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, & i_x, i_y, i_z, t_x, t_y, t_z) - ! Boundary safety check - if (i_x < lbound(x_grid,1)) i_x = lbound(x_grid,1) - if (i_y < lbound(y_grid,1)) i_y = lbound(y_grid,1) - if (i_z < lbound(z_grid,1)) i_z = lbound(z_grid,1) - if (i_x >= ubound(x_grid,1)) i_x = ubound(x_grid,1) - 1 - if (i_y >= ubound(y_grid,1)) i_y = ubound(y_grid,1) - 1 - if (i_z >= ubound(z_grid,1)) i_z = ubound(z_grid,1) - 1 + ! boundary safety check + if (i_x < lbound(x_grid, 1)) i_x = lbound(x_grid, 1) + if (i_y < lbound(y_grid, 1)) i_y = lbound(y_grid, 1) + if (i_z < lbound(z_grid, 1)) i_z = lbound(z_grid, 1) + if (i_x >= ubound(x_grid, 1)) i_x = ubound(x_grid, 1) - 1 + if (i_y >= ubound(y_grid, 1)) i_y = ubound(y_grid, 1) - 1 + if (i_z >= ubound(z_grid, 1)) i_z = ubound(z_grid, 1) - 1 - ! Force interpolation parameters to be in [0,1] + ! clamp interpolation parameters to [0,1] t_x = max(0.0_dp, MIN(1.0_dp, t_x)) t_y = max(0.0_dp, MIN(1.0_dp, t_y)) t_z = max(0.0_dp, MIN(1.0_dp, t_z)) - ! Get the corners of the cube with safety checks + ! get the corners of the cube with safety checks c000 = max(tiny_value, f_values(i_x, i_y, i_z)) c001 = max(tiny_value, f_values(i_x, i_y, i_z + 1)) c010 = max(tiny_value, f_values(i_x, i_y + 1, i_z)) @@ -283,7 +303,7 @@ function trilinear_interp(x_val, y_val, z_val, x_grid, y_grid, z_grid, f_values) c110 = max(tiny_value, f_values(i_x + 1, i_y + 1, i_z)) c111 = max(tiny_value, f_values(i_x + 1, i_y + 1, i_z + 1)) - ! Try standard linear interpolation first (safer) + ! try standard linear interpolation first (safer) c00 = c000*(1.0_dp - t_x) + c100*t_x c01 = c001*(1.0_dp - t_x) + c101*t_x c10 = c010*(1.0_dp - t_x) + c110*t_x @@ -294,9 +314,8 @@ function trilinear_interp(x_val, y_val, z_val, x_grid, y_grid, z_grid, f_values) f_interp = c0*(1.0_dp - t_z) + c1*t_z - ! If the linear result is valid and non-zero, try log space + ! if valid, try log-space interpolation (smoother for flux) if (f_interp > tiny_value) then - ! Perform log-space interpolation c00 = log(c000)*(1.0_dp - t_x) + log(c100)*t_x c01 = log(c001)*(1.0_dp - t_x) + log(c101)*t_x c10 = log(c010)*(1.0_dp - t_x) + log(c110)*t_x @@ -307,133 +326,17 @@ function trilinear_interp(x_val, y_val, z_val, x_grid, y_grid, z_grid, f_values) log_result = c0*(1.0_dp - t_z) + c1*t_z - ! Only use the log-space result if it's valid + ! only use the log-space result if it's valid if (log_result == log_result) then ! NaN check f_interp = EXP(log_result) end if end if - ! Final sanity check + ! final sanity check if (f_interp /= f_interp .or. f_interp <= 0.0_dp) then - ! If we somehow still got an invalid result, use nearest neighbor call find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, i_x, i_y, i_z) f_interp = max(tiny_value, f_values(i_x, i_y, i_z)) end if end function trilinear_interp - !--------------------------------------------------------------------------- - ! Find the cell containing the interpolation point - !--------------------------------------------------------------------------- - subroutine find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, & - i_x, i_y, i_z, t_x, t_y, t_z) - real(dp), intent(in) :: x_val, y_val, z_val - real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) - integer, intent(out) :: i_x, i_y, i_z - real(dp), intent(out) :: t_x, t_y, t_z - - ! Find x interval - call find_interval(x_grid, x_val, i_x, t_x) - - ! Find y interval - call find_interval(y_grid, y_val, i_y, t_y) - - ! Find z interval - call find_interval(z_grid, z_val, i_z, t_z) - end subroutine find_containing_cell - - !--------------------------------------------------------------------------- - ! Find the interval in a sorted array containing a value - !--------------------------------------------------------------------------- - subroutine find_interval(x, val, i, t) - real(dp), intent(in) :: x(:), val - integer, intent(out) :: i - real(dp), intent(out) :: t - - integer :: n, lo, hi, mid - logical :: dummy_axis - - n = size(x) - - ! Detect dummy axis - dummy_axis = all(x == 0.0_dp) .or. all(x == 999.0_dp) .or. all(x == -999.0_dp) - - if (dummy_axis) then - ! Collapse: use the first element of the axis, no interpolation - i = 1 - t = 0.0_dp - return - end if - - ! --- ORIGINAL CODE BELOW --- - if (val <= x(1)) then - i = 1 - t = 0.0_dp - return - else if (val >= x(n)) then - i = n - 1 - t = 1.0_dp - return - end if - - lo = 1 - hi = n - do while (hi - lo > 1) - mid = (lo + hi)/2 - if (val >= x(mid)) then - lo = mid - else - hi = mid - end if - end do - - i = lo - t = (val - x(i))/(x(i + 1) - x(i)) - end subroutine find_interval - - !--------------------------------------------------------------------------- - ! Find the nearest grid point - !--------------------------------------------------------------------------- - subroutine find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, & - i_x, i_y, i_z) - real(dp), intent(in) :: x_val, y_val, z_val - real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:) - integer, intent(out) :: i_x, i_y, i_z - - integer :: i - real(dp) :: min_dist, dist - - ! Find nearest x grid point - min_dist = abs(x_val - x_grid(1)) - i_x = 1 - do i = 2, size(x_grid) - dist = abs(x_val - x_grid(i)) - if (dist < min_dist) then - min_dist = dist - i_x = i - end if - end do - - ! Find nearest y grid point - min_dist = abs(y_val - y_grid(1)) - i_y = 1 - do i = 2, size(y_grid) - dist = abs(y_val - y_grid(i)) - if (dist < min_dist) then - min_dist = dist - i_y = i - end if - end do - - ! Find nearest z grid point - min_dist = abs(z_val - z_grid(1)) - i_z = 1 - do i = 2, size(z_grid) - dist = abs(z_val - z_grid(i)) - if (dist < min_dist) then - min_dist = dist - i_z = i - end if - end do - end subroutine find_nearest_point - -end module linear_interp +end module linear_interp \ No newline at end of file diff --git a/colors/private/shared_funcs.f90 b/colors/private/shared_funcs.f90 deleted file mode 100644 index 59d371cb7..000000000 --- a/colors/private/shared_funcs.f90 +++ /dev/null @@ -1,535 +0,0 @@ - -MODULE shared_funcs - USE const_def, ONLY: dp, strlen - USE utils_lib, ONLY: mesa_error - IMPLICIT NONE - - PRIVATE - PUBLIC :: dilute_flux, trapezoidalintegration, rombergintegration, SimpsonIntegration, loadsed, & - loadfilter, loadvegased, load_lookuptable, remove_dat - -CONTAINS - - !--------------------------------------------------------------------------- - ! Apply dilution factor to convert surface flux to observed flux - !--------------------------------------------------------------------------- - SUBROUTINE dilute_flux(surface_flux, R, d, calibrated_flux) - REAL(dp), INTENT(IN) :: surface_flux(:) - REAL(dp), INTENT(IN) :: R, d ! R = stellar radius, d = distance (both in the same units, e.g., cm) - REAL(dp), INTENT(OUT) :: calibrated_flux(:) - - ! Check that the output array has the same size as the input - IF (SIZE(calibrated_flux) /= SIZE(surface_flux)) THEN - PRINT *, "Error in dilute_flux: Output array must have the same size as input array." - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Apply the dilution factor (R/d)^2 to each element - calibrated_flux = surface_flux*((R/d)**2) - END SUBROUTINE dilute_flux - - !########################################################### - !## MATHS - !########################################################### - - !**************************** - !Trapezoidal and Simpson Integration For Flux Calculation - !**************************** - - SUBROUTINE trapezoidalintegration(x, y, result) - REAL(DP), DIMENSION(:), INTENT(IN) :: x, y - REAL(DP), INTENT(OUT) :: result - - INTEGER :: i, n - REAL(DP) :: sum - - n = SIZE(x) - sum = 0.0_dp - - ! Validate input sizes - IF (SIZE(x) /= SIZE(y)) THEN - PRINT *, "Error: x and y arrays must have the same size." - CALL mesa_error(__FILE__, __LINE__) - END IF - - IF (SIZE(x) < 2) THEN - PRINT *, "Error: x and y arrays must have at least 2 elements." - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Perform trapezoidal integration - DO i = 1, n - 1 - sum = sum + 0.5_dp*(x(i + 1) - x(i))*(y(i + 1) + y(i)) - END DO - - result = sum - END SUBROUTINE trapezoidalintegration - - SUBROUTINE SimpsonIntegration(x, y, result) - REAL(DP), DIMENSION(:), INTENT(IN) :: x, y - REAL(DP), INTENT(OUT) :: result - - INTEGER :: i, n - REAL(DP) :: sum, h1, h2, f1, f2, f0 - - n = SIZE(x) - sum = 0.0_DP - - ! Validate input sizes - IF (SIZE(x) /= SIZE(y)) THEN - PRINT *, "Error: x and y arrays must have the same size." - CALL mesa_error(__FILE__, __LINE__) - END IF - - IF (SIZE(x) < 2) THEN - PRINT *, "Error: x and y arrays must have at least 2 elements." - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Perform adaptive Simpson's rule - DO i = 1, n - 2, 2 - h1 = x(i + 1) - x(i) ! Step size for first interval - h2 = x(i + 2) - x(i + 1) ! Step size for second interval - - f0 = y(i) - f1 = y(i + 1) - f2 = y(i + 2) - - ! Simpson's rule: (h/3) * (f0 + 4f1 + f2) - sum = sum + (h1 + h2)/6.0_DP*(f0 + 4.0_DP*f1 + f2) - END DO - - ! Handle the case where n is odd (last interval) - IF (MOD(n, 2) == 0) THEN - sum = sum + 0.5_DP*(x(n) - x(n - 1))*(y(n) + y(n - 1)) - END IF - - result = sum - END SUBROUTINE SimpsonIntegration - - SUBROUTINE rombergintegration(x, y, result) - REAL(DP), DIMENSION(:), INTENT(IN) :: x, y - REAL(DP), INTENT(OUT) :: result - - INTEGER :: i, j, k, n, m - REAL(DP), DIMENSION(:), ALLOCATABLE :: R - REAL(DP) :: h, sum, factor - - n = SIZE(x) - m = INT(LOG(REAL(n, DP))/LOG(2.0_DP)) + 1 ! Number of refinement levels - - ! Validate input sizes - IF (SIZE(x) /= SIZE(y)) THEN - PRINT *, "Error: x and y arrays must have the same size." - CALL mesa_error(__FILE__, __LINE__) - END IF - - IF (n < 2) THEN - PRINT *, "Error: x and y arrays must have at least 2 elements." - CALL mesa_error(__FILE__, __LINE__) - END IF - - ALLOCATE (R(m)) - - ! Compute initial trapezoidal rule estimate - h = x(n) - x(1) - R(1) = 0.5_DP*h*(y(1) + y(n)) - - ! Refinement using Romberg's method - DO j = 2, m - sum = 0.0_DP - DO i = 1, 2**(j - 2) - sum = sum + y(1 + (2*i - 1)*(n - 1)/(2**(j - 1))) - END DO - - h = h/2.0_DP - R(j) = 0.5_DP*R(j - 1) + h*sum - - ! Richardson extrapolation - factor = 4.0_DP - DO k = j, 2, -1 - R(k - 1) = (factor*R(k) - R(k - 1))/(factor - 1.0_DP) - factor = factor*4.0_DP - END DO - END DO - - result = R(1) - DEALLOCATE (R) - END SUBROUTINE rombergintegration - - !----------------------------------------------------------------------- - ! File I/O functions - !----------------------------------------------------------------------- - - !**************************** - ! Load Vega SED for Zero Point Calculation - !**************************** - SUBROUTINE loadvegased(filepath, wavelengths, flux) - CHARACTER(LEN=*), INTENT(IN) :: filepath - REAL(dp), DIMENSION(:), ALLOCATABLE, INTENT(OUT) :: wavelengths, flux - CHARACTER(LEN=512) :: line - INTEGER :: unit, n_rows, status, i - REAL(dp) :: temp_wave, temp_flux - - unit = 20 - OPEN (unit, FILE=TRIM(filepath), STATUS='OLD', ACTION='READ', IOSTAT=status) - IF (status /= 0) THEN - PRINT *, "Error: Could not open Vega SED file ", TRIM(filepath) - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Skip header line - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) THEN - PRINT *, "Error: Could not read header from Vega SED file ", TRIM(filepath) - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Count the number of data lines - n_rows = 0 - DO - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) EXIT - n_rows = n_rows + 1 - END DO - - REWIND (unit) - READ (unit, '(A)', IOSTAT=status) line ! Skip header again - - ALLOCATE (wavelengths(n_rows)) - ALLOCATE (flux(n_rows)) - - i = 0 - DO - READ (unit, *, IOSTAT=status) temp_wave, temp_flux ! Ignore any extra columns - IF (status /= 0) EXIT - i = i + 1 - wavelengths(i) = temp_wave - flux(i) = temp_flux - END DO - - CLOSE (unit) - END SUBROUTINE loadvegased - - !**************************** - ! Load Filter File - !**************************** - SUBROUTINE loadfilter(directory, filter_wavelengths, filter_trans) - CHARACTER(LEN=*), INTENT(IN) :: directory - REAL(dp), DIMENSION(:), ALLOCATABLE, INTENT(OUT) :: filter_wavelengths, filter_trans - - CHARACTER(LEN=512) :: line - INTEGER :: unit, n_rows, status, i - REAL(dp) :: temp_wavelength, temp_trans - - ! Open the file - unit = 20 - OPEN (unit, FILE=TRIM(directory), STATUS='OLD', ACTION='READ', IOSTAT=status) - IF (status /= 0) THEN - PRINT *, "Error: Could not open file ", TRIM(directory) - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Skip header line - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) THEN - PRINT *, "Error: Could not read the file", TRIM(directory) - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Count rows in the file - n_rows = 0 - DO - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) EXIT - n_rows = n_rows + 1 - END DO - - ! Allocate arrays - ALLOCATE (filter_wavelengths(n_rows)) - ALLOCATE (filter_trans(n_rows)) - - ! Rewind to the first non-comment line - REWIND (unit) - DO - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) THEN - PRINT *, "Error: Could not rewind file", TRIM(directory) - CALL mesa_error(__FILE__, __LINE__) - END IF - IF (line(1:1) /= "#") EXIT - END DO - - ! Read and parse data - i = 0 - DO - READ (unit, *, IOSTAT=status) temp_wavelength, temp_trans - IF (status /= 0) EXIT - i = i + 1 - - filter_wavelengths(i) = temp_wavelength - filter_trans(i) = temp_trans - END DO - - CLOSE (unit) - END SUBROUTINE loadfilter - - !**************************** - ! Load Lookup Table For Identifying Stellar Atmosphere Models - !**************************** - SUBROUTINE load_lookuptable(lookup_file, lookup_table, out_file_names, out_logg, out_meta, out_teff) - CHARACTER(LEN=*), INTENT(IN) :: lookup_file - REAL, DIMENSION(:, :), ALLOCATABLE, INTENT(OUT) :: lookup_table - CHARACTER(LEN=100), ALLOCATABLE, INTENT(INOUT) :: out_file_names(:) - REAL(dp), ALLOCATABLE, INTENT(INOUT) :: out_logg(:), out_meta(:), out_teff(:) - - INTEGER :: i, n_rows, status, unit - CHARACTER(LEN=512) :: line - CHARACTER(LEN=*), PARAMETER :: delimiter = "," - CHARACTER(LEN=100), ALLOCATABLE :: columns(:), headers(:) - INTEGER :: logg_col, meta_col, teff_col - - ! Open the file - unit = 10 - OPEN (unit, FILE=lookup_file, STATUS='old', ACTION='read', IOSTAT=status) - IF (status /= 0) THEN - PRINT *, "Error: Could not open file", lookup_file - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Read header line - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) THEN - PRINT *, "Error: Could not read header line" - CALL mesa_error(__FILE__, __LINE__) - END IF - - CALL splitline(line, delimiter, headers) - - ! Determine column indices for logg, meta, and teff - logg_col = getcolumnindex(headers, "logg") - teff_col = getcolumnindex(headers, "teff") - - meta_col = getcolumnindex(headers, "meta") - IF (meta_col < 0) THEN - meta_col = getcolumnindex(headers, "feh") - END IF - - n_rows = 0 - DO - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) EXIT - n_rows = n_rows + 1 - END DO - REWIND (unit) - - ! Skip header - READ (unit, '(A)', IOSTAT=status) line - - ! Allocate output arrays - ALLOCATE (out_file_names(n_rows)) - ALLOCATE (out_logg(n_rows), out_meta(n_rows), out_teff(n_rows)) - - ! Read and parse the file - i = 0 - DO - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) EXIT - i = i + 1 - - CALL splitline(line, delimiter, columns) - - ! Populate arrays - out_file_names(i) = columns(1) - - IF (logg_col > 0) THEN - IF (columns(logg_col) /= "") THEN - READ (columns(logg_col), *) out_logg(i) - ELSE - out_logg(i) = 0.0 - END IF - ELSE - out_logg(i) = 0.0 - END IF - - IF (meta_col > 0) THEN - IF (columns(meta_col) /= "") THEN - READ (columns(meta_col), *) out_meta(i) - ELSE - out_meta(i) = 0.0 - END IF - ELSE - out_meta(i) = 0.0 - END IF - - IF (teff_col > 0) THEN - IF (columns(teff_col) /= "") THEN - READ (columns(teff_col), *) out_teff(i) - ELSE - out_teff(i) = 0.0 - END IF - ELSE - out_teff(i) = 0.0 - END IF - END DO - - CLOSE (unit) - - CONTAINS - - FUNCTION getcolumnindex(headers, target) RESULT(index) - CHARACTER(LEN=100), INTENT(IN) :: headers(:) - CHARACTER(LEN=*), INTENT(IN) :: target - INTEGER :: index, i - CHARACTER(LEN=100) :: clean_header, clean_target - - index = -1 - clean_target = TRIM(ADJUSTL(target)) ! Clean the target string - - DO i = 1, SIZE(headers) - clean_header = TRIM(ADJUSTL(headers(i))) ! Clean each header - IF (clean_header == clean_target) THEN - index = i - EXIT - END IF - END DO - END FUNCTION getcolumnindex - - SUBROUTINE splitline(line, delimiter, tokens) - CHARACTER(LEN=*), INTENT(IN) :: line, delimiter - CHARACTER(LEN=100), ALLOCATABLE, INTENT(OUT) :: tokens(:) - INTEGER :: num_tokens, pos, start, len_delim - - len_delim = LEN_TRIM(delimiter) - start = 1 - num_tokens = 0 - IF (ALLOCATED(tokens)) DEALLOCATE (tokens) - - DO - pos = INDEX(line(start:), delimiter) - - IF (pos == 0) EXIT - num_tokens = num_tokens + 1 - CALL AppendToken(tokens, line(start:start + pos - 2)) - start = start + pos + len_delim - 1 - END DO - - num_tokens = num_tokens + 1 - CALL AppendToken(tokens, line(start:)) - END SUBROUTINE splitline - - SUBROUTINE AppendToken(tokens, token) - CHARACTER(LEN=*), INTENT(IN) :: token - CHARACTER(LEN=100), ALLOCATABLE, INTENT(INOUT) :: tokens(:) - CHARACTER(LEN=100), ALLOCATABLE :: temp(:) - INTEGER :: n - - IF (.NOT. ALLOCATED(tokens)) THEN - ALLOCATE (tokens(1)) - tokens(1) = token - ELSE - n = SIZE(tokens) - ALLOCATE (temp(n)) - temp = tokens ! Backup the current tokens - DEALLOCATE (tokens) ! Deallocate the old array - ALLOCATE (tokens(n + 1)) ! Allocate with one extra space - tokens(1:n) = temp ! Restore old tokens - tokens(n + 1) = token ! Add the new token - DEALLOCATE (temp) ! Clean up temporary array - END IF - END SUBROUTINE AppendToken - - END SUBROUTINE load_lookuptable - - SUBROUTINE loadsed(directory, index, wavelengths, flux) - CHARACTER(LEN=*), INTENT(IN) :: directory - INTEGER, INTENT(IN) :: index - REAL(DP), DIMENSION(:), ALLOCATABLE, INTENT(OUT) :: wavelengths, flux - - CHARACTER(LEN=512) :: line - INTEGER :: unit, n_rows, status, i - REAL(DP) :: temp_wavelength, temp_flux - - ! Open the file - unit = 20 - OPEN (unit, FILE=TRIM(directory), STATUS='OLD', ACTION='READ', IOSTAT=status) - IF (status /= 0) THEN - PRINT *, "Error: Could not open file ", TRIM(directory) - CALL mesa_error(__FILE__, __LINE__) - END IF - - ! Skip header lines - DO - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) THEN - PRINT *, "Error: Could not read the file", TRIM(directory) - CALL mesa_error(__FILE__, __LINE__) - END IF - IF (line(1:1) /= "#") EXIT - END DO - - ! Count rows in the file - n_rows = 0 - DO - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) EXIT - n_rows = n_rows + 1 - END DO - - ! Allocate arrays - ALLOCATE (wavelengths(n_rows)) - ALLOCATE (flux(n_rows)) - - ! Rewind to the first non-comment line - REWIND (unit) - DO - READ (unit, '(A)', IOSTAT=status) line - IF (status /= 0) THEN - PRINT *, "Error: Could not rewind file", TRIM(directory) - CALL mesa_error(__FILE__, __LINE__) - END IF - IF (line(1:1) /= "#") EXIT - END DO - - ! Read and parse data - i = 0 - DO - READ (unit, *, IOSTAT=status) temp_wavelength, temp_flux - IF (status /= 0) EXIT - i = i + 1 - ! Convert f_lambda to f_nu - wavelengths(i) = temp_wavelength - flux(i) = temp_flux - END DO - - CLOSE (unit) - - END SUBROUTINE loadsed - - !----------------------------------------------------------------------- - ! Helper function for file names - !----------------------------------------------------------------------- - - function remove_dat(path) result(base) - ! Extracts the portion of the string before the first dot - character(len=*), intent(in) :: path - character(len=strlen) :: base - integer :: first_dot - - ! Find the position of the first dot - first_dot = 0 - do while (first_dot < len_trim(path) .and. path(first_dot + 1:first_dot + 1) /= '.') - first_dot = first_dot + 1 - end do - - ! Check if a dot was found - if (first_dot < len_trim(path)) then - ! Extract the part before the dot - base = path(:first_dot) - else - ! No dot found, return the input string - base = path - end if - end function remove_dat - -END MODULE shared_funcs diff --git a/colors/private/synthetic.f90 b/colors/private/synthetic.f90 index 85dadb2ad..174a82f29 100644 --- a/colors/private/synthetic.f90 +++ b/colors/private/synthetic.f90 @@ -28,21 +28,16 @@ module synthetic private public :: calculate_synthetic - ! Export zero-point computation functions for precomputation at initialization public :: compute_vega_zero_point, compute_ab_zero_point, compute_st_zero_point contains - !**************************** - ! Calculate Synthetic Photometry Using SED and Filter - ! Uses precomputed zero-point from filter data - !**************************** + ! calculate synthetic magnitude; uses precomputed zero-point from filter_data real(dp) function calculate_synthetic(temperature, gravity, metallicity, ierr, & wavelengths, fluxes, & filter_wavelengths, filter_trans, & zero_point_flux, & - filter_name, make_sed, colors_results_directory) - ! Input arguments + filter_name, make_sed, sed_per_model, colors_results_directory, model_number) real(dp), intent(in) :: temperature, gravity, metallicity character(len=*), intent(in) :: filter_name, colors_results_directory integer, intent(out) :: ierr @@ -50,11 +45,12 @@ real(dp) function calculate_synthetic(temperature, gravity, metallicity, ierr, & real(dp), dimension(:), intent(in) :: wavelengths, fluxes real(dp), dimension(:), intent(in) :: filter_wavelengths, filter_trans real(dp), intent(in) :: zero_point_flux ! precomputed at initialization - logical, intent(in) :: make_sed + logical, intent(in) :: make_sed, sed_per_model + integer, intent(in) :: model_number - ! Local variables real(dp), dimension(:), allocatable :: convolved_flux, filter_on_sed_grid character(len=256) :: csv_file + character(len=20) :: model_str character(len=1000) :: line real(dp) :: synthetic_flux integer :: max_size, i @@ -62,27 +58,33 @@ real(dp) function calculate_synthetic(temperature, gravity, metallicity, ierr, & ierr = 0 - ! Allocate working arrays - allocate(convolved_flux(size(wavelengths))) - allocate(filter_on_sed_grid(size(wavelengths))) + allocate (convolved_flux(size(wavelengths))) + allocate (filter_on_sed_grid(size(wavelengths))) - ! Interpolate filter onto SED wavelength grid call interpolate_array(filter_wavelengths, filter_trans, wavelengths, filter_on_sed_grid) - ! Convolve SED with filter - convolved_flux = fluxes * filter_on_sed_grid + convolved_flux = fluxes*filter_on_sed_grid - ! Write SED to CSV if requested if (make_sed) then if (.not. folder_exists(trim(colors_results_directory))) call mkdir(trim(colors_results_directory)) - csv_file = trim(colors_results_directory)//'/'//trim(remove_dat(filter_name))//'_SED.csv' + + ! track model number internally when write_sed_per_model is enabled + if (sed_per_model) then + + write (model_str, '(I8.8)') model_number + csv_file = trim(colors_results_directory)//'/'//trim(remove_dat(filter_name))//'_SED_'//trim(model_str)//'.csv' + + else + csv_file = trim(colors_results_directory)//'/'//trim(remove_dat(filter_name))//'_SED.csv' + end if max_size = max(size(wavelengths), size(filter_wavelengths)) open (unit=10, file=csv_file, status='REPLACE', action='write', iostat=ierr) if (ierr /= 0) then print *, "Error opening file for writing" - deallocate(convolved_flux, filter_on_sed_grid) + deallocate (convolved_flux, filter_on_sed_grid) + calculate_synthetic = huge(1.0_dp) return end if @@ -103,12 +105,10 @@ real(dp) function calculate_synthetic(temperature, gravity, metallicity, ierr, & close (10) end if - ! Calculate synthetic flux using photon-counting integration call calculate_synthetic_flux(wavelengths, convolved_flux, filter_on_sed_grid, synthetic_flux) - ! Calculate magnitude if (zero_point_flux > 0.0_dp .and. synthetic_flux > 0.0_dp) then - calculate_synthetic = -2.5d0 * log10(synthetic_flux / zero_point_flux) + calculate_synthetic = -2.5d0*log10(synthetic_flux/zero_point_flux) else if (zero_point_flux <= 0.0_dp) then print *, "Error: Zero point flux is zero or negative for filter ", trim(filter_name) @@ -119,34 +119,29 @@ real(dp) function calculate_synthetic(temperature, gravity, metallicity, ierr, & calculate_synthetic = huge(1.0_dp) end if - deallocate(convolved_flux, filter_on_sed_grid) + deallocate (convolved_flux, filter_on_sed_grid) end function calculate_synthetic - !**************************** - ! Calculate Synthetic Flux (photon-counting integration) - !**************************** + ! photon-counting synthetic flux (numerator and denominator weighted by lambda) subroutine calculate_synthetic_flux(wavelengths, convolved_flux, filter_on_sed_grid, synthetic_flux) real(dp), dimension(:), intent(in) :: wavelengths, convolved_flux, filter_on_sed_grid real(dp), intent(out) :: synthetic_flux real(dp) :: integrated_flux, integrated_filter - ! Photon-counting: weight by wavelength - call romberg_integration(wavelengths, convolved_flux * wavelengths, integrated_flux) - call romberg_integration(wavelengths, filter_on_sed_grid * wavelengths, integrated_filter) + ! photon-counting: weight by wavelength + call romberg_integration(wavelengths, convolved_flux*wavelengths, integrated_flux) + call romberg_integration(wavelengths, filter_on_sed_grid*wavelengths, integrated_filter) if (integrated_filter > 0.0_dp) then - synthetic_flux = integrated_flux / integrated_filter + synthetic_flux = integrated_flux/integrated_filter else print *, "Error: Integrated filter transmission is zero." synthetic_flux = -1.0_dp end if end subroutine calculate_synthetic_flux - !**************************** - ! Compute Vega Zero Point Flux - ! Called once at initialization, result cached in filter_data - !**************************** + ! vega zero-point -- called once at init, result cached in filter_data real(dp) function compute_vega_zero_point(vega_wave, vega_flux, filt_wave, filt_trans) real(dp), dimension(:), intent(in) :: vega_wave, vega_flux real(dp), dimension(:), intent(in) :: filt_wave, filt_trans @@ -154,34 +149,28 @@ real(dp) function compute_vega_zero_point(vega_wave, vega_flux, filt_wave, filt_ real(dp) :: int_flux, int_filter real(dp), allocatable :: filt_on_vega_grid(:), conv_flux(:) - allocate(filt_on_vega_grid(size(vega_wave))) - allocate(conv_flux(size(vega_wave))) + allocate (filt_on_vega_grid(size(vega_wave))) + allocate (conv_flux(size(vega_wave))) - ! Interpolate filter onto Vega wavelength grid call interpolate_array(filt_wave, filt_trans, vega_wave, filt_on_vega_grid) - ! Convolve Vega with filter - conv_flux = vega_flux * filt_on_vega_grid + conv_flux = vega_flux*filt_on_vega_grid - ! Photon-counting integration - call romberg_integration(vega_wave, vega_wave * conv_flux, int_flux) - call romberg_integration(vega_wave, vega_wave * filt_on_vega_grid, int_filter) + ! photon-counting integration + call romberg_integration(vega_wave, vega_wave*conv_flux, int_flux) + call romberg_integration(vega_wave, vega_wave*filt_on_vega_grid, int_filter) if (int_filter > 0.0_dp) then - compute_vega_zero_point = int_flux / int_filter + compute_vega_zero_point = int_flux/int_filter else compute_vega_zero_point = -1.0_dp end if - deallocate(filt_on_vega_grid, conv_flux) + deallocate (filt_on_vega_grid, conv_flux) end function compute_vega_zero_point - !**************************** - ! Compute AB Zero Point Flux - ! f_nu = 3631 Jy = 3.631e-20 erg/s/cm^2/Hz - ! f_lambda = f_nu * c / lambda^2 - ! Called once at initialization, result cached in filter_data - !**************************** + ! ab zero-point -- f_nu = 3631 Jy, converted to f_lambda = f_nu * c / lambda^2 + ! called once at init, result cached in filter_data real(dp) function compute_ab_zero_point(filt_wave, filt_trans) real(dp), dimension(:), intent(in) :: filt_wave, filt_trans @@ -189,57 +178,53 @@ real(dp) function compute_ab_zero_point(filt_wave, filt_trans) real(dp), allocatable :: ab_sed_flux(:) integer :: i - allocate(ab_sed_flux(size(filt_wave))) + allocate (ab_sed_flux(size(filt_wave))) - ! Construct AB spectrum (f_lambda) on the filter wavelength grid - ! 3631 Jy = 3.631E-20 erg/s/cm^2/Hz - ! clight in cm/s, wavelength in Angstroms, need to convert + ! construct ab spectrum (f_lambda) on the filter wavelength grid + ! 3631 Jy = 3.631e-20 erg/s/cm^2/Hz; clight in cm/s, wavelength in angstroms do i = 1, size(filt_wave) if (filt_wave(i) > 0.0_dp) then - ab_sed_flux(i) = 3.631d-20 * ((clight * 1.0d8) / (filt_wave(i)**2)) + ab_sed_flux(i) = 3.631d-20*((clight*1.0d8)/(filt_wave(i)**2)) else ab_sed_flux(i) = 0.0_dp end if end do - ! Photon-counting integration - call romberg_integration(filt_wave, ab_sed_flux * filt_trans * filt_wave, int_flux) - call romberg_integration(filt_wave, filt_wave * filt_trans, int_filter) + ! photon-counting integration + call romberg_integration(filt_wave, ab_sed_flux*filt_trans*filt_wave, int_flux) + call romberg_integration(filt_wave, filt_wave*filt_trans, int_filter) if (int_filter > 0.0_dp) then - compute_ab_zero_point = int_flux / int_filter + compute_ab_zero_point = int_flux/int_filter else compute_ab_zero_point = -1.0_dp end if - deallocate(ab_sed_flux) + deallocate (ab_sed_flux) end function compute_ab_zero_point - !**************************** - ! Compute ST Zero Point Flux - ! f_lambda = 3.63e-9 erg/s/cm^2/A (Constant) - ! Called once at initialization, result cached in filter_data - !**************************** + ! st zero-point -- f_lambda = 3.63e-9 erg/s/cm^2/A (flat spectrum) + ! called once at init, result cached in filter_data real(dp) function compute_st_zero_point(filt_wave, filt_trans) real(dp), dimension(:), intent(in) :: filt_wave, filt_trans real(dp) :: int_flux, int_filter real(dp), allocatable :: st_sed_flux(:) - allocate(st_sed_flux(size(filt_wave))) + allocate (st_sed_flux(size(filt_wave))) st_sed_flux = 3.63d-9 - ! Photon-counting integration - call romberg_integration(filt_wave, st_sed_flux * filt_trans * filt_wave, int_flux) - call romberg_integration(filt_wave, filt_wave * filt_trans, int_filter) + ! photon-counting integration + call romberg_integration(filt_wave, st_sed_flux*filt_trans*filt_wave, int_flux) + call romberg_integration(filt_wave, filt_wave*filt_trans, int_filter) if (int_filter > 0.0_dp) then - compute_st_zero_point = int_flux / int_filter + compute_st_zero_point = int_flux/int_filter else compute_st_zero_point = -1.0_dp end if - deallocate(st_sed_flux) + deallocate (st_sed_flux) end function compute_st_zero_point end module synthetic \ No newline at end of file diff --git a/colors/public/colors_def.f90 b/colors/public/colors_def.f90 index ec37f2326..f3b237635 100644 --- a/colors/public/colors_def.f90 +++ b/colors/public/colors_def.f90 @@ -23,21 +23,24 @@ module colors_def implicit none - ! Make everything in this module public by default public - ! Type to hold individual filter data + ! max number of SEDs to keep in the memory cache + ! each slot holds one wavelength array (~1200 doubles ~ 10 KB), + ! so 256 slots ~ 2.5 MB -- negligible even when the full cube + ! cannot be allocated + integer, parameter :: sed_mem_cache_cap = 256 + type :: filter_data character(len=100) :: name real(dp), allocatable :: wavelengths(:) real(dp), allocatable :: transmission(:) - ! Precomputed zero-point fluxes (computed once at initialization) + ! precomputed zero-point fluxes, computed once at init real(dp) :: vega_zero_point = -1.0_dp real(dp) :: ab_zero_point = -1.0_dp real(dp) :: st_zero_point = -1.0_dp end type filter_data - ! Colors Module control parameters type :: Colors_General_Info character(len=256) :: instrument character(len=256) :: vega_sed @@ -47,26 +50,77 @@ module colors_def real(dp) :: metallicity real(dp) :: distance logical :: make_csv + logical :: sed_per_model logical :: use_colors integer :: handle logical :: in_use - ! Cached lookup table data + logical :: colors_per_newton_step + integer :: iteration_output_unit + logical :: iteration_file_open + + ! cached lookup table data logical :: lookup_loaded = .false. character(len=100), allocatable :: lu_file_names(:) real(dp), allocatable :: lu_logg(:) real(dp), allocatable :: lu_meta(:) real(dp), allocatable :: lu_teff(:) - ! Cached Vega SED + ! cached vega SED logical :: vega_loaded = .false. real(dp), allocatable :: vega_wavelengths(:) real(dp), allocatable :: vega_fluxes(:) - ! Cached filter data (includes precomputed zero-points) + ! cached filter data (includes precomputed zero-points) logical :: filters_loaded = .false. type(filter_data), allocatable :: filters(:) + ! cached flux cube + logical :: cube_loaded = .false. + real(dp), allocatable :: cube_flux(:, :, :, :) ! (n_teff, n_logg, n_meta, n_lambda) + real(dp), allocatable :: cube_teff_grid(:) + real(dp), allocatable :: cube_logg_grid(:) + real(dp), allocatable :: cube_meta_grid(:) + real(dp), allocatable :: cube_wavelengths(:) + + ! unique sorted grids (built once from lookup table at init) + logical :: unique_grids_built = .false. + real(dp), allocatable :: u_teff(:), u_logg(:), u_meta(:) + + ! grid_to_lu(i_t, i_g, i_m) gives the lookup-table row index for + ! (u_teff(i_t), u_logg(i_g), u_meta(i_m)) -- avoids O(n_lu) + ! nearest-neighbour searches at runtime + logical :: grid_map_built = .false. + integer, allocatable :: grid_to_lu(:, :, :) + + ! fallback-path caches (used only when cube_loaded == .false.) + + ! stencil cache: the extended neighbourhood around the current + ! interpolation cell, includes derivative-context points + ! (i-1 .. i+2 per axis, clamped to boundaries) so that + ! hermite_tensor_interp3d gives the same result as the cube path + logical :: stencil_valid = .false. + integer :: stencil_i_t = -1, stencil_i_g = -1, stencil_i_m = -1 + real(dp), allocatable :: stencil_fluxes(:, :, :, :) ! (st, sg, sm, n_lambda) + real(dp), allocatable :: stencil_wavelengths(:) ! (n_lambda) + real(dp), allocatable :: stencil_teff(:) ! subgrid values (st) + real(dp), allocatable :: stencil_logg(:) ! subgrid values (sg) + real(dp), allocatable :: stencil_meta(:) ! subgrid values (sm) + + ! canonical wavelength grid for fallback SEDs (set once on first disk + ! read -- all SEDs in a given atmosphere grid share the same wavelengths) + logical :: fallback_wavelengths_set = .false. + real(dp), allocatable :: fallback_wavelengths(:) ! (n_lambda) + + ! bounded SED memory cache (circular buffer, keyed by lu index) + ! avoids re-reading text files for SEDs we've already parsed + logical :: sed_mcache_init = .false. + integer :: sed_mcache_count = 0 + integer :: sed_mcache_next = 1 + integer :: sed_mcache_nlam = 0 + integer, allocatable :: sed_mcache_keys(:) ! (sed_mem_cache_cap) + real(dp), allocatable :: sed_mcache_data(:, :) ! (n_lambda, sed_mem_cache_cap) + end type Colors_General_Info ! Global filter name list (shared across handles) @@ -104,6 +158,18 @@ subroutine colors_def_init(colors_cache_dir_in) colors_handles(i)%lookup_loaded = .false. colors_handles(i)%vega_loaded = .false. colors_handles(i)%filters_loaded = .false. + colors_handles(i)%cube_loaded = .false. + colors_handles(i)%unique_grids_built = .false. + colors_handles(i)%grid_map_built = .false. + colors_handles(i)%stencil_valid = .false. + colors_handles(i)%stencil_i_t = -1 + colors_handles(i)%stencil_i_g = -1 + colors_handles(i)%stencil_i_m = -1 + colors_handles(i)%sed_mcache_init = .false. + colors_handles(i)%sed_mcache_count = 0 + colors_handles(i)%sed_mcache_next = 1 + colors_handles(i)%sed_mcache_nlam = 0 + colors_handles(i)%fallback_wavelengths_set = .false. end do colors_temp_cache_dir = trim(mesa_temp_caches_dir)//'/colors_cache' @@ -149,36 +215,85 @@ subroutine free_colors_cache(handle) if (handle < 1 .or. handle > max_colors_handles) return - ! Free lookup table arrays if (allocated(colors_handles(handle)%lu_file_names)) & - deallocate(colors_handles(handle)%lu_file_names) + deallocate (colors_handles(handle)%lu_file_names) if (allocated(colors_handles(handle)%lu_logg)) & - deallocate(colors_handles(handle)%lu_logg) + deallocate (colors_handles(handle)%lu_logg) if (allocated(colors_handles(handle)%lu_meta)) & - deallocate(colors_handles(handle)%lu_meta) + deallocate (colors_handles(handle)%lu_meta) if (allocated(colors_handles(handle)%lu_teff)) & - deallocate(colors_handles(handle)%lu_teff) + deallocate (colors_handles(handle)%lu_teff) colors_handles(handle)%lookup_loaded = .false. - ! Free Vega SED arrays if (allocated(colors_handles(handle)%vega_wavelengths)) & - deallocate(colors_handles(handle)%vega_wavelengths) + deallocate (colors_handles(handle)%vega_wavelengths) if (allocated(colors_handles(handle)%vega_fluxes)) & - deallocate(colors_handles(handle)%vega_fluxes) + deallocate (colors_handles(handle)%vega_fluxes) colors_handles(handle)%vega_loaded = .false. - ! Free filter data arrays if (allocated(colors_handles(handle)%filters)) then do i = 1, size(colors_handles(handle)%filters) if (allocated(colors_handles(handle)%filters(i)%wavelengths)) & - deallocate(colors_handles(handle)%filters(i)%wavelengths) + deallocate (colors_handles(handle)%filters(i)%wavelengths) if (allocated(colors_handles(handle)%filters(i)%transmission)) & - deallocate(colors_handles(handle)%filters(i)%transmission) + deallocate (colors_handles(handle)%filters(i)%transmission) end do - deallocate(colors_handles(handle)%filters) + deallocate (colors_handles(handle)%filters) end if colors_handles(handle)%filters_loaded = .false. + if (allocated(colors_handles(handle)%cube_flux)) & + deallocate (colors_handles(handle)%cube_flux) + if (allocated(colors_handles(handle)%cube_teff_grid)) & + deallocate (colors_handles(handle)%cube_teff_grid) + if (allocated(colors_handles(handle)%cube_logg_grid)) & + deallocate (colors_handles(handle)%cube_logg_grid) + if (allocated(colors_handles(handle)%cube_meta_grid)) & + deallocate (colors_handles(handle)%cube_meta_grid) + if (allocated(colors_handles(handle)%cube_wavelengths)) & + deallocate (colors_handles(handle)%cube_wavelengths) + colors_handles(handle)%cube_loaded = .false. + + if (allocated(colors_handles(handle)%u_teff)) & + deallocate (colors_handles(handle)%u_teff) + if (allocated(colors_handles(handle)%u_logg)) & + deallocate (colors_handles(handle)%u_logg) + if (allocated(colors_handles(handle)%u_meta)) & + deallocate (colors_handles(handle)%u_meta) + colors_handles(handle)%unique_grids_built = .false. + + if (allocated(colors_handles(handle)%grid_to_lu)) & + deallocate (colors_handles(handle)%grid_to_lu) + colors_handles(handle)%grid_map_built = .false. + + if (allocated(colors_handles(handle)%stencil_fluxes)) & + deallocate (colors_handles(handle)%stencil_fluxes) + if (allocated(colors_handles(handle)%stencil_wavelengths)) & + deallocate (colors_handles(handle)%stencil_wavelengths) + if (allocated(colors_handles(handle)%stencil_teff)) & + deallocate (colors_handles(handle)%stencil_teff) + if (allocated(colors_handles(handle)%stencil_logg)) & + deallocate (colors_handles(handle)%stencil_logg) + if (allocated(colors_handles(handle)%stencil_meta)) & + deallocate (colors_handles(handle)%stencil_meta) + colors_handles(handle)%stencil_valid = .false. + colors_handles(handle)%stencil_i_t = -1 + colors_handles(handle)%stencil_i_g = -1 + colors_handles(handle)%stencil_i_m = -1 + + if (allocated(colors_handles(handle)%sed_mcache_keys)) & + deallocate (colors_handles(handle)%sed_mcache_keys) + if (allocated(colors_handles(handle)%sed_mcache_data)) & + deallocate (colors_handles(handle)%sed_mcache_data) + colors_handles(handle)%sed_mcache_init = .false. + colors_handles(handle)%sed_mcache_count = 0 + colors_handles(handle)%sed_mcache_next = 1 + colors_handles(handle)%sed_mcache_nlam = 0 + + if (allocated(colors_handles(handle)%fallback_wavelengths)) & + deallocate (colors_handles(handle)%fallback_wavelengths) + colors_handles(handle)%fallback_wavelengths_set = .false. + end subroutine free_colors_cache subroutine get_colors_ptr(handle, rq, ierr) @@ -196,10 +311,8 @@ end subroutine get_colors_ptr subroutine do_free_colors_tables integer :: i - ! Free the filter names array - if (allocated(color_filter_names)) deallocate(color_filter_names) + if (allocated(color_filter_names)) deallocate (color_filter_names) - ! Free cached data for all handles do i = 1, max_colors_handles call free_colors_cache(i) end do diff --git a/colors/public/colors_lib.f90 b/colors/public/colors_lib.f90 index fbc9566d3..81461ed8a 100644 --- a/colors/public/colors_lib.f90 +++ b/colors/public/colors_lib.f90 @@ -22,8 +22,10 @@ module colors_lib use const_def, only: dp, strlen, mesa_dir use bolometric, only: calculate_bolometric use synthetic, only: calculate_synthetic - use colors_utils, only: read_strings_from_file, load_lookup_table, load_filter, load_vega_sed + use colors_utils, only: read_strings_from_file, load_lookup_table, load_filter, load_vega_sed, & + resolve_path, load_flux_cube, build_unique_grids, build_grid_to_lu_map use colors_history, only: how_many_colors_history_columns, data_for_colors_history_columns + use colors_iteration, only: write_iteration_colors, open_iteration_file, close_iteration_file implicit none @@ -33,21 +35,23 @@ module colors_lib public :: alloc_colors_handle, alloc_colors_handle_using_inlist, free_colors_handle public :: colors_ptr public :: colors_setup_tables, colors_setup_hooks - ! Main functions + ! per-iteration colors routines (called from star) + public :: write_iteration_colors, open_iteration_file, close_iteration_file public :: calculate_bolometric, calculate_synthetic public :: how_many_colors_history_columns, data_for_colors_history_columns - ! Old bolometric correction functions that MESA expects (stub implementations, remove later): + ! old bolometric correction stubs that MESA expects (remove later): public :: get_bc_id_by_name, get_lum_band_by_id, get_abs_mag_by_id public :: get_bc_by_id, get_bc_name_by_id, get_bc_by_name public :: get_abs_bolometric_mag, get_abs_mag_by_name, get_bcs_all public :: get_lum_band_by_name + contains ! call this routine to initialize the colors module. ! only needs to be done once at start of run. - ! Reads data from the 'colors' directory in the data_dir. - ! If use_cache is true and there is a 'colors/cache' directory, it will try that first. - ! If it doesn't find what it needs in the cache, + ! reads data from the 'colors' directory in the data_dir. + ! if use_cache is true and there is a 'colors/cache' directory, it will try that first. + ! if it doesn't find what it needs in the cache, ! it reads the data and writes the cache for next time. subroutine colors_init(use_cache, colors_cache_dir, ierr) use colors_def, only: colors_def_init, colors_use_cache, colors_is_initialized @@ -80,6 +84,7 @@ integer function alloc_colors_handle_using_inlist(inlist, ierr) result(handle) character(len=*), intent(in) :: inlist ! empty means just use defaults. integer, intent(out) :: ierr ! 0 means AOK. ierr = 0 + handle = -1 if (.not. colors_is_initialized) then ierr = -1 return @@ -124,66 +129,70 @@ subroutine colors_setup_tables(handle, ierr) type(Colors_General_Info), pointer :: rq character(len=256) :: lookup_file, filter_dir, filter_filepath, vega_filepath - REAL, allocatable :: lookup_table(:,:) ! unused but required by load_lookup_table + REAL, allocatable :: lookup_table(:, :) ! unused but required by load_lookup_table integer :: i ierr = 0 call get_colors_ptr(handle, rq, ierr) if (ierr /= 0) return - ! Read filter names from instrument directory call read_strings_from_file(rq, color_filter_names, num_color_filters, ierr) if (ierr /= 0) return - ! ========================================= - ! Load lookup table (stellar atmosphere grid) - ! ========================================= + ! load lookup table (stellar atmosphere grid) if (.not. rq%lookup_loaded) then - lookup_file = trim(mesa_dir)//trim(rq%stellar_atm)//'/lookup_table.csv' + lookup_file = trim(resolve_path(rq%stellar_atm))//'/lookup_table.csv' call load_lookup_table(lookup_file, lookup_table, & rq%lu_file_names, rq%lu_logg, rq%lu_meta, rq%lu_teff) rq%lookup_loaded = .true. - if (allocated(lookup_table)) deallocate(lookup_table) + if (allocated(lookup_table)) deallocate (lookup_table) + + ! build unique sorted grids once for fallback interpolation + call build_unique_grids(rq) + + ! build the grid-to-lu mapping for O(1) stencil lookups + call build_grid_to_lu_map(rq) end if - ! ========================================= - ! Load Vega SED (needed for Vega mag system) - ! ========================================= + ! load vega SED if (.not. rq%vega_loaded) then - vega_filepath = trim(mesa_dir)//trim(rq%vega_sed) + vega_filepath = trim(resolve_path(rq%vega_sed)) call load_vega_sed(vega_filepath, rq%vega_wavelengths, rq%vega_fluxes) rq%vega_loaded = .true. end if - ! ========================================= - ! Load all filter transmission curves and precompute zero-points - ! ========================================= + ! load filter transmission curves and precompute zero-points if (.not. rq%filters_loaded) then - filter_dir = trim(mesa_dir)//trim(rq%instrument) + filter_dir = trim(resolve_path(rq%instrument)) - allocate(rq%filters(num_color_filters)) + allocate (rq%filters(num_color_filters)) do i = 1, num_color_filters rq%filters(i)%name = color_filter_names(i) filter_filepath = trim(filter_dir)//'/'//trim(color_filter_names(i)) call load_filter(filter_filepath, rq%filters(i)%wavelengths, rq%filters(i)%transmission) - ! Precompute zero-points for all magnitude systems - ! These are constant for each filter and never need recalculation + ! precompute zero-points for all magnitude systems rq%filters(i)%vega_zero_point = compute_vega_zero_point( & - rq%vega_wavelengths, rq%vega_fluxes, & - rq%filters(i)%wavelengths, rq%filters(i)%transmission) + rq%vega_wavelengths, rq%vega_fluxes, & + rq%filters(i)%wavelengths, rq%filters(i)%transmission) rq%filters(i)%ab_zero_point = compute_ab_zero_point( & - rq%filters(i)%wavelengths, rq%filters(i)%transmission) + rq%filters(i)%wavelengths, rq%filters(i)%transmission) rq%filters(i)%st_zero_point = compute_st_zero_point( & - rq%filters(i)%wavelengths, rq%filters(i)%transmission) + rq%filters(i)%wavelengths, rq%filters(i)%transmission) end do rq%filters_loaded = .true. end if + ! try to load the full flux cube -- if allocation fails, cube_loaded stays + ! .false. and we fall back to loading individual SED files + if (.not. rq%cube_loaded) then + call load_flux_cube(rq, rq%stellar_atm) + end if + end subroutine colors_setup_tables subroutine colors_setup_hooks(handle, ierr) @@ -200,9 +209,7 @@ subroutine colors_setup_hooks(handle, ierr) end subroutine colors_setup_hooks - !----------------------------------------------------------------------- - ! Bolometric correction interface (stub implementations) - !----------------------------------------------------------------------- + ! bolometric correction stubs (legacy mesa interface) real(dp) function get_bc_by_name(name, log_Teff, log_g, M_div_h, ierr) character(len=*), intent(in) :: name @@ -244,7 +251,7 @@ end function get_bc_name_by_id real(dp) function get_abs_bolometric_mag(lum) use const_def, only: dp - real(dp), intent(in) :: lum ! Luminosity in lsun units + real(dp), intent(in) :: lum ! luminosity in lsun get_abs_bolometric_mag = -99.9d0 end function get_abs_bolometric_mag @@ -254,7 +261,7 @@ real(dp) function get_abs_mag_by_name(name, log_Teff, log_g, M_div_h, lum, ierr) real(dp), intent(in) :: log_Teff ! log10 of surface temp real(dp), intent(in) :: M_div_h ! [M/H] real(dp), intent(in) :: log_g ! log_10 of surface gravity - real(dp), intent(in) :: lum ! Luminosity in lsun units + real(dp), intent(in) :: lum ! luminosity in lsun integer, intent(inout) :: ierr ierr = 0 @@ -266,7 +273,7 @@ real(dp) function get_abs_mag_by_id(id, log_Teff, log_g, M_div_h, lum, ierr) real(dp), intent(in) :: log_Teff ! log10 of surface temp real(dp), intent(in) :: log_g ! log_10 of surface gravity real(dp), intent(in) :: M_div_h ! [M/H] - real(dp), intent(in) :: lum ! Luminosity in lsun units + real(dp), intent(in) :: lum ! luminosity in lsun integer, intent(inout) :: ierr ierr = 0 @@ -289,7 +296,7 @@ real(dp) function get_lum_band_by_name(name, log_Teff, log_g, M_div_h, lum, ierr real(dp), intent(in) :: log_Teff ! log10 of surface temp real(dp), intent(in) :: M_div_h ! [M/H] real(dp), intent(in) :: log_g ! log_10 of surface gravity - real(dp), intent(in) :: lum ! Total luminosity in lsun units + real(dp), intent(in) :: lum ! luminosity in lsun integer, intent(inout) :: ierr ierr = 0 @@ -301,7 +308,7 @@ real(dp) function get_lum_band_by_id(id, log_Teff, log_g, M_div_h, lum, ierr) real(dp), intent(in) :: log_Teff ! log10 of surface temp real(dp), intent(in) :: log_g ! log_10 of surface gravity real(dp), intent(in) :: M_div_h ! [M/H] - real(dp), intent(in) :: lum ! Total luminosity in lsun units + real(dp), intent(in) :: lum ! luminosity in lsun integer, intent(inout) :: ierr ierr = 0 diff --git a/star/job/run_star_support.f90 b/star/job/run_star_support.f90 index f0870d79f..8f697ab22 100644 --- a/star/job/run_star_support.f90 +++ b/star/job/run_star_support.f90 @@ -477,9 +477,13 @@ end subroutine extras_controls end if if (dbg) write(*,*) 'call extras_startup' + call s% extras_startup(id, restart, ierr) if (failed('extras_startup',ierr)) return + call star_setup_colors_iteration_hook(id, ierr) + if (ierr /= 0) ierr = 0 ! Colors hook should not be fatal as the colors hook is optional + if (s% job% profile_starting_model .and. .not. restart) then call star_set_vars(id, 0d0, ierr) if (failed('star_set_vars',ierr)) return diff --git a/star/private/history.f90 b/star/private/history.f90 index d6faffd84..f51806a03 100644 --- a/star/private/history.f90 +++ b/star/private/history.f90 @@ -267,8 +267,10 @@ subroutine do_history_info(s, write_flag, ierr) end if colors_col_names(1:num_colors_cols) = 'unknown' colors_col_vals(1:num_colors_cols) = -1d99 - call data_for_colors_history_columns(s%T(1), log10(s%grav(1)), s%R(1), s%kap_rq%Zbase, & - s% colors_handle, num_colors_cols, colors_col_names, colors_col_vals, ierr) + + call data_for_colors_history_columns(s%T(1), safe_log10(s%grav(1)), s%R(1), s%kap_rq%Zbase, & + s% model_number, s% colors_handle, num_colors_cols, colors_col_names, colors_col_vals, ierr) + if (ierr /= 0) then call dealloc return diff --git a/star/private/init.f90 b/star/private/init.f90 index 98bcc1ec0..a6d364a8c 100644 --- a/star/private/init.f90 +++ b/star/private/init.f90 @@ -31,6 +31,7 @@ module init public :: do_starlib_shutdown public :: set_kap_and_eos_handles public :: set_colors_handles + public :: setup_colors_iteration_hook public :: load_zams_model public :: create_pre_ms_model public :: create_initial_model @@ -116,6 +117,62 @@ subroutine set_colors_handles(id, ierr) end if end subroutine set_colors_handles + + subroutine colors_solver_monitor_wrapper( & + id, iter, passed_tol_tests, & + correction_norm, max_correction, & + residual_norm, max_residual, ierr) + use colors_lib, only: write_iteration_colors + use const_def, only: dp + integer, intent(in) :: id, iter + logical, intent(in) :: passed_tol_tests + real(dp), intent(in) :: correction_norm, max_correction + real(dp), intent(in) :: residual_norm, max_residual + integer, intent(out) :: ierr + type(star_info), pointer :: s + + ierr = 0 + call get_star_ptr(id, s, ierr) + if (ierr /= 0) return + if (s% colors_handle <= 0) return + + call write_iteration_colors( & + s% colors_handle, & + s% model_number, & + iter, & + s% star_age, & + s% dt, & + s% Teff, & + safe_log10(s% grav(1)), & + s% R(1), & + s%kap_rq%Zbase, & + ierr) + end subroutine colors_solver_monitor_wrapper + + + subroutine setup_colors_iteration_hook(id, ierr) + use colors_lib, only: colors_ptr + use colors_def, only: Colors_General_Info + integer, intent(in) :: id + integer, intent(out) :: ierr + type(star_info), pointer :: s + type(Colors_General_Info), pointer :: cs + + ierr = 0 + call get_star_ptr(id, s, ierr) + if (ierr /= 0) return + + if (s% colors_handle > 0) then + call colors_ptr(s% colors_handle, cs, ierr) + if (ierr /= 0) return + if (cs% colors_per_newton_step .and. cs% use_colors) then + s% use_other_solver_monitor = .true. + s% other_solver_monitor => colors_solver_monitor_wrapper + end if + end if + end subroutine setup_colors_iteration_hook + + subroutine do_star_init( & my_mesa_dir, chem_isotopes_filename, & net_reaction_filename, jina_reaclib_filename, & diff --git a/star/public/star_lib.f90 b/star/public/star_lib.f90 index 150bcc182..49e20731f 100644 --- a/star/public/star_lib.f90 +++ b/star/public/star_lib.f90 @@ -298,6 +298,14 @@ subroutine star_set_colors_handles(id, ierr) end subroutine star_set_colors_handles + subroutine star_setup_colors_iteration_hook(id, ierr) + use init, only: setup_colors_iteration_hook + integer, intent(in) :: id + integer, intent(out) :: ierr + call setup_colors_iteration_hook(id, ierr) + end subroutine star_setup_colors_iteration_hook + + subroutine star_set_net(id, new_net_name, ierr) use net, only: set_net integer, intent(in) :: id diff --git a/star/test_suite/custom_colors/README.rst b/star/test_suite/custom_colors/README.rst index 1395a9e9e..4206bc597 100644 --- a/star/test_suite/custom_colors/README.rst +++ b/star/test_suite/custom_colors/README.rst @@ -1,242 +1,217 @@ .. _custom_colors: -************* -custom_colors -************* +****** +Colors +****** -This test suite was tested against SDK 25.12.1 +This test suite case demonstrates the functionality of the MESA ``colors`` module. -This test suite case demonstrates the functionality of the MESA ``colors`` module, a framework introduced in MESA r25.10.1 for calculating synthetic photometry and bolometric quantities during stellar evolution. +Running the Test Suite +====================== -What is MESA colors? -==================== - -MESA colors is a post-processing and runtime module that allows users to generate "observer-ready" data directly from stellar evolution models. Instead of limiting output to theoretical quantities like Luminosity (:math:`L`) and Surface Temperature (:math:`T_{\rm eff}`), the colors module computes: - -* **Bolometric Magnitude** (:math:`M_{\rm bol}`) -* **Bolometric Flux** (:math:`F_{\rm bol}`) -* **Synthetic Magnitudes** in specific photometric filters (e.g., Johnson V, Gaia G, 2MASS J). - -This bridges the gap between theoretical evolutionary tracks and observational color-magnitude diagrams (CMDs). - -How does the MESA colors module work? -===================================== -The module operates by coupling the stellar structure model with pre-computed grids of stellar atmospheres. +The ``custom_colors`` test suite case evolves a 7 M☉ star from the pre-main sequence through core hydrogen burning, up to the point of X_He,c < 0.01. -1. **Interpolation**: At each timestep, the module takes the star's current surface parameters—Effective Temperature (:math:`T_{\rm eff}`), Surface Gravity (:math:`\log g`), and Metallicity ([M/H])—and queries a user-specified library of stellar atmospheres (defined in ``stellar_atm``). It interpolates within this grid to construct a specific Spectral Energy Distribution (SED) for the stars current features. +The test is not scientifically rigorous—the pre-MS relaxation settings are aggressive and the mesh is fine—its purpose is solely to exercise the colors module across a wide range of stellar parameters (T_eff, log g) in a reasonable wall-clock time. -2. **Convolution**: This specific SED is then convolved with filter transmission curves (defined in ``instrument``) to calculate the flux passing through each filter. -3. **Integration**: The fluxes are converted into magnitudes using the user-selected magnitude system (AB, ST, or Vega). +The ``custom_colors`` test suite is a standard MESA work directory. Before running it for the first time—or after making changes to the ``colors`` module source—the binary must be compiled. -The Test Suite -============== +``make clean`` + Removes all previously compiled object files and binaries from the ``build/`` directory. Run this before recompiling to ensure a clean state, particularly after switching MESA versions or modifying source files. -This test suite evolves a complete stellar model (from the pre–main sequence onward) while the ``colors`` module runs *continuously* in the background. -At every timestep, MESA computes synthetic photometry by interpolating a stellar atmosphere grid and convolving the resulting SED with the filters you specify in the inlist. +``make`` + Compiles the test suite and links it against the installed MESA libraries (including ``libcolors``). When it completes successfully it produces the ``build/bin/star`` executable. -During the run, the module automatically appends new photometric columns to the ``history.data`` file. -You **do not** need to list these in ``history_columns.list``—the module detects the available filters by inspecting the directory defined in the ``instrument`` parameter. +``./rn`` + Runs the compiled stellar evolution model. MESA evolves the star according to the parameters in ``inlist_colors``, writing history and profile data to ``LOGS/`` and photometric outputs to ``SED/``. -What the Test Suite Produces ----------------------------- +A typical run workflow is: -The standard output of the test suite includes: +.. code-block:: bash -* ``Mag_bol`` - The bolometric magnitude computed from the star's instantaneous bolometric flux. + make clean # wipe any previous build + make # compile + ./rn # run -* ``Flux_bol`` - The bolometric flux (cgs units) after distance dilution is applied. - The default distance is 10 pc, producing **absolute magnitudes** unless changed. +If ``make`` completes without errors and ``./rn`` begins printing timestep output, the colors module is working correctly. -* ``[Filter_Name]`` - A synthetic magnitude column for **every** filter in the instrument directory. -For example, if your filter directory contains: -.. code-block:: text +What is MESA colors? +==================== - filters/Generic/Johnson/ - B.dat - V.dat - R.dat - Johnson +MESA colors is a runtime module that allows users to generate observer-ready data directly from stellar evolution models. Instead of limiting output to theoretical quantities like Luminosity (L) and Surface Temperature (T_eff), the colors module computes: -then your history file will include: +* **Bolometric Magnitude** (M_bol) +* **Bolometric Flux** (F_bol) +* **Synthetic Magnitudes** in specific photometric filters (e.g., Johnson V, Gaia G, 2MASS J). -``B``, ``V``, ``R`` +How does the MESA colors module work? +===================================== -as new magnitude columns generated automatically at runtime. +1. At each timestep, the module takes the star's current surface parameters—Effective Temperature (T_eff), Surface Gravity (log g), and Metallicity ([M/H])—and queries a user-specified library of stellar atmospheres (defined in ``stellar_atm``). It interpolates within this grid to construct a specific Spectral Energy Distribution (SED) for the star's current parameters. -What the Test Suite Actually Does ---------------------------------- +2. This SED is then convolved with filter transmission curves (defined in ``instrument``) to calculate the flux passing through each filter. -The provided ``inlist_colors`` configures a full evolution run with the colors module activated: +3. The fluxes are converted into magnitudes using the user-selected magnitude system (AB, ST, or Vega). -* Starts from a **pre–main-sequence model** -* Evolves the model through multiple phases while computing synthetic photometry -* Uses the Johnson filter set -* Uses the Kurucz2003all atmosphere grid -* Outputs bolometric and filter-specific magnitudes to ``history.data`` every step +Inlist Options & Parameters +=========================== -Because the test suite's inlist also defines a set of PGSTAR panels, you automatically -get real-time plots of: +The colors module is controlled via the ``&colors`` namelist. Below is a detailed guide to the key parameters. -* HR diagram (log L vs. log Teff) -* A light curve based on any synthetic magnitude (the test suite uses ``V``) +use_colors +---------- -Real-Time Visualization (Enabled by Default) --------------------------------------------- +**Default:** ``.false.`` -The test suite's ``pgstar`` block is configured so that, as the star evolves: +Master switch for the module. Must be set to ``.true.`` to enable any photometric output. -* Panel 1: HR diagram (theoretical) -* Panel 2: Light curve in the Johnson V band +**Example:** -These update automatically as the model runs. +.. code-block:: fortran -Purpose of the Test Suite -------------------------- + use_colors = .true. -This test problem is designed to demonstrate: -1. That MESA can compute synthetic photometry **at runtime**, without external tools. -2. How atmosphere grids and filters affect magnitude evolution. -3. How to configure your own inlists for scientific use. -4. How the ``colors`` module integrates with PGSTAR visualization. -5. How synthetic magnitudes appear in ``history.data`` and how to use them for CMDs, light curves, and population modeling. +instrument +---------- +**Default:** ``'data/colors_data/filters/Generic/Johnson'`` +Path to the filter instrument directory, structured as ``facility/instrument``. -Inlist Options & Parameters -=========================== +* The directory must contain an index file with the same name as the instrument + (e.g., ``Johnson``), listing one filter filename per line. +* The module loads every ``.dat`` transmission curve listed in that index and + creates a corresponding history column for each. -The colors module is controlled via the ``&colors`` namelist. Below is a detailed guide to the key parameters. +.. rubric:: Note on paths -instrument ----------- -**Default:** `'/data/colors_data/filters/Generic/Johnson'` +All path parameters (``instrument``, ``stellar_atm``, ``vega_sed``) are resolved +using the same logic: -This points to the directory containing the filter transmission curves you wish to use. The path must be structured as ``facility/instrument``. - -* The directory must contain a file named after the instrument (e.g., ``Johnson``) which acts as an index. -* The module will read every ``.dat`` file listed in that directory and create a corresponding history column for it. +* ``'data/colors_data/...'`` — no leading slash; ``$MESA_DIR`` is prepended. This + is the recommended form for all standard data paths. +* ``'/absolute/path/...'`` — tested on disk first; if found, used as-is. If not + found, ``$MESA_DIR`` is prepended (preserves backwards compatibility). +* ``'./local/path/...'`` or ``'../up/one/...'`` — used exactly as supplied, + relative to the MESA working directory. **Example:** .. code-block:: fortran - instrument = '/data/colors_data/filters/GAIA/GAIA' + instrument = 'data/colors_data/filters/GAIA/GAIA' stellar_atm ----------- -**Default:** `'/data/colors_data/stellar_models/Kurucz2003all/'` +**Default:** ``'data/colors_data/stellar_models/Kurucz2003all/'`` -Specifies the path to the directory containing the grid of stellar atmosphere models. This directory must contain: +Path to the directory containing the grid of stellar atmosphere models. Paths may be relative to ``$MESA_DIR``, relative to the working directory, or absolute. This directory must contain: -1. **lookup_table.csv**: A map linking filenames to physical parameters (:math:`T_{\rm eff}`, :math:`\log g`, [M/H]). +1. **lookup_table.csv**: A map linking filenames to physical parameters (T_eff`, log g, [M/H]). 2. **SED files**: The actual spectra (text or binary format). 3. **flux_cube.bin**: (Optional but recommended) A binary cube for rapid interpolation. -The module queries this grid using the star's current parameters. If the star evolves outside the grid boundaries, the module may clamp to the nearest edge or extrapolate, depending on internal settings. +The module queries this grid using the star's current parameters. If the star evolves outside the grid boundaries, the module will clamp to the nearest edge. **Example:** .. code-block:: fortran - stellar_atm = '/data/colors_data/stellar_models/sg-SPHINX/' + stellar_atm = 'data/colors_data/stellar_models/sg-SPHINX/' distance -------- -**Default:** `3.0857d19` (10 parsecs in cm) +**Default:** ``3.0857d19`` (10 parsecs in cm) -The distance to the star in centimeters. +The distance to the star in centimetres, used to convert surface flux to observed flux. -* This value is used to convert surface flux to observed flux. -* **Default Behavior:** It defaults to 10 parsecs (:math:`3.0857 \times 10^{19}` cm), resulting in **Absolute Magnitudes**. -* **Custom Usage:** You can set this to a specific source distance (e.g., distance to Betelgeuse) to calculate Apparent Magnitudes. +* **Default Behaviour:** At 10 parsecs (3.0857 * 10^19 cm) the output is **Absolute Magnitudes**. +* **Custom Usage:** Set this to a specific source distance to calculate Apparent Magnitudes. **Example:** .. code-block:: fortran - distance = 5.1839d20 + distance = 5.1839d20 + make_csv -------- -**Default:** `.false.` +**Default:** ``.false.`` If set to ``.true.``, the module exports the full calculated SED at every profile interval. * **Destination:** Files are saved to the directory defined by ``colors_results_directory``. * **Format:** CSV files containing Wavelength vs. Flux. -* **Use Case:** useful for debugging or plotting the full spectrum of the star at a specific age. +* **Use Case:** Useful for debugging or plotting the full spectrum of the star at a specific evolutionary age. **Example:** .. code-block:: fortran - make_csv = .true. + make_csv = .true. colors_results_directory ------------------------ -**Default:** `'SED'` +**Default:** ``'SED'`` -The folder where csv files (if ``make_csv = .true.``) and other debug outputs are saved. +The folder where CSV files (if ``make_csv = .true.``) and other outputs are saved. **Example:** .. code-block:: fortran - colors_results_directory = 'sed' + colors_results_directory = 'sed' mag_system ---------- -**Default:** `'Vega'` +**Default:** ``'Vega'`` Defines the zero-point system for magnitude calculations. Options are: * ``'AB'``: Based on a flat spectral flux density of 3631 Jy. * ``'ST'``: Based on a flat spectral flux density per unit wavelength. -* ``'Vega'``: Calibrated such that the star Vega has magnitude 0 in all bands. +* ``'Vega'``: Calibrated such that Vega has magnitude 0 in all bands. **Example:** .. code-block:: fortran - mag_system = 'AB' + mag_system = 'AB' vega_sed -------- -**Default:** `'/data/colors_data/stellar_models/vega_flam.csv'` +**Default:** ``'data/colors_data/stellar_models/vega_flam.csv'`` -Required only if ``mag_system = 'Vega'``. This points to the reference SED file for Vega. The default path points to a file provided with the MESA data distribution. +Required only if ``mag_system = 'Vega'``. Points to the reference SED file for Vega, used to compute photometric zero-points. Paths may be relative to ``$MESA_DIR``, relative to the working directory, or absolute. **Example:** .. code-block:: fortran - vega_sed = '/another/file/for/vega_SED.csv' + vega_sed = '/path/to/my/vega_SED.csv' Data Preparation (SED_Tools) ============================ The ``colors`` module requires pre-processed stellar atmospheres and filter -profiles organized in a very specific directory structure. To automate this +profiles organised in a specific directory structure. To automate this entire workflow, we provide the dedicated repository: **Repository:** `SED_Tools `_ @@ -248,15 +223,14 @@ filter transmission curves from the following public archives: * `MAST BOSZ Stellar Atmosphere Library `_ * `MSG / Townsend Atmosphere Grids `_ -These sources provide heterogeneous formats and file organizations. SED_Tools -standardizes them into the exact structure required by MESA: +These sources provide heterogeneous formats and file organisations. SED_Tools +standardises them into the exact structure required by MESA: * ``lookup_table.csv`` -* Raw SED files (text or/and HDF5) +* Raw SED files (text and/or HDF5) * ``flux_cube.bin`` (binary cube for fast interpolation) * Filter index files and ``*.dat`` transmission curves - SED_Tools produces: .. code-block:: text @@ -288,19 +262,18 @@ This server provides a live view of: Defaults Reference ================== -Below are the default values for the colors module parameters as defined in ``colors.defaults``. These are used if you do not override them in your inlist. +Below are the default values for all user-facing ``colors`` module parameters as defined in ``colors.defaults``. .. code-block:: fortran use_colors = .false. - instrument = '/data/colors_data/filters/Generic/Johnson' - vega_sed = '/data/colors_data/stellar_models/vega_flam.csv' - stellar_atm = '/data/colors_data/stellar_models/Kurucz2003all/' + instrument = 'data/colors_data/filters/Generic/Johnson' + stellar_atm = 'data/colors_data/stellar_models/Kurucz2003all/' + vega_sed = 'data/colors_data/stellar_models/vega_flam.csv' distance = 3.0857d19 ! 10 parsecs in cm (Absolute Magnitude) make_csv = .false. colors_results_directory = 'SED' mag_system = 'Vega' - vega_sed = '/data/colors_data/stellar_models/vega_flam.csv' Visual Summary of Data Flow =========================== @@ -318,8 +291,8 @@ Visual Summary of Data Flow | 1. Query Stellar Atmosphere Grid with input model | | 2. Interpolate grid to construct specific SED | | 3. Convolve SED with filters to generate band flux | - | 2. Apply distance flux dilution to generate bolometric flux -> Flux_bol | - | 4. Apply zero point (Vega/AB/ST) to generate magnitudes | + | 4. Apply distance flux dilution to generate bolometric flux -> Flux_bol | + | 5. Apply zero point (Vega/AB/ST) to generate magnitudes | | (Both bolometric and per filter) | +-------------------------------------------------------------------------+ | @@ -332,60 +305,208 @@ Visual Summary of Data Flow +----------------------+ - +===================== +===================== Python Helper Scripts ===================== -A collection of Python scripts is provided in the ``python_helpers/`` directory to assist with real-time monitoring, visualization, and analysis of the colors module output. +All Python helpers live in ``python_helpers/`` and are run from that directory. +All paths default to ``../LOGS/`` and ``../SED/`` relative to that location, matching the standard test suite layout. + +plot_history_live.py +-------------------- + +**Purpose:** Live-updating four-panel diagnostic viewer for ``history.data``, designed to run *during* a MESA simulation. + +**What it shows:** + +* **Top-left**: Color–magnitude diagram (CMD) constructed automatically from whichever filters are present in the history file. Filter priority for color index selection follows: Gaia (Gbp−Grp), then Johnson (B−R or B−V), then Sloan (g−r), with a fallback to the first and last available filter. +* **Top-right**: Classical HR diagram (Teff vs. log L). +* **Bottom-left**: Color index as a function of stellar age. +* **Bottom-right**: Light curves for all available filter bands simultaneously. + +All four panels are color-coded by MESA's ``phase_of_evolution`` integer if that column is present in the history file. If it is absent, points are colored by stellar age using a compressed inferno colormap that emphasises the most recent evolution. + +**How to use:** + +.. code-block:: bash + + cd python_helpers + python plot_history_live.py -Dependencies: - * ``matplotlib`` - * ``numpy`` - * ``mesa_reader`` (ensure this is installed/accessible) - * ``ffmpeg`` (optional, for movie generation) +The script polls ``../LOGS/history.data`` every 0.1 seconds and updates the plot whenever the file changes. It will print a change notification for the first five updates, then go silent to avoid log spam. Close the window to exit. -HISTORY_check.py + +plot_history.py +--------------- + +**Purpose:** Single-shot (non-live) version of the history viewer. Reads the completed ``history.data`` once and renders the same four-panel figure. Intended for post-run analysis and as a shared library imported by ``plot_history_live.py``, ``movie_history.py``, and ``movie_cmd_3d.py``. + +**How to use:** + +.. code-block:: bash + + cd python_helpers + python plot_history.py + +Produces a static figure and then calls ``plt.show()``. The script also exports ``MesaView``, ``read_header_columns``, and ``setup_hr_diagram_params`` for use by the other scripts. + +**Note:** The first 5 model rows are skipped (``MesaView`` skip=5) to avoid noisy pre-MS relaxation artifacts at the very start of the run. + +plot_sed_live.py ---------------- -**Usage:** ``python python_helpers/HISTORY_check.py`` +**Purpose:** Live-updating SED viewer that monitors the ``SED/`` directory for CSV files written by the colors module (requires ``make_csv = .true.`` in ``inlist_colors``). -A real-time dashboard that monitors your ``history.data`` file as MESA runs. It automatically refreshes when new data is written. +**What it shows:** A single plot with a logarithmic wavelength axis showing: -**Plots:** - 1. **HR Diagram (Color vs. Magnitude):** Points colored by evolutionary phase. - 2. **Theoretical HR (Teff vs. Log L):** Standard theoretical track. - 3. **Color Evolution:** Color index vs. Age. - 4. **Light Curves:** Absolute magnitude vs. Age for all filters. +* The full stellar SED (black line, plotted once from the first file found). +* The filter-convolved flux for each band (one colored line per filter). +* The Vega reference SED if a ``VEGA_*`` file is present. +* Colored background bands marking the X-ray, UV, optical, and IR regions of the electromagnetic spectrum. +* A text box in the corner showing the stellar mass, metallicity, and distance parsed from ``inlist_colors``. -**Requirements:** Ensure ``phase_of_evolution`` is present in your ``history_columns.list`` for phase-based coloring. +The x-axis auto-scales to the wavelength range where the SED has flux above 1% of its peak; the y-axis scales to the range of convolved fluxes. + +**How to use:** + +.. code-block:: bash + + cd python_helpers + python plot_sed_live.py + +Runs live, refreshing every 0.1 s. Close the window or press Ctrl-C to stop. Can also be configured to save a video instead of displaying live by setting ``save_video=True`` in the ``SEDChecker`` constructor. -SED_check.py ------------- -**Usage:** ``python python_helpers/SED_check.py`` -Monitors the ``colors_results_directory`` (default: ``SED/``) for new CSV output. +plot_sed.py +----------- + +**Purpose:** Single-shot SED plot that reads all ``*SED.csv`` files in ``../SED/`` and overlays them in one figure. Simpler than ``plot_sed_live.py``—no live monitoring, no EM region shading, no inlist parsing. + +**How to use:** + +.. code-block:: bash + + cd python_helpers + python plot_sed.py + +Displays the combined SED figure. The x-axis is cropped to 0–60000 Å by default; edit the ``xlim`` argument in ``main()`` to change this. + + +plot_cmd_3d.py +-------------- + +**Purpose:** Interactive 3D scatter plot of the CMD with a user-selectable third axis. Useful for visualising how any history column correlates with the photometric evolution of the star. + +**How to use:** + +.. code-block:: bash + + cd python_helpers + python plot_cmd_3d.py + +On launch, the script prints all available columns from ``history.data`` and prompts you to type the name of the column to use for the Z axis (default: ``Interp_rad``, the interpolation radius in the atmosphere grid). Press Enter to accept the default or type any other column name. A rotatable 3D matplotlib window then opens showing color index vs. magnitude vs. your chosen column. + + +movie_history.py +---------------- + +**Purpose:** Renders the same four-panel display as ``plot_history_live.py`` into an MP4 video, with one frame per model row. Points accumulate from left to right in time, so the full evolutionary track builds up across the video. + +**How to use:** + +.. code-block:: bash + + cd python_helpers + python movie_history.py + +Writes ``history.mp4`` in the current directory at 24 fps, 150 dpi. Requires ``ffmpeg`` to be installed and accessible on ``$PATH``. Progress is shown with a ``tqdm`` progress bar if that package is available. + + +movie_cmd_3d.py +--------------- + +**Purpose:** Creates a 3D rotation video that starts looking straight down the ``Interp_rad`` axis (showing a plain CMD) and rotates to an oblique 3D perspective over 10 seconds, with 1-second holds at each end. + +**How to use:** + +.. code-block:: bash + + cd python_helpers + python movie_cmd_3d.py + +Writes ``cmd_interp_rad_rotation.mp4`` in the current directory at 30 fps. Requires ``ffmpeg``. -**Features:** - * Plots the full high-resolution stellar spectrum (black line). - * Plots the filter-convolved fluxes (colored lines corresponding to filters). - * Displays stellar parameters (Mass, Z, Distance) parsed from your inlist. - * Updates automatically if ``make_csv = .true.`` in your inlist and MESA overwrites the files. -interactive_cmd_3d.py ---------------------- +plot_newton_iter.py +------------------- -**Usage:** ``python python_helpers/interactive_cmd_3d.py`` +**Purpose:** Plots the per-Newton-iteration photometry data written to ``SED/iteration_colors.data`` by the colors module's solver-monitor hook. This file records the photometry at every Newton iteration within each timestep, capturing sub-timestep variability during convergence. -Generates an interactive 3D Color-Magnitude Diagram. +The script supports both an **interactive mode** (a terminal UI with a filterable column picker) and a **batch mode** driven by command-line arguments. Both modes support arbitrary column expressions (e.g., ``B-V``, ``(V-U)/(B-V)``, ``Teff/1000``) in addition to named columns. When a ``history.data`` file is present, the history track is overlaid on the plot in grey for comparison. Plots are saved as both PDF and JPG. -* **X-axis:** Color Index (e.g., B-V or Gbp-Grp). -* **Y-axis:** Magnitude (e.g., V or G). -* **Z-axis:** User-selectable column from the history file (e.g., ``Interp_rad``, ``star_age``, ``log_R``). The script will prompt you to choose a column at runtime. +**Interactive mode:** -Movie Makers +.. code-block:: bash + + cd python_helpers + python plot_newton_iter.py + +You will be prompted to select the plot type (2D scatter, 2D line, or 3D scatter), then the X, Y, (optionally Z), and color axes from a grid display. The picker supports substring filtering (``/text``), negative filtering (``!text``), and regex filtering (``//pattern``). You can also type a column name directly. + + +movie_newton_iter.py +-------------------- + +**Purpose:** Creates an animated MP4 showing the Newton iteration data accumulating point by point over time. Iterations are sorted by model number then iteration number, so the animation follows the physical sequence of the simulation. History file points are overlaid incrementally—a new history point appears each time a model's full set of iterations is complete. Outliers are removed via iterative sigma clipping before rendering. Axis limits expand smoothly as new data arrives. + +The script supports the same interactive and batch modes as ``plot_newton_iter.py``, and imports all shared functionality (terminal UI, data loading, expression parsing) directly from that module. + +**Interactive mode:** + +.. code-block:: bash + + cd python_helpers + python movie_newton_iter.py + +You will be prompted for X, Y, and color axes, video duration, FPS, output filename, and sigma-clipping threshold. + +Requires ``ffmpeg``. For GIF output, ``pillow`` can be used instead by changing the writer in the source. + + +plot_zero_points.py +------------------- + +**Purpose:** Runs MESA three times in sequence with ``mag_system`` set to ``'Vega'``, ``'AB'``, and ``'ST'`` respectively, then overlays the resulting CMDs in a single comparison figure. Useful for quantifying the offsets between magnitude systems for a given set of filters. + +**How to use:** + +.. code-block:: bash + + cd python_helpers + python plot_zero_points.py + +The script must be run from within ``python_helpers/`` (it changes directory to ``../`` before calling ``./rn``). It temporarily modifies ``inlist_colors`` to set the magnitude system and to disable PGstar, restores the original file after each run, and saves each LOGS directory as ``LOGS_Vega``, ``LOGS_AB``, and ``LOGS_ST``. The comparison figure is saved to ``mag_system_comparison.png``. + +.. warning:: + + This script runs the full simulation three times. For large or slow models, consider reducing ``initial_mass`` or loosening ``varcontrol_target`` and ``mesh_delta_coeff`` inside the ``FAST_RUN_OVERRIDES`` dictionary at the top of the script before running. + + +run_batch.py ------------ -Scripts to generate MP4 animations of your run. Requires ``ffmpeg``. +**Purpose:** Runs MESA in batch over a list of parameter files, each of which defines a different stellar type (M dwarf, Pop III, OB star, etc.). After each run, the history file is renamed to avoid being overwritten by the next run. + +**How to use:** + +Edit ``param_options`` at the top of the script to list the ``extra_controls_inlist_name`` files you want to loop over, then run: + +.. code-block:: bash + + cd python_helpers + python run_batch.py + +The script: (1) comments out all parameter entries in ``inlist_1.0``, (2) uncomments one entry at a time, (3) runs ``./clean``, ``./mk``, and ``./rn`` in the MESA work directory, (4) renames the resulting ``history.data`` to ``history_.data``, and (5) re-comments all entries at the end. If a run fails, a warning is printed and the loop continues with the next parameter set. -* **make_history_movie.py**: Creates ``history.mp4``, an animated version of the ``HISTORY_check.py`` dashboard showing the evolution over time. -* **make_CMD_InterpRad_movie.py**: Creates ``cmd_interp_rad_rotation.mp4``, a rotating 3D view of the Color-Magnitude Diagram with the Interpolation Radius on the Z-axis. Useful for visualizing grid coverage and interpolation quality. +**Note:** The MESA work directory is assumed to be one level above the location of ``run_batch.py`` (i.e., ``../``). The inlist to modify is ``inlist_1.0``. Adjust these paths at the top of the script if your layout differs. diff --git a/star/test_suite/custom_colors/inlist_colors b/star/test_suite/custom_colors/inlist_colors index cfdf11b54..3d546a538 100644 --- a/star/test_suite/custom_colors/inlist_colors +++ b/star/test_suite/custom_colors/inlist_colors @@ -57,15 +57,15 @@ ! This points to a directory containing Johnson filter transmission curves. ! The Johnson directory should contain files like: ! - U.dat, B.dat, V.dat, R.dat, I.dat, J.dat, M.dat (filter transmission curves) - instrument = '/data/colors_data/filters/Generic/Johnson' + instrument = 'data/colors_data/filters/Generic/Johnson' ! Stellar atmosphere model directory ! Should contain: ! - lookup_table.csv: mapping file with columns for filename, Teff, log_g, metallicity ! - Individual SED files: wavelength vs surface flux for each stellar model ! - flux_cube.bin: pre-computed interpolation grid (if using linear/hermite interpolation) - stellar_atm = '/data/colors_data/stellar_models/Kurucz2003all/' - + stellar_atm = 'data/colors_data/stellar_models/Kurucz2003all' + ! Physical parameters for synthetic photometry distance = 3.0857d19 ! Distance to star in cm (1 pc = 3.0857e16 m for absolute mags) @@ -80,7 +80,7 @@ ! Vega spectrum for zero-point calibration ! This file should contain wavelength (Å) and flux (erg/s/cm²/Å) columns ! Used to define the zero-point of the magnitude system (Vega = 0.0 mag) - vega_sed = '/data/colors_data/stellar_models/vega_flam.csv' + vega_sed = 'data/colors_data/stellar_models/vega_flam.csv' / ! end of colors namelist @@ -105,7 +105,7 @@ xa_central_lower_limit(1) = 1d-2 ! Fail-safe max age (just in case) - max_age = 1d12 + max_age = 1d8 ! --- OUTPUT CONTROL --- ! High resolution history for smooth light curves diff --git a/star/test_suite/custom_colors/python_helpers/make_CMD_InterpRad_movie.py b/star/test_suite/custom_colors/python_helpers/movie_cmd_3d.py similarity index 96% rename from star/test_suite/custom_colors/python_helpers/make_CMD_InterpRad_movie.py rename to star/test_suite/custom_colors/python_helpers/movie_cmd_3d.py index 905c61bf0..c14a96004 100644 --- a/star/test_suite/custom_colors/python_helpers/make_CMD_InterpRad_movie.py +++ b/star/test_suite/custom_colors/python_helpers/movie_cmd_3d.py @@ -6,7 +6,7 @@ import mesa_reader as mr from matplotlib.animation import FFMpegWriter, FuncAnimation from mpl_toolkits.mplot3d import Axes3D # noqa: F401 -from static_HISTORY_check import MesaView, read_header_columns, setup_hr_diagram_params +from plot_history import MesaView, read_header_columns, setup_hr_diagram_params def make_cmd_rotation_video( diff --git a/star/test_suite/custom_colors/python_helpers/make_history_movie.py b/star/test_suite/custom_colors/python_helpers/movie_history.py similarity index 98% rename from star/test_suite/custom_colors/python_helpers/make_history_movie.py rename to star/test_suite/custom_colors/python_helpers/movie_history.py index f83bc5397..f0acb2726 100644 --- a/star/test_suite/custom_colors/python_helpers/make_history_movie.py +++ b/star/test_suite/custom_colors/python_helpers/movie_history.py @@ -4,8 +4,8 @@ import matplotlib.pyplot as plt import numpy as np -from HISTORY_check import HistoryChecker # uses static_HISTORY_check under the hood from matplotlib.animation import FFMpegWriter +from plot_history_live import HistoryChecker # uses static_HISTORY_check under the hood do_tqdm = True try: diff --git a/star/test_suite/custom_colors/python_helpers/movie_newton_iter.py b/star/test_suite/custom_colors/python_helpers/movie_newton_iter.py new file mode 100644 index 000000000..a85e96a0e --- /dev/null +++ b/star/test_suite/custom_colors/python_helpers/movie_newton_iter.py @@ -0,0 +1,768 @@ +#!/usr/bin/env python3 +""" +plot_newton_iter_video.py — Video plotter for MESA Colors per-iteration output + +Creates an animated video showing Newton iterations accumulating over time. +Points appear one by one, sorted by model number then iteration number. +History file points appear incrementally as each timestep completes. + +Author: Niall Miller (2025) +""" + +import argparse +import os +import re +import sys +from typing import List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import cm +from matplotlib.animation import FFMpegWriter, FuncAnimation, PillowWriter +from matplotlib.colors import Normalize + +# Import shared functionality from static plotter +from plot_newton_iter import ( + CYAN, # Terminal UI; Data loading; Prompting + DIM, + GREEN, + RESET, + load_history_data, + load_iteration_data, + print_error, + print_header, + print_info, + print_subheader, + print_success, + prompt_yes_no, + resolve_axis, + resolve_history_axis, + term_width, +) + +# ============================================================================ +# SIGMA CLIPPING +# ============================================================================ + + +def sigma_clip_mask( + data_arrays: List[np.ndarray], sigma: float = 3.0, max_iter: int = 5 +) -> np.ndarray: + """ + Create a mask identifying outliers using iterative sigma clipping. + + Clips based on ALL provided arrays - a point is kept only if it's + within sigma of the median in ALL arrays. + """ + n_points = len(data_arrays[0]) + mask = np.ones(n_points, dtype=bool) + + for iteration in range(max_iter): + prev_mask = mask.copy() + + for arr in data_arrays: + valid_data = arr[mask] + if len(valid_data) < 3: + continue + + median = np.median(valid_data) + mad = np.median(np.abs(valid_data - median)) + std_equiv = mad * 1.4826 + + if std_equiv > 0: + deviation = np.abs(arr - median) + mask &= deviation < sigma * std_equiv + + if np.array_equal(mask, prev_mask): + break + + return mask + + +# ============================================================================ +# VIDEO-SPECIFIC DATA HANDLING +# ============================================================================ + + +def sort_by_model_and_iter( + data: np.ndarray, column_names: List[str] +) -> Tuple[np.ndarray, np.ndarray]: + """Sort data by model number, then by iteration number.""" + model_idx = column_names.index("model") + iter_idx = column_names.index("iter") + sort_indices = np.lexsort((data[:, iter_idx], data[:, model_idx])) + return data[sort_indices], sort_indices + + +def get_model_completion_indices(model_data: np.ndarray) -> np.ndarray: + """ + Find the index where each model's iterations end. + Returns array of indices - the last point for each unique model. + """ + indices = [] + unique_models = np.unique(model_data) + + for model in unique_models: + model_mask = model_data == model + model_indices = np.where(model_mask)[0] + # Last index for this model + indices.append(model_indices[-1]) + + return np.array(sorted(indices)) + + +# ============================================================================ +# VIDEO CREATION +# ============================================================================ + + +def create_video( + x_data: np.ndarray, + y_data: np.ndarray, + color_data: np.ndarray, + model_data: np.ndarray, + iter_data: np.ndarray, + x_label: str, + y_label: str, + color_label: str, + output_path: str, + duration: float = 30.0, + fps: int = 30, + cmap: str = "viridis", + point_size: int = 20, + alpha: float = 0.7, + flip_y: bool = False, + dpi: int = 150, + sigma_clip: float = 3.0, + history_x: Optional[np.ndarray] = None, + history_y: Optional[np.ndarray] = None, +) -> None: + """ + Create animated video of points appearing one by one. + + History points appear incrementally as each model's iterations complete. + Axis limits dynamically expand with smooth easing. + """ + + # ========================================= + # SIGMA CLIPPING - Remove outliers + # ========================================= + print_info(f"Applying {sigma_clip}-sigma clipping to remove outliers...") + + valid_mask = sigma_clip_mask([x_data, y_data], sigma=sigma_clip) + n_removed = np.sum(~valid_mask) + + if n_removed > 0: + print_info( + f"Removed {n_removed} outlier points ({100 * n_removed / len(x_data):.1f}%)" + ) + x_data = x_data[valid_mask] + y_data = y_data[valid_mask] + color_data = color_data[valid_mask] + model_data = model_data[valid_mask] + iter_data = iter_data[valid_mask] + else: + print_info("No outliers detected") + + n_points = len(x_data) + + if n_points == 0: + print_error("No points remaining after sigma clipping!") + return + + # Find where each model ends (for incremental history plotting) + model_completion_indices = get_model_completion_indices(model_data) + n_models = len(model_completion_indices) + + # Check history data availability + has_history = ( + history_x is not None and history_y is not None and len(history_x) >= n_models + ) + + if has_history: + print_info( + f"History data: {len(history_x)} points available, {n_models} will be shown incrementally" + ) + # Truncate history to match number of models in iteration data + history_x = history_x[:n_models] + history_y = history_y[:n_models] + else: + print_info("No matching history data for incremental overlay") + + # Calculate frames and timing + total_frames = int(duration * fps) + + if n_points <= total_frames: + points_per_frame = 1 + total_frames = n_points + else: + points_per_frame = max(1, n_points // total_frames) + total_frames = (n_points + points_per_frame - 1) // points_per_frame + + actual_duration = total_frames / fps + + print_info(f"Total points: {n_points}") + print_info(f"Points per frame: {points_per_frame}") + print_info(f"Total frames: {total_frames}") + print_info(f"Video duration: {actual_duration:.1f}s at {fps} fps") + + # ========================================= + # PRECOMPUTE AXIS LIMITS + # ========================================= + print_info("Precomputing dynamic axis limits...") + + # Final limits including history data + all_x = [x_data] + all_y = [y_data] + if has_history: + all_x.append(history_x) + all_y.append(history_y) + + all_x = np.concatenate(all_x) + all_y = np.concatenate(all_y) + + x_min_final, x_max_final = np.min(all_x), np.max(all_x) + y_min_final, y_max_final = np.min(all_y), np.max(all_y) + + x_range_final = x_max_final - x_min_final + y_range_final = y_max_final - y_min_final + + pad_frac = 0.08 + x_min_final -= x_range_final * pad_frac + x_max_final += x_range_final * pad_frac + y_min_final -= y_range_final * pad_frac + y_max_final += y_range_final * pad_frac + + # ========================================= + # SET UP FIGURE + # ========================================= + fig, ax = plt.subplots(figsize=(10, 8)) + + ax.set_xlabel(x_label, fontsize=12) + ax.set_ylabel(y_label, fontsize=12) + ax.grid(True, alpha=0.3) + + norm = Normalize(vmin=np.min(color_data), vmax=np.max(color_data)) + colormap = cm.get_cmap(cmap) + + sm = cm.ScalarMappable(cmap=colormap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax, pad=0.02) + cbar.set_label(color_label, fontsize=12) + + title = ax.set_title("", fontsize=14, fontweight="bold") + + # Initialize scatter plots + scatter = ax.scatter([], [], c=[], cmap=cmap, norm=norm, s=point_size, alpha=alpha) + scatter_history = ax.scatter( + [], + [], + c="black", + marker="x", + s=point_size * 0.5, + linewidths=2, + label="History", + zorder=10, + ) + ax.legend(loc="upper right") + + plt.tight_layout() + + def get_dynamic_limits( + n_show: int, n_history_show: int + ) -> Tuple[float, float, float, float]: + """Compute axis limits that smoothly expand to contain all visible points.""" + x_visible = list(x_data[:n_show]) + y_visible = list(y_data[:n_show]) + + if has_history and n_history_show > 0: + x_visible.extend(history_x[:n_history_show]) + y_visible.extend(history_y[:n_history_show]) + + x_visible = np.array(x_visible) + y_visible = np.array(y_visible) + + x_min_curr = np.min(x_visible) + x_max_curr = np.max(x_visible) + y_min_curr = np.min(y_visible) + y_max_curr = np.max(y_visible) + + x_range_curr = max(x_max_curr - x_min_curr, x_range_final * 0.01) + y_range_curr = max(y_max_curr - y_min_curr, y_range_final * 0.01) + + x_min_padded = x_min_curr - x_range_curr * pad_frac + x_max_padded = x_max_curr + x_range_curr * pad_frac + y_min_padded = y_min_curr - y_range_curr * pad_frac + y_max_padded = y_max_curr + y_range_curr * pad_frac + + # Progress and easing + progress = n_show / n_points + ease = 1 - (1 - progress) ** 2 + + x_min = x_min_padded + (x_min_final - x_min_padded) * ease + x_max = x_max_padded + (x_max_final - x_max_padded) * ease + y_min = y_min_padded + (y_min_final - y_min_padded) * ease + y_max = y_max_padded + (y_max_final - y_max_padded) * ease + + return x_min, x_max, y_min, y_max + + def init(): + scatter.set_offsets(np.empty((0, 2))) + scatter.set_array(np.array([])) + scatter_history.set_offsets(np.empty((0, 2))) + ax.set_xlim(x_min_final, x_max_final) + if flip_y: + ax.set_ylim(y_max_final, y_min_final) + else: + ax.set_ylim(y_min_final, y_max_final) + return scatter, scatter_history, title + + def update(frame): + n_show = min((frame + 1) * points_per_frame, n_points) + + # Update Newton iteration scatter + x_show = x_data[:n_show] + y_show = y_data[:n_show] + c_show = color_data[:n_show] + + scatter.set_offsets(np.column_stack([x_show, y_show])) + scatter.set_array(c_show) + + # Count how many models have completed (for incremental history) + n_history_show = 0 + if has_history: + # A model is complete when we've shown PAST its completion index + # (history X appears one point after the model's last iteration) + n_history_show = np.sum(model_completion_indices < n_show - 1) + + if n_history_show > 0: + scatter_history.set_offsets( + np.column_stack( + [history_x[:n_history_show], history_y[:n_history_show]] + ) + ) + else: + scatter_history.set_offsets(np.empty((0, 2))) + + # Update dynamic axis limits + x_min, x_max, y_min, y_max = get_dynamic_limits(n_show, n_history_show) + ax.set_xlim(x_min, x_max) + if flip_y: + ax.set_ylim(y_max, y_min) + else: + ax.set_ylim(y_min, y_max) + + # Update title + current_model = int(model_data[n_show - 1]) + current_iter = int(iter_data[n_show - 1]) + title.set_text( + f"Newton Iteration Colors — Model {current_model}, Iter {current_iter}\n" + f"Points: {n_show}/{n_points} | Timesteps: {n_history_show}" + ) + + return scatter, scatter_history, title + + # Create animation + print_info("Generating animation frames...") + anim = FuncAnimation( + fig, + update, + frames=total_frames, + init_func=init, + blit=False, + interval=1000 / fps, + ) + + # Save video + print_info(f"Encoding video to {output_path}...") + + if output_path.endswith(".gif"): + writer = PillowWriter(fps=fps) + else: + try: + writer = FFMpegWriter( + fps=fps, metadata={"title": "Newton Iterations"}, bitrate=5000 + ) + except Exception: + print_info("FFmpeg not available, falling back to GIF output") + output_path = output_path.rsplit(".", 1)[0] + ".gif" + writer = PillowWriter(fps=fps) + + anim.save(output_path, writer=writer, dpi=dpi) + plt.close(fig) + + print_success(f"Saved: {output_path}") + + +# ============================================================================ +# INTERACTIVE PROMPTS +# ============================================================================ + + +def prompt_axis_choice( + column_names: List[str], + data: np.ndarray, + label: str, +) -> Optional[Tuple[np.ndarray, str]]: + """Prompt user to select an axis column or expression.""" + N = len(column_names) + + print_subheader(f"{label} ({CYAN}{N}{RESET} columns)") + + col_width = max(len(s) for s in column_names) + 8 + cols = max(1, min(3, term_width() // col_width)) + + for i, name in enumerate(column_names): + end = "\n" if (i + 1) % cols == 0 else "" + print(f" [{GREEN}{i:2d}{RESET}] {name:<{col_width - 8}}", end=end) + if N % cols != 0: + print() + + print( + f"\n{DIM}Enter: column number | column name | expression (e.g. B-R, [9]-[13], Teff/1000){RESET}" + ) + + while True: + inp = input(f"\n{CYAN}>{RESET} ").strip() + + if not inp: + continue + + if inp.lower() == "q": + return None + + try: + arr, lbl = resolve_axis(inp, column_names, data) + return arr, lbl + except ValueError as e: + print_error(str(e)) + + +def prompt_duration(n_points: int, default: float = 30.0) -> Tuple[float, int]: + """Prompt user for video duration, showing points info.""" + print_subheader("Video Duration") + print_info(f"Total data points: {n_points}") + + while True: + inp = input(f"Video duration in seconds {DIM}[{default}]{RESET}: ").strip() + + if not inp: + duration = default + break + + try: + duration = float(inp) + if duration <= 0: + print_error("Duration must be positive") + continue + break + except ValueError: + print_error("Invalid number") + + default_fps = 30 + while True: + inp = input(f"Frames per second {DIM}[{default_fps}]{RESET}: ").strip() + + if not inp: + fps = default_fps + break + + try: + fps = int(inp) + if fps <= 0: + print_error("FPS must be positive") + continue + break + except ValueError: + print_error("Invalid number") + + total_frames = int(duration * fps) + points_per_frame = max(1, n_points // total_frames) + points_per_second = points_per_frame * fps + + print_info(f"≈ {points_per_second} points/second ({points_per_frame} points/frame)") + + return duration, fps + + +# ============================================================================ +# MAIN WORKFLOWS +# ============================================================================ + + +def run_interactive(filepath: str, history_file: str = "../LOGS/history.data") -> None: + """Run interactive mode.""" + print_header("MESA Colors — Newton Iteration Video Maker") + + print_info(f"Loading: {filepath}") + try: + column_names, data = load_iteration_data(filepath) + except Exception as e: + print_error(f"Failed to load file: {e}") + sys.exit(1) + + if "model" not in column_names or "iter" not in column_names: + print_error("Data must have 'model' and 'iter' columns") + sys.exit(1) + + data, _ = sort_by_model_and_iter(data, column_names) + print_success(f"Loaded {data.shape[0]} data points, {data.shape[1]} columns") + print_success("Data sorted by model number, then iteration") + + # Load history file + md = load_history_data(history_file) + + model_idx = column_names.index("model") + iter_idx = column_names.index("iter") + model_data = data[:, model_idx] + iter_data = data[:, iter_idx] + + # Select X axis + print(f"\n{DIM}Select X-axis (or enter expression like 'B-R'):{RESET}") + result = prompt_axis_choice(column_names, data, "X-axis") + if result is None: + return + x_data, x_label = result + print_success(f"X-axis: {x_label}") + + # Select Y axis + print(f"\n{DIM}Select Y-axis:{RESET}") + result = prompt_axis_choice(column_names, data, "Y-axis") + if result is None: + return + y_data, y_label = result + print_success(f"Y-axis: {y_label}") + + flip_y = prompt_yes_no("Flip Y axis?", default=False) + + # Select color axis + print(f"\n{DIM}Select Color axis:{RESET}") + result = prompt_axis_choice(column_names, data, "Color") + if result is None: + return + color_data, color_label = result + print_success(f"Color: {color_label}") + + # Prompt for duration + duration, fps = prompt_duration(len(x_data)) + + # Output filename + safe_x = re.sub(r"[^\w\-]", "_", x_label) + safe_y = re.sub(r"[^\w\-]", "_", y_label) + default_output = f"newton_iter_{safe_y}_vs_{safe_x}.mp4" + + out_inp = input(f"Output filename {DIM}[{default_output}]{RESET}: ").strip() + output_path = out_inp if out_inp else default_output + + # Sigma clipping + sigma_inp = input(f"Sigma clipping threshold {DIM}[3.0]{RESET}: ").strip() + try: + sigma_clip = float(sigma_inp) if sigma_inp else 3.0 + except ValueError: + sigma_clip = 3.0 + print_info("Invalid input, using default 3.0") + + # Get history data for overlay + history_x, history_y = None, None + if md is not None: + history_x = resolve_history_axis(x_label, md) + history_y = resolve_history_axis(y_label, md) + + if history_x is not None and history_y is not None: + print_info(f"History data: {len(history_x)} points available") + else: + missing = [] + if history_x is None: + missing.append(x_label) + if history_y is None: + missing.append(y_label) + print_info(f"Could not find history columns for: {', '.join(missing)}") + + # Create video + print_subheader("Generating Video") + create_video( + x_data, + y_data, + color_data, + model_data, + iter_data, + x_label, + y_label, + color_label, + output_path, + duration=duration, + fps=fps, + flip_y=flip_y, + sigma_clip=sigma_clip, + history_x=history_x, + history_y=history_y, + ) + + print(f"\n{GREEN}Done!{RESET}") + + +def run_batch( + filepath: str, + x_col: str, + y_col: str, + color_col: str, + output: Optional[str] = None, + duration: float = 30.0, + fps: int = 30, + cmap: str = "viridis", + flip_y: bool = False, + sigma_clip: float = 3.0, + history_file: str = "../LOGS/history.data", +) -> None: + """Run in batch mode.""" + column_names, data = load_iteration_data(filepath) + + if "model" not in column_names or "iter" not in column_names: + print_error("Data must have 'model' and 'iter' columns") + sys.exit(1) + + data, _ = sort_by_model_and_iter(data, column_names) + + md = load_history_data(history_file) + + model_idx = column_names.index("model") + iter_idx = column_names.index("iter") + model_data = data[:, model_idx] + iter_data = data[:, iter_idx] + + x_data, x_label = resolve_axis(x_col, column_names, data) + y_data, y_label = resolve_axis(y_col, column_names, data) + color_data, color_label = resolve_axis(color_col, column_names, data) + + if output is None: + safe_x = re.sub(r"[^\w\-]", "_", x_label) + safe_y = re.sub(r"[^\w\-]", "_", y_label) + output = f"newton_iter_{safe_y}_vs_{safe_x}.mp4" + + history_x, history_y = None, None + if md is not None: + history_x = resolve_history_axis(x_label, md) + history_y = resolve_history_axis(y_label, md) + + create_video( + x_data, + y_data, + color_data, + model_data, + iter_data, + x_label, + y_label, + color_label, + output, + duration=duration, + fps=fps, + cmap=cmap, + flip_y=flip_y, + sigma_clip=sigma_clip, + history_x=history_x, + history_y=history_y, + ) + + +# ============================================================================ +# CLI ENTRY +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Create video of MESA Colors per-iteration data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s # Interactive mode + %(prog)s -f SED/iteration_colors.data # Specify file + %(prog)s -x Teff -y V -c model # Batch mode + %(prog)s -x Teff -y "V-U" -c iter # With color index expression + %(prog)s -x Teff -y V -c model --flip-y # Flip Y axis (for magnitudes) + %(prog)s -x Teff -y V -c model -d 60 # 60 second video + %(prog)s -x Teff -y V -c model --sigma 2.5 # Stricter outlier removal + """, + ) + + parser.add_argument( + "-f", + "--file", + default="SED/iteration_colors.data", + help="Path to iteration colors data file (default: SED/iteration_colors.data)", + ) + + parser.add_argument( + "--history", + default="../LOGS/history.data", + help="Path to MESA history file (default: ../LOGS/history.data)", + ) + + parser.add_argument("-x", "--x-col", help="X-axis column name or expression") + parser.add_argument("-y", "--y-col", help="Y-axis column name or expression") + parser.add_argument( + "-c", "--color-col", help="Color axis column name or expression" + ) + parser.add_argument("-o", "--output", help="Output video filename") + parser.add_argument( + "-d", + "--duration", + type=float, + default=30.0, + help="Target video duration in seconds (default: 30)", + ) + parser.add_argument( + "--fps", type=int, default=30, help="Frames per second (default: 30)" + ) + parser.add_argument( + "--cmap", default="viridis", help="Matplotlib colormap (default: viridis)" + ) + parser.add_argument("--flip-y", action="store_true", help="Flip Y axis") + parser.add_argument( + "--sigma", + type=float, + default=3.0, + help="Sigma clipping threshold for outlier removal (default: 3.0)", + ) + parser.add_argument( + "--list-columns", action="store_true", help="List columns and exit" + ) + + args = parser.parse_args() + + if not os.path.exists(args.file): + alt_path = os.path.join("..", args.file) + if os.path.exists(alt_path): + args.file = alt_path + else: + print_error(f"File not found: {args.file}") + sys.exit(1) + + if args.list_columns: + column_names, data = load_iteration_data(args.file) + print_header("Available Columns") + for i, name in enumerate(column_names): + print(f" [{GREEN}{i:2d}{RESET}] {name}") + print(f"\n{DIM}Total: {len(column_names)} columns, {data.shape[0]} rows{RESET}") + return + + if args.x_col and args.y_col and args.color_col: + run_batch( + filepath=args.file, + x_col=args.x_col, + y_col=args.y_col, + color_col=args.color_col, + output=args.output, + duration=args.duration, + fps=args.fps, + cmap=args.cmap, + flip_y=args.flip_y, + sigma_clip=args.sigma, + history_file=args.history, + ) + else: + run_interactive(args.file, args.history) + + +if __name__ == "__main__": + main() diff --git a/star/test_suite/custom_colors/python_helpers/interactive_cmd_3d.py b/star/test_suite/custom_colors/python_helpers/plot_cmd_3d.py similarity index 97% rename from star/test_suite/custom_colors/python_helpers/interactive_cmd_3d.py rename to star/test_suite/custom_colors/python_helpers/plot_cmd_3d.py index d69635200..9b2cf8512 100644 --- a/star/test_suite/custom_colors/python_helpers/interactive_cmd_3d.py +++ b/star/test_suite/custom_colors/python_helpers/plot_cmd_3d.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt import mesa_reader as mr from mpl_toolkits.mplot3d import Axes3D # noqa: F401 -from static_HISTORY_check import MesaView, read_header_columns, setup_hr_diagram_params +from plot_history import MesaView, read_header_columns, setup_hr_diagram_params def get_z_axis_selection(available_columns, default="Interp_rad"): diff --git a/star/test_suite/custom_colors/python_helpers/static_HISTORY_check.py b/star/test_suite/custom_colors/python_helpers/plot_history.py similarity index 78% rename from star/test_suite/custom_colors/python_helpers/static_HISTORY_check.py rename to star/test_suite/custom_colors/python_helpers/plot_history.py index 1a5780158..fd9858b99 100644 --- a/star/test_suite/custom_colors/python_helpers/static_HISTORY_check.py +++ b/star/test_suite/custom_colors/python_helpers/plot_history.py @@ -113,45 +113,100 @@ def read_header_columns(history_file): def setup_hr_diagram_params(md, filter_columns): """Set up parameters for HR diagram based on available filters.""" - if "Gbp" in filter_columns and "Grp" in filter_columns and "G" in filter_columns: - hr_color = md.Gbp - md.Grp - hr_mag = md.G + + # Normalize filter names for case-insensitive matching + fc_lower = [f.lower() for f in filter_columns] + + def has(name): + return name.lower() in fc_lower + + def get(name): + # return actual filter name (case preserved) if present + for f in filter_columns: + if f.lower() == name.lower(): + try: + return getattr(md, f) + except AttributeError: + return md.data(f) + return None + + # --- Gaia-like (Gbp, Grp, G) --- + if has("gbp") and has("grp") and has("g"): + hr_color = get("gbp") - get("grp") + hr_mag = get("g") hr_xlabel = "Gbp - Grp" hr_ylabel = "G" color_index = hr_color - elif "V" in filter_columns and "B" in filter_columns and "R" in filter_columns: - hr_color = md.B - md.R - hr_mag = md.V - hr_xlabel = "B - R" - hr_ylabel = "V" - color_index = hr_color - else: - if len(filter_columns) >= 2: - # Use the first two filters - f1 = filter_columns[0] - f2 = filter_columns[1] - # Retrieve the data using getattr or data method - try: - col1 = getattr(md, f1) - col2 = getattr(md, f2) - except AttributeError: - col1 = md.data(f1) - col2 = md.data(f2) - - hr_color = col1 - col2 - hr_mag = col1 - hr_xlabel = f"{f1} - {f2}" - hr_ylabel = f1 + # --- Johnson-like broadbands --- + elif has("v"): + # B-R if present + if has("b") and has("r"): + hr_color = get("b") - get("r") + hr_mag = get("v") + hr_xlabel = "B - R" + hr_ylabel = "V" color_index = hr_color - else: - # Default values if not enough filters - print("Warning: Not enough filter columns to construct color index") - hr_color = np.zeros_like(md.Teff) - hr_mag = np.zeros_like(md.Teff) - hr_xlabel = "Color Index" - hr_ylabel = "Magnitude" + # B-V if only B present + elif has("b"): + hr_color = get("b") - get("v") + hr_mag = get("v") + hr_xlabel = "B - V" + hr_ylabel = "V" + color_index = hr_color + # V-R if only R present + elif has("r"): + hr_color = get("v") - get("r") + hr_mag = get("v") + hr_xlabel = "V - R" + hr_ylabel = "V" color_index = hr_color + else: + # no recognized pair with V + hr_color = None + + # --- Sloan-like g-r if no V branch matched --- + elif has("g") and has("r"): + hr_color = get("g") - get("r") + hr_mag = get("g") + hr_xlabel = "g - r" + hr_ylabel = "g" + color_index = hr_color + + else: + hr_color = None + + # If we matched a branch with a valid hr_color + if hr_color is not None: + return hr_color, hr_mag, hr_xlabel, hr_ylabel, color_index + + # --- FALLBACK: use first and last filter if nothing above matched --- + if len(filter_columns) >= 2: + f1 = filter_columns[0] + f2 = filter_columns[-1] + try: + col1 = getattr(md, f1) + except AttributeError: + col1 = md.data(f1) + try: + col2 = getattr(md, f2) + except AttributeError: + col2 = md.data(f2) + + hr_color = col1 - col2 + hr_mag = col1 + hr_xlabel = f"{f1} - {f2}" + hr_ylabel = f1 + color_index = hr_color + + else: + # Not enough filters, fallback to flat arrays + print("Warning: Not enough filter columns to construct color index") + hr_color = np.zeros_like(md.Teff) + hr_mag = np.zeros_like(md.Teff) + hr_xlabel = "Color Index" + hr_ylabel = "Magnitude" + color_index = hr_color return hr_color, hr_mag, hr_xlabel, hr_ylabel, color_index diff --git a/star/test_suite/custom_colors/python_helpers/HISTORY_check.py b/star/test_suite/custom_colors/python_helpers/plot_history_live.py similarity index 99% rename from star/test_suite/custom_colors/python_helpers/HISTORY_check.py rename to star/test_suite/custom_colors/python_helpers/plot_history_live.py index 1ec53db49..0b19a8f45 100644 --- a/star/test_suite/custom_colors/python_helpers/HISTORY_check.py +++ b/star/test_suite/custom_colors/python_helpers/plot_history_live.py @@ -12,8 +12,8 @@ from matplotlib.animation import FuncAnimation # Import functions from static version for consistency -from static_HISTORY_check import MesaView # get_mesa_phase_info, -from static_HISTORY_check import read_header_columns, setup_hr_diagram_params +from plot_history import MesaView # get_mesa_phase_info, +from plot_history import read_header_columns, setup_hr_diagram_params def age_colormap_colors(ages, cmap_name="inferno", recent_fraction=0.25, stretch=5.0): diff --git a/star/test_suite/custom_colors/python_helpers/plot_newton_iter.py b/star/test_suite/custom_colors/python_helpers/plot_newton_iter.py new file mode 100644 index 000000000..bcf7b36bf --- /dev/null +++ b/star/test_suite/custom_colors/python_helpers/plot_newton_iter.py @@ -0,0 +1,1077 @@ +#!/usr/bin/env python3 +""" +plot_newton_iter.py — Interactive plotter for MESA Colors per-iteration output + +Plots stellar photometry data from Newton solver iterations with an +interactive column picker and colored terminal UI. + +Author: Niall Miller (2025) +""" + +import argparse +import os +import re +import shutil +import sys +from typing import List, Optional, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np + +try: + import mesa_reader as mr + + MESA_READER_AVAILABLE = True +except ImportError: + MESA_READER_AVAILABLE = False + + +# ============================================================================ +# TERMINAL UI +# ============================================================================ + + +def use_color() -> bool: + """Check if we should use ANSI colors.""" + return sys.stdout.isatty() and ("NO_COLOR" not in os.environ) + + +# ANSI color codes +BOLD = "\x1b[1m" if use_color() else "" +DIM = "\x1b[2m" if use_color() else "" +CYAN = "\x1b[36m" if use_color() else "" +YELL = "\x1b[33m" if use_color() else "" +GREEN = "\x1b[32m" if use_color() else "" +RED = "\x1b[31m" if use_color() else "" +RESET = "\x1b[0m" if use_color() else "" + + +def term_width() -> int: + """Get terminal width.""" + return shutil.get_terminal_size().columns + + +def print_header(title: str) -> None: + """Print a styled header.""" + width = min(70, term_width()) + print(f"\n{BOLD}{CYAN}{'═' * width}{RESET}") + print(f"{BOLD}{CYAN} {title}{RESET}") + print(f"{BOLD}{CYAN}{'═' * width}{RESET}\n") + + +def print_subheader(title: str) -> None: + """Print a styled subheader.""" + width = min(70, term_width()) + print(f"\n{BOLD}{title}{RESET}") + print(f"{DIM}{'─' * width}{RESET}") + + +def print_success(msg: str) -> None: + """Print a success message.""" + print(f"{GREEN}✓{RESET} {msg}") + + +def print_error(msg: str) -> None: + """Print an error message.""" + print(f"{RED}✗{RESET} {msg}") + + +def print_info(msg: str) -> None: + """Print an info message.""" + print(f"{CYAN}ℹ{RESET} {msg}") + + +# ============================================================================ +# COLUMN PICKER +# ============================================================================ + + +def prompt_choice( + options: Sequence[str], + label: str, + allow_back: bool = False, + max_cols: int = 3, + filter_enabled: bool = True, +) -> Optional[int]: + """ + Interactive column picker with grid display and filtering. + + Returns: + Selected index (0-based), None for quit, -1 for back + """ + if not options: + print(f"No {label} options available.") + return None + + labels = list(options) + N = len(labels) + max_label_len = max(len(s) for s in labels) + 2 + filt: Optional[Tuple[str, str]] = None # (kind, pattern) + + def apply_filter(indices: List[int]) -> List[int]: + if filt is None: + return indices + kind, patt = filt + p = patt.lower() + if kind == "substr": + return [i for i in indices if p in labels[i].lower()] + elif kind == "neg": + return [i for i in indices if p not in labels[i].lower()] + else: # regex + rx = re.compile(patt, re.I) + return [i for i in indices if rx.search(labels[i])] + + def highlight(s: str) -> str: + if not use_color() or filt is None: + return s + kind, patt = filt + if kind != "substr" or not patt: + return s + rx = re.compile(re.escape(patt), re.I) + return rx.sub(lambda m: f"{YELL}{m.group(0)}{RESET}", s) + + def grid_print(visible_ids: List[int]) -> None: + width = max(70, term_width()) + col_w = 8 + max_label_len + cols = max(1, min(max_cols, width // col_w)) + + cells = [] + for i in visible_ids: + cell = f"[{GREEN}{i:2d}{RESET}] {highlight(labels[i])}" + cells.append(cell) + + # Pad to fill grid + while len(cells) % cols: + cells.append("") + + rows = [cells[k : k + cols] for k in range(0, len(cells), cols)] + + print_subheader(f"{label} ({CYAN}{N}{RESET} columns)") + for row in rows: + print(" " + "".join(cell.ljust(col_w) for cell in row)) + + all_idx = list(range(N)) + + while True: + kept = apply_filter(all_idx) + grid_print(kept) + + # Show filter status + if filt: + kind, patt = filt + if kind == "substr": + print(f"\n{DIM}Filter: /{patt}{RESET}") + elif kind == "neg": + print(f"\n{DIM}Filter: !{patt}{RESET}") + else: + print(f"\n{DIM}Filter: //{patt}{RESET}") + + # Show controls + controls = f"\n{DIM}Enter column number" + if filter_enabled: + controls += " | /text !text //regex | clear" + if allow_back: + controls += " | b=back" + controls += f" | q=quit{RESET}" + print(controls) + + inp = input(f"{CYAN}>{RESET} ").strip() + + if not inp: + continue + + # Quit + if inp.lower() == "q": + return None + + # Back + if inp.lower() == "b" and allow_back: + return -1 + + # Clear filter + if inp.lower() == "clear": + filt = None + continue + + # Substring filter: /text + if inp.startswith("/") and not inp.startswith("//"): + filt = ("substr", inp[1:]) + continue + + # Negative filter: !text + if inp.startswith("!"): + filt = ("neg", inp[1:]) + continue + + # Regex filter: //pattern + if inp.startswith("//"): + try: + re.compile(inp[2:]) + filt = ("regex", inp[2:]) + except re.error: + print_error("Invalid regex pattern") + continue + + # Try to parse as number + try: + idx = int(inp) + if 0 <= idx < N: + return idx + else: + print_error(f"Index must be between 0 and {N - 1}") + except ValueError: + # Try to match by name + matches = [i for i, lbl in enumerate(labels) if inp.lower() == lbl.lower()] + if len(matches) == 1: + return matches[0] + elif len(matches) > 1: + print_error(f"Ambiguous: {len(matches)} columns match '{inp}'") + else: + print_error(f"Invalid input: '{inp}'") + + +def prompt_yes_no(prompt: str, default: bool = True) -> bool: + """Prompt for yes/no with default.""" + suffix = "[Y/n]" if default else "[y/N]" + inp = input(f"{prompt} {DIM}{suffix}{RESET} ").strip().lower() + if not inp: + return default + return inp in ("y", "yes") + + +# ============================================================================ +# DATA LOADING +# ============================================================================ + + +def load_iteration_data(filepath: str) -> Tuple[List[str], np.ndarray]: + """Load iteration colors data file.""" + + # Read header to get column names + with open(filepath, "r") as f: + header_line = f.readline().strip() + + # Parse column names (handle the # at the start) + if header_line.startswith("#"): + header_line = header_line[1:].strip() + + column_names = header_line.split() + + # Load the data + data = np.loadtxt(filepath, comments="#") + + # Handle single-row data + if data.ndim == 1: + data = data.reshape(1, -1) + + return column_names, data + + +def load_history_data( + history_file: str = "../LOGS/history.data", +) -> Optional[mr.MesaData]: + """Load MESA history file using mesa_reader.""" + if not MESA_READER_AVAILABLE: + print_error("mesa_reader not available. Cannot load history file.") + return None + + if not os.path.exists(history_file): + # Try alternative paths + alt_paths = ["LOGS/history.data", "./history.data"] + for alt in alt_paths: + if os.path.exists(alt): + history_file = alt + break + else: + print_error(f"History file not found: {history_file}") + return None + + try: + md = mr.MesaData(history_file) + print_success( + f"Loaded history file: {history_file} ({len(md.model_number)} models)" + ) + return md + except Exception as e: + print_error(f"Failed to load history file: {e}") + return None + + +def get_history_column(md: mr.MesaData, col_name: str) -> Optional[np.ndarray]: + """Get a column from the history data, trying various name mappings.""" + # Direct match + try: + return getattr(md, col_name) + except AttributeError: + pass + + # Try common name mappings between iteration_colors and history + name_mappings = { + "Teff": ["Teff", "log_Teff"], + "log_Teff": ["log_Teff", "Teff"], + "log_g": ["log_g", "log_surf_g"], + "R": ["radius", "log_R"], + "L": ["luminosity", "log_L"], + } + + if col_name in name_mappings: + for alt_name in name_mappings[col_name]: + try: + return getattr(md, alt_name) + except AttributeError: + continue + + return None + + +# ============================================================================ +# EXPRESSION PARSING +# ============================================================================ + + +def is_expression(s: str) -> bool: + """Check if string contains math operators (is an expression).""" + # Look for operators not at the start (to allow negative numbers) + return bool(re.search(r"(? Tuple[np.ndarray, str]: + """ + Parse and evaluate a column expression. + + Supports: + - Column names: V, Teff, log_g + - Column indices: [0], [14], [15] + - Math operators: +, -, *, / + - Parentheses: (V-U)/(B-V) + - Constants: 2.5, 1000 + + Examples: + "V-U" -> data[:, V_idx] - data[:, U_idx] + "[15]-[14]" -> data[:, 15] - data[:, 14] + "Teff/1000" -> data[:, Teff_idx] / 1000 + "2.5*log_g" -> 2.5 * data[:, log_g_idx] + + Returns: + (result_array, label_string) + """ + original_expr = expr + expr = expr.strip() + + # Build a safe evaluation namespace + namespace = {} + + # First, replace [N] index references with placeholder variable names + def replace_index(match): + idx = int(match.group(1)) + if idx < 0 or idx >= len(column_names): + raise ValueError( + f"Column index [{idx}] out of range (0-{len(column_names) - 1})" + ) + var_name = f"__col_{idx}__" + namespace[var_name] = data[:, idx] + return var_name + + expr = re.sub(r"\[(\d+)\]", replace_index, expr) + + # Sort column names by length (longest first) to avoid partial matches + sorted_names = sorted(column_names, key=len, reverse=True) + + # Replace column names with placeholder variable names + for i, name in enumerate(sorted_names): + # Find the original index + orig_idx = column_names.index(name) + var_name = f"__col_{orig_idx}__" + + # Use word boundaries to avoid partial matches + # But column names might have special chars, so escape them + pattern = r"\b" + re.escape(name) + r"\b" + if re.search(pattern, expr): + namespace[var_name] = data[:, orig_idx] + expr = re.sub(pattern, var_name, expr) + + # Add safe math functions + safe_funcs = { + "abs": np.abs, + "sqrt": np.sqrt, + "log": np.log, + "log10": np.log10, + "exp": np.exp, + "sin": np.sin, + "cos": np.cos, + "tan": np.tan, + "pi": np.pi, + } + namespace.update(safe_funcs) + + # Validate expression only contains safe characters + # Allow: digits, letters, underscores, operators, parentheses, dots, spaces + if not re.match(r"^[\w\s\+\-\*/\(\)\.\,]+$", expr): + raise ValueError(f"Invalid characters in expression: {original_expr}") + + # Evaluate + try: + result = eval(expr, {"__builtins__": {}}, namespace) + except Exception as e: + raise ValueError(f"Failed to evaluate expression '{original_expr}': {e}") + + # Ensure result is array + if np.isscalar(result): + result = np.full(data.shape[0], result) + + return result, original_expr + + +def resolve_axis( + spec: Union[int, str], column_names: List[str], data: np.ndarray +) -> Tuple[np.ndarray, str]: + """ + Resolve an axis specification to data array and label. + + Args: + spec: Column index (int), column name (str), or expression (str) + column_names: List of column names + data: Data array + + Returns: + (data_array, label_string) + """ + # If it's already an integer index + if isinstance(spec, int): + if 0 <= spec < len(column_names): + return data[:, spec], column_names[spec] + raise ValueError(f"Column index {spec} out of range") + + spec = str(spec).strip() + + # Check if it's a simple column name + if spec in column_names: + idx = column_names.index(spec) + return data[:, idx], spec + + # Check if it's a simple index like "5" or "[5]" + match = re.match(r"^\[?(\d+)\]?$", spec) + if match: + idx = int(match.group(1)) + if 0 <= idx < len(column_names): + return data[:, idx], column_names[idx] + raise ValueError(f"Column index {idx} out of range") + + # Check case-insensitive column name match + for i, name in enumerate(column_names): + if name.lower() == spec.lower(): + return data[:, i], name + + # Must be an expression + return parse_expression(spec, column_names, data) + + +def resolve_history_axis( + spec: str, + md: mr.MesaData, +) -> Optional[np.ndarray]: + """ + Resolve an axis specification to history data array. + Handles simple columns and expressions like "V-U". + """ + spec = str(spec).strip() + + # Check if it's an expression (contains operators) + if is_expression(spec): + # Parse expression for history data + return parse_history_expression(spec, md) + + # Simple column name + return get_history_column(md, spec) + + +def parse_history_expression(expr: str, md: mr.MesaData) -> Optional[np.ndarray]: + """Parse and evaluate an expression using history data.""" + expr = expr.strip() + + # Get all available column names from history + history_cols = md.bulk_names + + # Build namespace + namespace = {} + + # Sort column names by length (longest first) to avoid partial matches + sorted_names = sorted(history_cols, key=len, reverse=True) + + # Replace column names with placeholder variable names + for name in sorted_names: + var_name = f"__hist_{name}__" + pattern = r"\b" + re.escape(name) + r"\b" + if re.search(pattern, expr): + try: + namespace[var_name] = getattr(md, name) + expr = re.sub(pattern, var_name, expr) + except AttributeError: + continue + + # Add safe math functions + safe_funcs = { + "abs": np.abs, + "sqrt": np.sqrt, + "log": np.log, + "log10": np.log10, + "exp": np.exp, + "sin": np.sin, + "cos": np.cos, + "tan": np.tan, + "pi": np.pi, + } + namespace.update(safe_funcs) + + # Validate and evaluate + if not re.match(r"^[\w\s\+\-\*/\(\)\.\,]+$", expr): + return None + + try: + result = eval(expr, {"__builtins__": {}}, namespace) + return result + except Exception: + return None + + +# ============================================================================ +# PLOTTING +# ============================================================================ + + +def create_plot( + x_data: np.ndarray, + y_data: np.ndarray, + color_data: np.ndarray, + x_label: str, + y_label: str, + color_label: str, + z_data: Optional[np.ndarray] = None, + z_label: Optional[str] = None, + cmap: str = "viridis", + point_size: int = 20, + alpha: float = 0.7, + flip_y: bool = False, + history_x: Optional[np.ndarray] = None, + history_y: Optional[np.ndarray] = None, + history_z: Optional[np.ndarray] = None, +) -> plt.Figure: + """Create the plot with the provided data arrays. + + Args: + history_x, history_y, history_z: Data from MESA history file to overlay + """ + + # Create figure + if z_data is not None: + # 3D plot + fig = plt.figure(figsize=(12, 9)) + ax = fig.add_subplot(111, projection="3d") + + scatter = ax.scatter( + x_data, y_data, z_data, c=color_data, cmap=cmap, s=point_size, alpha=alpha + ) + + # Plot history data with black X markers + if history_x is not None and history_y is not None and history_z is not None: + ax.scatter( + history_x, + history_y, + history_z, + c="black", + marker="x", + s=point_size * 0.5, + linewidths=2, + label="History", + zorder=10, + ) + ax.legend(loc="best") + + ax.set_xlabel(x_label, fontsize=12, labelpad=10) + ax.set_ylabel(y_label, fontsize=12, labelpad=10) + ax.set_zlabel(z_label, fontsize=12, labelpad=10) + + title = f"Newton Iteration Colors\n{z_label} vs {y_label} vs {x_label}" + + else: + # 2D plot + fig, ax = plt.subplots(figsize=(10, 8)) + + scatter = ax.scatter( + x_data, y_data, c=color_data, cmap=cmap, s=point_size, alpha=alpha + ) + + # Plot history data with black X markers + if history_x is not None and history_y is not None: + ax.scatter( + history_x, + history_y, + c="black", + marker="x", + s=point_size * 1.1, + linewidths=1, + label="History", + zorder=10, + ) + ax.legend(loc="best") + + ax.set_xlabel(x_label, fontsize=12) + ax.set_ylabel(y_label, fontsize=12) + ax.grid(True, alpha=0.3) + + title = f"Newton Iteration Colors\n{y_label} vs {x_label}" + + if flip_y: + ax.invert_yaxis() + + # Add colorbar + cbar = plt.colorbar(scatter, ax=ax, pad=0.02) + cbar.set_label(color_label, fontsize=12) + + ax.set_title(title, fontsize=14, fontweight="bold") + plt.tight_layout() + + return fig + + +def save_figure( + fig: plt.Figure, base_name: str, formats: List[str], dpi: int = 300 +) -> None: + """Save figure in multiple formats.""" + for fmt in formats: + path = f"{base_name}.{fmt}" + fig.savefig(path, dpi=dpi, bbox_inches="tight") + print_success(f"Saved: {path}") + + +# ============================================================================ +# MAIN WORKFLOWS +# ============================================================================ + + +def prompt_axis( + column_names: List[str], + data: np.ndarray, + label: str, + allow_back: bool = True, +) -> Optional[Tuple[np.ndarray, str]]: + """ + Prompt user for axis selection - can be column or expression. + + Returns: + (data_array, label) or None for quit, or "back" string for back + """ + while True: + # Show column picker + result = prompt_choice(column_names, label, allow_back=allow_back) + + if result is None: + return None + if result == -1: + return "back" + + # User selected a column by index + return data[:, result], column_names[result] + + +def prompt_axis_or_expr( + column_names: List[str], + data: np.ndarray, + label: str, + allow_back: bool = True, +) -> Optional[Tuple[np.ndarray, str]]: + """ + Prompt user for axis - can be column selection OR custom expression. + + Returns: + (data_array, label) or None for quit, or "back" string for back + """ + N = len(column_names) + max_label_len = max(len(s) for s in column_names) + 2 + + def grid_print() -> None: + width = max(70, term_width()) + col_w = 8 + max_label_len + cols = max(1, min(3, width // col_w)) + + cells = [] + for i, name in enumerate(column_names): + cell = f"[{GREEN}{i:2d}{RESET}] {name}" + cells.append(cell) + + while len(cells) % cols: + cells.append("") + + rows = [cells[k : k + cols] for k in range(0, len(cells), cols)] + + print_subheader(f"{label} ({CYAN}{N}{RESET} columns)") + for row in rows: + print(" " + "".join(cell.ljust(col_w) for cell in row)) + + while True: + grid_print() + + print(f"\n{DIM}Enter column number, name, or expression (e.g., B-V)") + controls = "b=back | " if allow_back else "" + controls += "q=quit" + print(f"{controls}{RESET}") + + inp = input(f"{CYAN}>{RESET} ").strip() + + if not inp: + continue + + # Quit + if inp.lower() == "q": + return None + + # Back + if inp.lower() == "b" and allow_back: + return "back" + + # Try to parse as number (column index) + try: + idx = int(inp) + if 0 <= idx < N: + return data[:, idx], column_names[idx] + else: + print_error(f"Index must be between 0 and {N - 1}") + continue + except ValueError: + pass + + # Try exact column name match + if inp in column_names: + idx = column_names.index(inp) + return data[:, idx], column_names[idx] + + # Try case-insensitive column name match + for i, name in enumerate(column_names): + if name.lower() == inp.lower(): + return data[:, i], name + + # Try as expression + if is_expression(inp): + try: + result_data, result_label = parse_expression(inp, column_names, data) + return result_data, result_label + except ValueError as e: + print_error(str(e)) + continue + + print_error(f"Invalid input: '{inp}'") + + +def run_interactive(filepath: str, history_file: str = "../LOGS/history.data") -> None: + """Run the interactive column picker and plotting workflow.""" + + print_header("MESA Colors Newton Iteration Plotter") + print_info(f"Loading: {filepath}") + + # Load iteration data + column_names, data = load_iteration_data(filepath) + print_success(f"Loaded {data.shape[0]} rows, {len(column_names)} columns") + + # Load history file + md = load_history_data(history_file) + + # Select plot type + print_subheader("Select Plot Type") + print(f" [{GREEN}2{RESET}] 2D scatter plot") + print(f" [{GREEN}3{RESET}] 3D scatter plot") + + while True: + inp = input(f"{CYAN}>{RESET} ").strip() + if inp == "2": + plot_type = 2 + break + elif inp == "3": + plot_type = 3 + break + elif inp.lower() == "q": + print("Goodbye!") + return + print_error("Enter 2 or 3") + + # Select X axis + result = prompt_axis_or_expr(column_names, data, "X-axis", allow_back=False) + if result is None: + return + x_data, x_label = result + print_success(f"X-axis: {x_label}") + + # Select Y axis + result = prompt_axis_or_expr(column_names, data, "Y-axis", allow_back=True) + if result is None: + return + if result == "back": + return run_interactive(filepath, history_file) + y_data, y_label = result + print_success(f"Y-axis: {y_label}") + + flip_y = prompt_yes_no("Flip Y axis?", default=False) + + # Select Z axis (for 3D) + z_data, z_label = None, None + if plot_type == 3: + result = prompt_axis_or_expr(column_names, data, "Z-axis", allow_back=True) + if result is None: + return + if result == "back": + return run_interactive(filepath, history_file) + z_data, z_label = result + print_success(f"Z-axis: {z_label}") + + # Select color axis + result = prompt_axis_or_expr(column_names, data, "Color axis", allow_back=True) + if result is None: + return + if result == "back": + return run_interactive(filepath, history_file) + color_data, color_label = result + print_success(f"Color: {color_label}") + + # Generate output filename + base_name = "newton_iter" + axes = [x_label, y_label] + if z_label is not None: + axes.append(z_label) + # Sanitize labels for filename + safe_axes = [re.sub(r"[^\w\-]", "_", a) for a in reversed(axes)] + base_name += "_" + "_vs_".join(safe_axes) + + # Create and save plot + print_subheader("Generating Plot") + + # Get history data for overlay + history_x, history_y, history_z = None, None, None + if md is not None: + history_x = resolve_history_axis(x_label, md) + history_y = resolve_history_axis(y_label, md) + if z_label is not None: + history_z = resolve_history_axis(z_label, md) + + if history_x is not None and history_y is not None: + print_info(f"History data: {len(history_x)} points will be overlaid") + else: + missing = [] + if history_x is None: + missing.append(x_label) + if history_y is None: + missing.append(y_label) + print_info(f"Could not find history columns for: {', '.join(missing)}") + + fig = create_plot( + x_data, + y_data, + color_data, + x_label, + y_label, + color_label, + z_data, + z_label, + flip_y=flip_y, + history_x=history_x, + history_y=history_y, + history_z=history_z, + ) + save_figure(fig, base_name, ["pdf", "jpg"]) + + # Show interactive plot + print_info("Displaying interactive plot (close window to exit)") + plt.show() + + print(f"\n{GREEN}Done!{RESET}") + + +def run_batch( + filepath: str, + x_col: str, + y_col: str, + color_col: str, + z_col: Optional[str] = None, + output: Optional[str] = None, + formats: List[str] = ["pdf", "jpg"], + no_show: bool = False, + cmap: str = "viridis", + history_file: str = "../LOGS/history.data", +) -> None: + """Run in batch mode with specified columns or expressions.""" + + # Load data + column_names, data = load_iteration_data(filepath) + + # Load history file + md = load_history_data(history_file) + + # Resolve each axis (supports columns and expressions) + x_data, x_label = resolve_axis(x_col, column_names, data) + y_data, y_label = resolve_axis(y_col, column_names, data) + color_data, color_label = resolve_axis(color_col, column_names, data) + + z_data, z_label = None, None + if z_col: + z_data, z_label = resolve_axis(z_col, column_names, data) + + # Generate output name + if output is None: + axes = [x_label, y_label] + if z_label is not None: + axes.append(z_label) + # Sanitize labels for filename + safe_axes = [re.sub(r"[^\w\-]", "_", a) for a in reversed(axes)] + output = "newton_iter_" + "_vs_".join(safe_axes) + + # Get history data for overlay + history_x, history_y, history_z = None, None, None + if md is not None: + history_x = resolve_history_axis(x_label, md) + history_y = resolve_history_axis(y_label, md) + if z_label is not None: + history_z = resolve_history_axis(z_label, md) + + # Create plot + fig = create_plot( + x_data, + y_data, + color_data, + x_label, + y_label, + color_label, + z_data, + z_label, + cmap=cmap, + history_x=history_x, + history_y=history_y, + history_z=history_z, + ) + save_figure(fig, output, formats) + + if not no_show: + plt.show() + + +# ============================================================================ +# CLI ENTRY +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Plot MESA Colors per-iteration data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s # Interactive mode + %(prog)s -f SED/iteration_colors.data # Specify file + %(prog)s -x iter -y Teff -c model # Batch mode with columns + %(prog)s -x 1 -y 4 -c 0 --no-show # Use column indices + %(prog)s -x iter -y Teff -z R -c model # 3D plot + +Expression examples (color indices, ratios, etc.): + %(prog)s -x iter -y "V-U" -c model # Color index V-U + %(prog)s -x iter -y "[9]-[13]" -c Teff # Using column indices + %(prog)s -x "Teff/1000" -y log_g -c model # Scaled temperature + %(prog)s -x iter -y "B-V" -c "U-B" # Color-color diagram + +Supported expression syntax: + - Column names: V, Teff, log_g, Mag_bol + - Column indices: [0], [14], [15] + - Operators: +, -, *, / + - Parentheses: (V-U)/(B-V) + - Functions: sqrt(), log10(), abs() + """, + ) + + parser.add_argument( + "-f", + "--file", + default="SED/iteration_colors.data", + help="Path to iteration colors data file (default: SED/iteration_colors.data)", + ) + + parser.add_argument( + "--history", + default="../LOGS/history.data", + help="Path to MESA history file (default: ../LOGS/history.data)", + ) + + parser.add_argument( + "-x", "--x-col", help="X-axis column name or index (enables batch mode)" + ) + + parser.add_argument("-y", "--y-col", help="Y-axis column name or index") + + parser.add_argument( + "-z", "--z-col", help="Z-axis column name or index (for 3D plots)" + ) + + parser.add_argument("-c", "--color-col", help="Color axis column name or index") + + parser.add_argument( + "-o", "--output", help="Output filename base (without extension)" + ) + + parser.add_argument( + "--formats", + default="pdf,jpg", + help="Output formats, comma-separated (default: pdf,jpg)", + ) + + parser.add_argument( + "--cmap", default="viridis", help="Matplotlib colormap (default: viridis)" + ) + + parser.add_argument( + "--no-show", action="store_true", help="Don't display interactive plot" + ) + + parser.add_argument( + "--list-columns", action="store_true", help="List available columns and exit" + ) + + args = parser.parse_args() + + # Check file exists + if not os.path.exists(args.file): + # Try with ../ prefix + alt_path = os.path.join("..", args.file) + if os.path.exists(alt_path): + args.file = alt_path + else: + print_error(f"File not found: {args.file}") + sys.exit(1) + + # List columns mode + if args.list_columns: + column_names, data = load_iteration_data(args.file) + print_header("Available Columns") + for i, name in enumerate(column_names): + print(f" [{GREEN}{i:2d}{RESET}] {name}") + print(f"\n{DIM}Total: {len(column_names)} columns, {data.shape[0]} rows{RESET}") + return + + # Batch mode if columns specified + if args.x_col and args.y_col and args.color_col: + formats = [f.strip() for f in args.formats.split(",")] + run_batch( + filepath=args.file, + x_col=args.x_col, + y_col=args.y_col, + color_col=args.color_col, + z_col=args.z_col, + output=args.output, + formats=formats, + no_show=args.no_show, + cmap=args.cmap, + history_file=args.history, + ) + else: + # Interactive mode + run_interactive(args.file, args.history) + + +if __name__ == "__main__": + main() diff --git a/star/test_suite/custom_colors/python_helpers/static_SED_check.py b/star/test_suite/custom_colors/python_helpers/plot_sed.py similarity index 100% rename from star/test_suite/custom_colors/python_helpers/static_SED_check.py rename to star/test_suite/custom_colors/python_helpers/plot_sed.py diff --git a/star/test_suite/custom_colors/python_helpers/SED_check.py b/star/test_suite/custom_colors/python_helpers/plot_sed_live.py similarity index 100% rename from star/test_suite/custom_colors/python_helpers/SED_check.py rename to star/test_suite/custom_colors/python_helpers/plot_sed_live.py diff --git a/star/test_suite/custom_colors/python_helpers/zero_point_check.py b/star/test_suite/custom_colors/python_helpers/plot_zero_points.py similarity index 100% rename from star/test_suite/custom_colors/python_helpers/zero_point_check.py rename to star/test_suite/custom_colors/python_helpers/plot_zero_points.py diff --git a/star/test_suite/custom_colors/python_helpers/batch_run.py b/star/test_suite/custom_colors/python_helpers/run_batch.py similarity index 100% rename from star/test_suite/custom_colors/python_helpers/batch_run.py rename to star/test_suite/custom_colors/python_helpers/run_batch.py diff --git a/star/test_suite/custom_colors/python_helpers/test_paths.py b/star/test_suite/custom_colors/python_helpers/test_paths.py new file mode 100644 index 000000000..3165727cb --- /dev/null +++ b/star/test_suite/custom_colors/python_helpers/test_paths.py @@ -0,0 +1,829 @@ +#!/usr/bin/env python3 +""" +path_test.py +============ +Tests all path resolution cases for the MESA colors module's instrument, +stellar_atm, vega_sed, and colors_results_directory parameters. + +For each test the script: + 1. Copies (not symlinks) the data to the appropriate location + 2. Pre-creates the output directory (belt-and-suspenders alongside the + Fortran mkdir-p fix in open_iteration_file) + 3. Rewrites inlist_colors with the test paths + 4. Runs 'make run' long enough to pass colors initialisation + 5. Saves three proof files to the test output dir: + proof_plot.png — CMD + HR diagram + filter light curves + proof_newton.png — Teff and magnitude vs Newton iteration number + proof_sed.png — sample filter SED convolution curves + 6. Does NOT clean up — everything is left on disk for inspection + +Run from the custom_colors test suite directory: + python3 path_test.py + +Requirements: + - Run './mk' first to build + - MESA_DIR environment variable must be set + - mesa_reader, matplotlib, pandas, numpy must be importable +""" + +import os +import re +import shutil +import subprocess +import textwrap +import time +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +try: + import pandas as pd + + HAVE_PD = True +except ImportError: + HAVE_PD = False + print("WARNING: pandas not found — SED plots will be skipped.") + +try: + import mesa_reader as mr + + HAVE_MR = True +except ImportError: + HAVE_MR = False + print("WARNING: mesa_reader not found — history-based plots will be skipped.") + +# ============================================================================= +# CONFIGURATION +# ============================================================================= +os.chdir("../") + +_mesa = os.environ.get("MESA_DIR", "") + +SOURCE_INSTRUMENT = os.path.join(_mesa, "data/colors_data/filters/Generic/Johnson") +SOURCE_STELLAR_ATM = os.path.join( + _mesa, "data/colors_data/stellar_models/Kurucz2003all" +) +SOURCE_VEGA_SED = os.path.join(_mesa, "data/colors_data/stellar_models/vega_flam.csv") + +INLIST_COLORS = "inlist_colors" +LOGS_DIR = "LOGS" +RUN_TIMEOUT = 45 + +SUCCESS_MARKERS = ["step 1", "step 1", " 1 "] +FAILURE_MARKERS = [ + "Error: Could not open file", + "colors_utils.f90", + "ERROR STOP", + "STOP 1", + "Backtrace", + "failed to find", + "ERROR: failed", +] + +# ============================================================================= +# HELPERS +# ============================================================================= + +INLIST_BACKUP = INLIST_COLORS + ".bak" +CWD = os.getcwd() + +MESA_DIR = os.environ.get("MESA_DIR", "") +if not MESA_DIR: + raise EnvironmentError("MESA_DIR is not set.") + + +def backup_inlist(): + shutil.copy2(INLIST_COLORS, INLIST_BACKUP) + + +def restore_inlist(): + shutil.copy2(INLIST_BACKUP, INLIST_COLORS) + + +def resolve_outdir(output_dir: str) -> Path: + """Return an absolute Path for output_dir (handles relative and ../).""" + if os.path.isabs(output_dir): + return Path(output_dir) + return Path(os.path.normpath(os.path.join(CWD, output_dir))) + + +def patch_inlist(instrument, stellar_atm, vega_sed, colors_results_directory): + with open(INLIST_COLORS, "r") as f: + text = f.read() + + def replace_param(text, key, value): + pattern = rf"(^\s*{re.escape(key)}\s*=\s*)['\"].*?['\"]" + replacement = rf"\g<1>'{value}'" + new_text, n = re.subn(pattern, replacement, text, flags=re.MULTILINE) + if n == 0: + raise ValueError(f"Could not find parameter '{key}' in {INLIST_COLORS}") + return new_text + + text = replace_param(text, "instrument", instrument) + text = replace_param(text, "stellar_atm", stellar_atm) + text = replace_param(text, "vega_sed", vega_sed) + text = replace_param(text, "colors_results_directory", colors_results_directory) + + with open(INLIST_COLORS, "w") as f: + f.write(text) + + +def copy_data(src, dst): + dst_path = Path(dst) + dst_path.parent.mkdir(parents=True, exist_ok=True) + if dst_path.exists() or dst_path.is_symlink(): + shutil.rmtree(dst) if dst_path.is_dir() else dst_path.unlink() + shutil.copytree(src, dst) if Path(src).is_dir() else shutil.copy2(src, dst) + + +def run_star(timeout): + try: + proc = subprocess.Popen( + ["make", "run"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + try: + output, _ = proc.communicate(timeout=timeout) + except subprocess.TimeoutExpired: + proc.kill() + output, _ = proc.communicate() + except FileNotFoundError: + return False, "ERROR: 'make' not found." + + found_failure = any(m in output for m in FAILURE_MARKERS) + found_success = any(m in output for m in SUCCESS_MARKERS) + return found_success and not found_failure, output + + +def print_result(name, success, output): + status = "✓ PASS" if success else "✗ FAIL" + print(f"\n{'=' * 70}") + print(f" {status} | {name}") + print(f"{'=' * 70}") + lines = output.strip().splitlines() + relevant = [ + l + for l in lines + if any( + kw.lower() in l.lower() + for kw in [ + "color", + "filter", + "error", + "step", + "instrument", + "could not open", + "backtrace", + "stop", + "results", + ] + ) + ] + if relevant: + print(" Relevant output:") + for l in relevant[:25]: + print(f" {l}") + else: + print(" (no recognisable output — showing tail)") + for l in lines[-15:]: + print(f" {l}") + + +# ============================================================================= +# PROOF PLOTS +# ============================================================================= + +_PHASE_MAP = { + -1: ("Relax", "#C0C0C0"), + 1: ("Starting", "#E6E6FA"), + 2: ("Pre-MS", "#FF69B4"), + 3: ("ZAMS", "#00FF00"), + 4: ("IAMS", "#0000FF"), + 5: ("TAMS", "#FF8C00"), + 6: ("He-Burn", "#8A2BE2"), + 7: ("ZACHeB", "#9932CC"), + 8: ("TACHeB", "#BA55D3"), + 9: ("TP-AGB", "#8B0000"), + 10: ("C-Burn", "#FF4500"), + 14: ("WDCS", "#708090"), +} + + +def _read_filter_columns(history_file): + with open(history_file) as f: + for line in f: + if "model_number" in line: + cols = line.split() + try: + idx = cols.index("Interp_rad") + return cols[idx + 1 :] + except ValueError: + pass + return [] + + +def _phase_colors(md_raw, skip): + if hasattr(md_raw, "phase_of_evolution"): + codes = md_raw.phase_of_evolution[skip:] + else: + codes = np.zeros(len(md_raw.model_number) - skip, dtype=int) + names = [_PHASE_MAP.get(int(c), ("Unknown", "#808080"))[0] for c in codes] + colors = [_PHASE_MAP.get(int(c), ("Unknown", "#808080"))[1] for c in codes] + return names, colors + + +def _best_color_pair(filter_columns, md_raw, skip): + fc = {f.lower(): f for f in filter_columns} + + def get(name): + real = fc.get(name.lower()) + if real is None: + return None + try: + v = getattr(md_raw, real) + except AttributeError: + try: + v = md_raw.data(real) + except Exception: + return None + return v[skip:] if isinstance(v, np.ndarray) and len(v) > skip else v + + for blue, red, mag_name in [ + ("gbp", "grp", "g"), + ("b", "r", "v"), + ("b", "v", "v"), + ("v", "r", "v"), + ("g", "r", "g"), + ]: + b, r, m = get(blue), get(red), get(mag_name) + if b is not None and r is not None and m is not None: + return b - r, m, f"{blue.upper()} − {red.upper()}", mag_name.upper() + + if len(filter_columns) >= 2: + d0, d1 = get(filter_columns[0]), get(filter_columns[-1]) + if d0 is not None and d1 is not None: + return ( + d0 - d1, + d0, + f"{filter_columns[0]} − {filter_columns[-1]}", + filter_columns[0], + ) + + return None, None, "Color", "Mag" + + +def plot_history_result(test_name, out_path, success): + """ + proof_plot.png — 3 panels: CMD, HR diagram, filter light curves. + Also copies history.data into the output directory. + """ + if not HAVE_MR: + return + + history_file = os.path.join(LOGS_DIR, "history.data") + if not os.path.exists(history_file): + print(f" [proof_plot] {history_file} not found — skipping.") + return + + shutil.copy2(history_file, out_path / "history.data") + + try: + md_raw = mr.MesaData(history_file) + except Exception as e: + print(f" [proof_plot] mesa_reader failed: {e}") + return + + skip = 5 + filter_columns = _read_filter_columns(history_file) + phase_names, p_cols = _phase_colors(md_raw, skip) + color_idx, mag, xlabel, ylabel = _best_color_pair(filter_columns, md_raw, skip) + + fig, axes = plt.subplots(1, 3, figsize=(16, 5)) + fig.suptitle(f"{test_name} — {'PASS ✓' if success else 'FAIL ✗'}", fontsize=13) + + ax = axes[0] + if color_idx is not None: + ax.scatter(color_idx, mag, c=p_cols, s=15, alpha=0.7, edgecolors="none") + ax.set_xlabel(xlabel, fontsize=11) + ax.set_ylabel(ylabel, fontsize=11) + ax.invert_yaxis() + ax.xaxis.set_ticks_position("top") + ax.xaxis.set_label_position("top") + else: + ax.text( + 0.5, 0.5, "No colour pair", ha="center", va="center", transform=ax.transAxes + ) + ax.set_title("Colour–Magnitude", fontsize=11) + ax.grid(True, alpha=0.3) + + ax = axes[1] + try: + ax.scatter( + md_raw.Teff[skip:], + md_raw.log_L[skip:], + c=p_cols, + s=15, + alpha=0.7, + edgecolors="none", + ) + ax.set_xlabel("Teff (K)", fontsize=11) + ax.set_ylabel("log L/L☉", fontsize=11) + ax.invert_xaxis() + ax.xaxis.set_ticks_position("top") + ax.xaxis.set_label_position("top") + except Exception as e: + ax.text( + 0.5, + 0.5, + f"HR error:\n{e}", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=8, + ) + ax.set_title("HR Diagram", fontsize=11) + ax.grid(True, alpha=0.3) + + ax = axes[2] + try: + age = md_raw.star_age[skip:] + plotted = 0 + for filt in filter_columns: + try: + data = getattr(md_raw, filt)[skip:] + except AttributeError: + try: + data = md_raw.data(filt)[skip:] + except Exception: + continue + ax.plot(age, data, lw=1.2, label=filt, alpha=0.85) + plotted += 1 + if plotted: + ax.invert_yaxis() + ax.legend(fontsize=8, ncol=2) + else: + ax.text( + 0.5, + 0.5, + "No filter data", + ha="center", + va="center", + transform=ax.transAxes, + ) + ax.set_xlabel("Age (yr)", fontsize=11) + ax.set_ylabel("Magnitude", fontsize=11) + except Exception as e: + ax.text( + 0.5, + 0.5, + f"LC error:\n{e}", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=8, + ) + ax.set_title("Light Curves", fontsize=11) + ax.grid(True, alpha=0.3) + + seen = {} + for name, color in zip(phase_names, p_cols): + seen.setdefault(name, color) + legend_handles = [ + plt.Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=c, + markersize=8, + label=n, + markeredgecolor="none", + ) + for n, c in seen.items() + ] + fig.legend( + handles=legend_handles, + loc="lower center", + ncol=len(seen), + fontsize=9, + title="Evolutionary phase", + title_fontsize=9, + bbox_to_anchor=(0.5, -0.04), + ) + + plt.tight_layout(rect=[0, 0.06, 1, 1]) + out_png = out_path / "proof_plot.png" + plt.savefig(out_png, dpi=120, bbox_inches="tight") + plt.close(fig) + print(f" [proof_plot] saved → {out_png}") + + +def plot_newton_iter_result(test_name, out_path, success): + """ + proof_newton.png — 2 panels from iteration_colors.data: + Left: Newton iteration vs Teff, coloured by model_number + Right: Newton iteration vs first available filter magnitude, coloured by model_number + """ + iter_file = out_path / "iteration_colors.data" + if not iter_file.exists(): + print(f" [proof_newton] {iter_file} not found — skipping.") + return + + with open(iter_file) as f: + header_line = f.readline().strip().lstrip("#").strip() + col_names = header_line.split() + + try: + data = np.loadtxt(iter_file, comments="#") + except Exception as e: + print(f" [proof_newton] Failed to load {iter_file}: {e}") + return + if data.ndim == 1: + data = data.reshape(1, -1) + if data.shape[0] < 2: + print(f" [proof_newton] Too few rows ({data.shape[0]}) — skipping.") + return + + def col(name): + try: + return data[:, col_names.index(name)] + except ValueError: + return None + + model_col = col("model") + iter_col = col("iter") + teff_col = col("Teff") + + if iter_col is None or teff_col is None or model_col is None: + print( + " [proof_newton] Expected columns (model/iter/Teff) not found — skipping." + ) + return + + # First non-sentinel filter magnitude column (after Flux_bol) + mag_col_data, mag_col_label = None, None + try: + flux_idx = col_names.index("Flux_bol") + for candidate in col_names[flux_idx + 1 :]: + c = col(candidate) + if c is not None and not np.all(c == -99.0): + mag_col_data = c + mag_col_label = candidate + break + except ValueError: + pass + + n_panels = 2 if mag_col_data is not None else 1 + fig, axes = plt.subplots(1, n_panels, figsize=(6 * n_panels, 5)) + if n_panels == 1: + axes = [axes] + fig.suptitle( + f"Newton Iterations | {test_name} — {'PASS ✓' if success else 'FAIL ✗'}", + fontsize=12, + ) + + sc = axes[0].scatter( + iter_col, + teff_col, + c=model_col, + cmap="viridis", + s=15, + alpha=0.7, + edgecolors="none", + ) + axes[0].set_xlabel("Newton iteration", fontsize=11) + axes[0].set_ylabel("Teff (K)", fontsize=11) + axes[0].set_title("Teff per Newton iteration", fontsize=11) + axes[0].grid(True, alpha=0.3) + plt.colorbar(sc, ax=axes[0], label="model number") + + if mag_col_data is not None: + sc2 = axes[1].scatter( + iter_col, + mag_col_data, + c=model_col, + cmap="plasma", + s=15, + alpha=0.7, + edgecolors="none", + ) + axes[1].set_xlabel("Newton iteration", fontsize=11) + axes[1].set_ylabel(f"{mag_col_label} (mag)", fontsize=11) + axes[1].set_title(f"{mag_col_label} per Newton iteration", fontsize=11) + axes[1].invert_yaxis() + axes[1].grid(True, alpha=0.3) + plt.colorbar(sc2, ax=axes[1], label="model number") + + plt.tight_layout() + out_png = out_path / "proof_newton.png" + plt.savefig(out_png, dpi=120, bbox_inches="tight") + plt.close(fig) + print(f" [proof_newton] saved → {out_png}") + + +def plot_sed_result(test_name, out_path, success): + """ + proof_sed.png — up to 5 filter SED convolution files (*_SED.csv, excluding VEGA_*) + overlaid on the full stellar SED. + """ + if not HAVE_PD: + return + + sed_files = sorted( + [ + f + for f in out_path.iterdir() + if f.name.endswith("_SED.csv") and not f.name.startswith("VEGA") + ] + ) + + if not sed_files: + print(f" [proof_sed] No *_SED.csv files in {out_path} — skipping.") + return + + fig, ax = plt.subplots(figsize=(10, 5)) + fig.suptitle( + f"SED Sample | {test_name} — {'PASS ✓' if success else 'FAIL ✗'}", + fontsize=12, + ) + + sed_plotted = False + for f in sed_files[:5]: + try: + df = ( + pd.read_csv(f, delimiter=",", header=0) + .rename(columns=str.strip) + .dropna() + ) + except Exception as e: + print(f" [proof_sed] Could not read {f.name}: {e}") + continue + + wavelengths = df.get("wavelengths", pd.Series()).to_numpy() + flux = df.get("fluxes", pd.Series()).to_numpy() + convolved_flux = df.get("convolved_flux", pd.Series()).to_numpy() + + if not sed_plotted and len(wavelengths) > 0 and len(flux) > 0: + ax.plot( + wavelengths, flux, color="black", lw=1.5, label="Full SED", zorder=5 + ) + sed_plotted = True + + filter_label = f.stem.replace("_SED", "") + if len(wavelengths) > 0 and len(convolved_flux) > 0: + ax.plot( + wavelengths, + convolved_flux, + lw=1.2, + label=f"{filter_label} (convolved)", + alpha=0.85, + ) + + ax.set_xlabel("Wavelength (Å)", fontsize=11) + ax.set_ylabel("Flux", fontsize=11) + ax.set_title("Filter convolution check", fontsize=11) + ax.legend(loc="best", fontsize=9) + ax.ticklabel_format(style="plain", useOffset=False) + ax.set_xlim([0, 60000]) + ax.grid(True, alpha=0.3) + plt.tight_layout() + + out_png = out_path / "proof_sed.png" + plt.savefig(out_png, dpi=120, bbox_inches="tight") + plt.close(fig) + print(f" [proof_sed] saved → {out_png}") + + +def run_all_proof_plots(test_name, output_dir, success): + out_path = resolve_outdir(output_dir) + out_path.mkdir(parents=True, exist_ok=True) + plot_history_result(test_name, out_path, success) + plot_newton_iter_result(test_name, out_path, success) + plot_sed_result(test_name, out_path, success) + + +# ============================================================================= +# DEFAULT INPUT PATHS +# ============================================================================= + +_DEFAULT_INSTRUMENT = "data/colors_data/filters/Generic/Johnson" +_DEFAULT_STELLAR_ATM = "data/colors_data/stellar_models/Kurucz2003all/" +_DEFAULT_VEGA_SED = "data/colors_data/stellar_models/vega_flam.csv" + + +def _run_test(name, instrument, stellar_atm, vega_sed, out): + """Pre-create output dir, patch inlist, run, plot.""" + resolve_outdir(out).mkdir(parents=True, exist_ok=True) + patch_inlist(instrument, stellar_atm, vega_sed, out) + success, output = run_star(RUN_TIMEOUT) + print_result(name, success, output) + run_all_proof_plots(name, out, success) + return success + + +# ============================================================================= +# INPUT PATH TESTS +# ============================================================================= + + +def test01_input_mesa_dir_relative(): + return _run_test( + "INPUT: MESA_DIR-relative", + "data/colors_data/filters/Generic/Johnson", + "data/colors_data/stellar_models/Kurucz2003all/", + "data/colors_data/stellar_models/vega_flam.csv", + "test_results/test01_output", + ) + + +def test02_input_cwd_dotslash(): + stage = "./test_staged/test02_input_cwd_dotslash" + copy_data(SOURCE_INSTRUMENT, f"{stage}/filters/Generic/Johnson") + copy_data(SOURCE_STELLAR_ATM, f"{stage}/stellar_models/Kurucz2003all") + copy_data(SOURCE_VEGA_SED, f"{stage}/stellar_models/vega_flam.csv") + return _run_test( + "INPUT: CWD-relative (./)", + f"{stage}/filters/Generic/Johnson", + f"{stage}/stellar_models/Kurucz2003all", + f"{stage}/stellar_models/vega_flam.csv", + "test_results/test02_output", + ) + + +def test03_input_cwd_dotdotslash(): + stage = "../test_staged/test03_input_cwd_dotdotslash" + copy_data(SOURCE_INSTRUMENT, f"{stage}/filters/Generic/Johnson") + copy_data(SOURCE_STELLAR_ATM, f"{stage}/stellar_models/Kurucz2003all") + copy_data(SOURCE_VEGA_SED, f"{stage}/stellar_models/vega_flam.csv") + return _run_test( + "INPUT: CWD-relative (../)", + f"{stage}/filters/Generic/Johnson", + f"{stage}/stellar_models/Kurucz2003all", + f"{stage}/stellar_models/vega_flam.csv", + "test_results/test03_output", + ) + + +def test04_input_absolute(): + stage = os.path.join(CWD, "test_staged/test04_input_absolute") + copy_data(SOURCE_INSTRUMENT, os.path.join(stage, "filters/Generic/Johnson")) + copy_data(SOURCE_STELLAR_ATM, os.path.join(stage, "stellar_models/Kurucz2003all")) + copy_data(SOURCE_VEGA_SED, os.path.join(stage, "stellar_models/vega_flam.csv")) + return _run_test( + "INPUT: Absolute path", + os.path.join(stage, "filters/Generic/Johnson"), + os.path.join(stage, "stellar_models/Kurucz2003all"), + os.path.join(stage, "stellar_models/vega_flam.csv"), + "test_results/test04_output", + ) + + +def test05_input_slash_mesa_fallback(): + return _run_test( + "INPUT: /-prefixed MESA_DIR fallback", + "/data/colors_data/filters/Generic/Johnson", + "/data/colors_data/stellar_models/Kurucz2003all/", + "/data/colors_data/stellar_models/vega_flam.csv", + "test_results/test05_output", + ) + + +# ============================================================================= +# OUTPUT PATH TESTS +# ============================================================================= + + +def test06_output_plain_relative(): + return _run_test( + "OUTPUT: Plain relative", + _DEFAULT_INSTRUMENT, + _DEFAULT_STELLAR_ATM, + _DEFAULT_VEGA_SED, + "test_results/test06_output", + ) + + +def test07_output_cwd_dotslash(): + return _run_test( + "OUTPUT: CWD-relative (./)", + _DEFAULT_INSTRUMENT, + _DEFAULT_STELLAR_ATM, + _DEFAULT_VEGA_SED, + "./test_results/test07_output", + ) + + +def test08_output_dotdotslash(): + return _run_test( + "OUTPUT: CWD-relative (../)", + _DEFAULT_INSTRUMENT, + _DEFAULT_STELLAR_ATM, + _DEFAULT_VEGA_SED, + "../test_results_parent/test08_output", + ) + + +def test09_output_absolute(): + return _run_test( + "OUTPUT: Absolute path", + _DEFAULT_INSTRUMENT, + _DEFAULT_STELLAR_ATM, + _DEFAULT_VEGA_SED, + os.path.join(CWD, "test_results/test09_output"), + ) + + +def test10_output_nested_subdir(): + return _run_test( + "OUTPUT: Nested subdirectory", + _DEFAULT_INSTRUMENT, + _DEFAULT_STELLAR_ATM, + _DEFAULT_VEGA_SED, + "test_results/test10_output/nested/SED", + ) + + +# ============================================================================= +# MAIN +# ============================================================================= + +TESTS = [ + ("INPUT: MESA_DIR-relative", test01_input_mesa_dir_relative), + ("INPUT: CWD-relative (./)", test02_input_cwd_dotslash), + ("INPUT: CWD-relative (../)", test03_input_cwd_dotdotslash), + ("INPUT: Absolute path", test04_input_absolute), + ("INPUT: /-prefixed MESA fallback", test05_input_slash_mesa_fallback), + ("OUTPUT: Plain relative", test06_output_plain_relative), + ("OUTPUT: CWD-relative (./)", test07_output_cwd_dotslash), + ("OUTPUT: CWD-relative (../)", test08_output_dotdotslash), + ("OUTPUT: Absolute path", test09_output_absolute), + ("OUTPUT: Nested subdirectory", test10_output_nested_subdir), +] + +if __name__ == "__main__": + print( + textwrap.dedent(f""" + MESA Colors — Path Resolution Test Suite + ========================================= + Work directory : {CWD} + MESA_DIR : {MESA_DIR} + Run timeout : {RUN_TIMEOUT}s per test + + NOTE: No cleanup is performed. All staged input data and all output + directories are left on disk for inspection after the run. + + Staged input data : ./test_staged/ and ../test_staged/ + Output data : ./test_results/ and ../test_results_parent/ + Proof plots per test (in each output dir): + proof_plot.png — CMD + HR diagram + filter light curves + proof_newton.png — Teff and filter mag vs Newton iteration + proof_sed.png — filter SED convolution sample + """) + ) + + backup_inlist() + + results = {} + try: + for name, fn in TESTS: + print(f"\nRunning: {name} ...") + try: + results[name] = fn() + except Exception as e: + print(f" EXCEPTION: {e}") + results[name] = False + finally: + restore_inlist() + time.sleep(1) + finally: + restore_inlist() + if os.path.exists(INLIST_BACKUP): + os.remove(INLIST_BACKUP) + + print(f"\n\n{'=' * 70}") + print(" SUMMARY") + print(f"{'=' * 70}") + for name, passed in results.items(): + status = "✓ PASS" if passed else "✗ FAIL" + print(f" {status} {name}") + total = len(results) + passed = sum(results.values()) + print(f"\n {passed}/{total} tests passed") + print(f"{'=' * 70}") + print( + textwrap.dedent(f""" + Directories to inspect: + Input staging : {CWD}/test_staged/ + {os.path.dirname(CWD)}/test_staged/ + Output + plots : {CWD}/test_results/ + {os.path.dirname(CWD)}/test_results_parent/ + """) + ) + + exit(0 if passed == total else 1) diff --git a/star/test_suite/custom_colors/src/run_star_extras.f90 b/star/test_suite/custom_colors/src/run_star_extras.f90 index 8870fc6fc..8f8f163c0 100644 --- a/star/test_suite/custom_colors/src/run_star_extras.f90 +++ b/star/test_suite/custom_colors/src/run_star_extras.f90 @@ -253,4 +253,4 @@ subroutine extras_after_evolve(id, ierr) if (ierr /= 0) return end subroutine extras_after_evolve -end module run_star_extras +end module run_star_extras \ No newline at end of file