diff --git a/DESCRIPTION b/DESCRIPTION index a96a5ac..95e9727 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -9,4 +9,7 @@ License: BSD 3-Clause Encoding: UTF-8 LazyData: true Imports: + CircE (>= 1.1), sn +Remotes: + CircE=MitchellAcoustics/CircE-R diff --git a/src/soundscapy/__init__.py b/src/soundscapy/__init__.py index 95f4321..7d9b014 100644 --- a/src/soundscapy/__init__.py +++ b/src/soundscapy/__init__.py @@ -1,6 +1,8 @@ """Soundscapy is a Python library for soundscape analysis and visualisation.""" # ruff: noqa: E402 +import importlib + from loguru import logger # https://loguru.readthedocs.io/en/latest/resources/recipes.html#configuring-loguru-to-be-used-by-a-library-or-an-application @@ -89,37 +91,68 @@ # Audio module not available - this is expected if dependencies aren't installed pass -# Try to import optional SPI module -try: - from soundscapy import spi - from soundscapy.spi import ( - CentredParams, - DirectParams, - MultiSkewNorm, - cp2dp, - dp2cp, - msn, - ) - - __all__ += [ +# Optional R-backed modules (spi, satp) are loaded lazily via __getattr__ so +# that `import soundscapy` does not start the R process. R only starts when +# the user explicitly accesses one of these names. +_SPI_ATTRS: frozenset[str] = frozenset( + { + "spi", "CentredParams", "DirectParams", "MultiSkewNorm", "cp2dp", "dp2cp", "msn", - "spi", - ] -except ImportError: - # SPI module not available - pass + "spi_score", + } +) +_SATP_ATTRS: frozenset[str] = frozenset({"satp", "SATP", "CircModelE"}) -try: - from soundscapy import satp - from soundscapy.satp import SATP, CircModelE - __all__ += ["SATP", "CircModelE", "satp"] +def __getattr__(name: str): # noqa: ANN202 + """ + Lazily import optional R-backed sub-modules on first access. -except ImportError: - # SATP module not available - pass + R is not started until one of these names is explicitly accessed. + After the first access each name is stored in the module's ``__dict__``, + so subsequent lookups skip this function entirely. + """ + if name in _SPI_ATTRS: + try: + _spi = importlib.import_module("soundscapy.spi") + _g = globals() + _g["spi"] = _spi + # Pull the individual public names from the sub-module so callers + # can do ``sspy.MultiSkewNorm`` as well as ``sspy.spi.MultiSkewNorm``. + for _attr in _SPI_ATTRS - {"spi"}: + _g[_attr] = getattr(_spi, _attr) + return _g[name] + except ImportError as e: + msg = ( + f"soundscapy.{name} requires optional SPI dependencies. " + "Install with: pip install 'soundscapy[spi]'" + ) + raise ImportError(msg) from e + + if name in _SATP_ATTRS: + try: + _satp = importlib.import_module("soundscapy.satp") + _g = globals() + _g["satp"] = _satp + for _attr in _SATP_ATTRS - {"satp"}: + _g[_attr] = getattr(_satp, _attr) + return _g[name] + except ImportError as e: + msg = ( + f"soundscapy.{name} requires optional SATP dependencies. " + "Install with: pip install 'soundscapy[satp]'" + ) + raise ImportError(msg) from e + + msg = f"module 'soundscapy' has no attribute {name!r}" + raise AttributeError(msg) + + +def __dir__() -> list[str]: + """Extend dir() to include lazily-loaded optional names (PEP 562).""" + return sorted(list(globals()) + list(_SPI_ATTRS) + list(_SATP_ATTRS)) diff --git a/src/soundscapy/r_wrapper/__init__.py b/src/soundscapy/r_wrapper/__init__.py index 9d4a185..bcc1cea 100644 --- a/src/soundscapy/r_wrapper/__init__.py +++ b/src/soundscapy/r_wrapper/__init__.py @@ -13,9 +13,9 @@ raise ImportError(msg) from e # Now we can import our modules that depend on the optional packages -from ._circe_wrapper import bfgs, extract_bfgs_fit -from ._r_wrapper import PKG_SRC, get_r_session -from ._rsn_wrapper import ( +from ._circe_wrapper import bfgs, extract_bfgs_fit # noqa: F401 +from ._r_wrapper import PKG_SRC, get_r_session # noqa: F401 +from ._rsn_wrapper import ( # noqa: F401 cp2dp, dp2cp, extract_cp, @@ -25,16 +25,7 @@ selm, ) -__all__ = [ - "PKG_SRC", - "bfgs", - "cp2dp", - "dp2cp", - "extract_bfgs_fit", - "extract_cp", - "extract_dp", - "get_r_session", - "sample_msn", - "sample_mtsn", - "selm", -] +# r_wrapper is an internal implementation package. All user-facing names are +# re-exported from soundscapy.spi and soundscapy.satp. Nothing is in __all__ +# so that ``from soundscapy.r_wrapper import *`` imports nothing. +__all__: list[str] = [] diff --git a/src/soundscapy/r_wrapper/_circe_wrapper.py b/src/soundscapy/r_wrapper/_circe_wrapper.py index 2814fdc..05e9c46 100644 --- a/src/soundscapy/r_wrapper/_circe_wrapper.py +++ b/src/soundscapy/r_wrapper/_circe_wrapper.py @@ -1,6 +1,8 @@ +import numpy as np import pandas as pd from rpy2 import robjects as ro from rpy2.robjects import pandas2ri +from scipy.stats import chi2 as scipy_chi2 from soundscapy.sspylogging import get_logger from soundscapy.surveys.survey_utils import PAQ_IDS @@ -9,9 +11,6 @@ logger = get_logger() -_, _, _stats_package, _base_package, circe = get_r_session() -logger.debug("R session and packages retrieved successfully.") - def extract_bfgs_fit(bfgs_model: ro.ListVector) -> dict: """ @@ -33,10 +32,11 @@ def extract_bfgs_fit(bfgs_model: ro.ListVector) -> dict: >>> data_paqs = data[PAQ_IDS] >>> data_paqs = data_paqs.dropna() >>> data_cor = data_paqs.corr() - >>> n = data_paqs.shape[0] + >>> n = len(data_paqs) >>> circ_model = ModelType(name=CircModelE.CIRCUMPLEX) >>> circe_res = sspy.spi.bfgs( ... data_cor=data_cor, + ... n=n, ... scales=PAQ_IDS, ... m_val=3, ... equal_ang=circ_model.equal_ang, @@ -45,18 +45,38 @@ def extract_bfgs_fit(bfgs_model: ro.ListVector) -> dict: >>> fit_stats = sspy.r_wrapper.extract_bfgs_fit(circe_res) """ + # Session must already be active (bfgs_model was produced by bfgs()), but + # calling get_r_session() here ensures a clean error if somehow called in + # isolation. + get_r_session() with (ro.default_converter + pandas2ri.converter).context(): py_res = { key.lower(): ro.conversion.get_conversion().rpy2py(val) for key, val in bfgs_model.items() } - py_res["p"] = 1 - _stats_package.pchisq(py_res["chisq"], py_res["dfnull"]).item() + + # Normalize all length-1 numpy arrays to Python scalars so callers + # never need to call .item() themselves. Vectors/matrices are kept + # as-is. This also avoids DeprecationWarning from numpy >= 1.25 when + # float() or int() is applied to an ndarray with ndim > 0. + py_res = { + k: (v.item() if isinstance(v, np.ndarray) and v.shape == (1,) else v) + for k, v in py_res.items() + } + + # Use scipy instead of R's pchisq to avoid py2rpy conversion of pandas + # Series objects produced by the pandas2ri context above. + # scipy.chi2.sf(x, df) == 1 - pchisq(x, df) by definition. + # Use the model's own degrees of freedom ("d"), NOT the null-model df + # ("dfnull" = k*(k-1)/2). Using dfnull gives a wildly wrong p-value. + py_res["p"] = float(scipy_chi2.sf(py_res["chisq"], py_res["d"])) return py_res def bfgs( data_cor: pd.DataFrame, + n: int, scales: list[str] = PAQ_IDS, m_val: int = 3, *, @@ -69,6 +89,8 @@ def bfgs( Parameters ---------- data_cor (pd.DataFrame): Correlation matrix of the data. + n (int): Number of observations (participants) used to compute the correlation + matrix. Used by CircE_BFGS for chi-square and RMSEA calculations. scales (list[str], optional): List of scale names. Defaults to PAQ_IDS. m_val (int, optional): Number of dimensions. Defaults to 3. equal_ang (bool, optional): Whether to enforce equal angles constraint. @@ -92,6 +114,7 @@ def bfgs( >>> circ_model = ModelType(name=CircModelE.CIRCUMPLEX) >>> circe_res = bfgs( ... data_cor=data_cor, + ... n=n, ... scales=PAQ_IDS, ... m_val=3, ... equal_ang=circ_model.equal_ang, @@ -99,16 +122,18 @@ def bfgs( ... ) """ - n = data_cor.shape[0] - + r = get_r_session() with (ro.default_converter + pandas2ri.converter).context(): + # Only the Python→R conversion needs the pandas2ri context. + # Calling as_matrix() inside the context would cause its R-matrix + # return value to be auto-converted back to numpy by the active + # converter, producing a numpy array instead of an R matrix. r_data_cor = ro.conversion.get_conversion().py2rpy(data_cor) - r_cor_mat = _base_package.as_matrix(r_data_cor) - + r_cor_mat = r.base.as_matrix(r_data_cor) r_scales = ro.StrVector(scales) - return circe.CircE_BFGS( + return r.circe.CircE_BFGS( r_cor_mat, v_names=r_scales, m=m_val, diff --git a/src/soundscapy/r_wrapper/_r_wrapper.py b/src/soundscapy/r_wrapper/_r_wrapper.py index 5724d2a..b24c609 100644 --- a/src/soundscapy/r_wrapper/_r_wrapper.py +++ b/src/soundscapy/r_wrapper/_r_wrapper.py @@ -7,105 +7,218 @@ 3. Converting data between R and Python 4. Executing R functions for skew-normal calculations +Session state is held in a single module-level :class:`RSession` dataclass +instance (``_state``) rather than nine scattered globals. Functions that read +or write session fields do so directly — no ``global`` declarations are needed, +since mutating an object's attributes does not rebind the module-level name. +The sole exception is :func:`reset_r_session`, which creates a fresh +``RSession()`` and therefore does rebind ``_state``. + It is not intended to be used directly by end users. """ import importlib.metadata -import warnings +import os +import sys +from dataclasses import dataclass from enum import Enum from typing import Any, NoReturn +# NOTE: importing rpy2.robjects here unconditionally starts the embedded R +# process. There is no way to defer this further — R begins as soon as this +# module is loaded. The lazy __getattr__ in soundscapy/__init__.py ensures +# this module (and therefore R) is only loaded when the user first accesses +# soundscapy.spi or soundscapy.satp, not on a plain ``import soundscapy``. from rpy2 import robjects -from rpy2.robjects import numpy2ri, pandas2ri -# These are used in the docstring examples but not in the code -# They will be used by code that imports and uses this module from soundscapy.sspylogging import get_logger logger = get_logger() -# Cached values to avoid repeated checks -_r_checked = False -_sn_checked = False -_circe_checked = False +REQUIRED_R_VERSION: str = "3.6" -# Session state -_r_session = None -_sn_package = None -_circe_package = None -_stats_package = None -_base_package = None -_session_active = False -REQUIRED_R_VERSION = 3.6 -AUTO_INSTALL_R_PACKAGES = True +class PKG_SRC(str, Enum): # noqa: N801 + CIRCE = "MitchellAcoustics/CircE-R" -class PKG_SRC(str, Enum): - CIRCE = "MitchellAcoustics/CircE-R" +@dataclass +class RSession: + """ + Unified state container for the R session, loaded packages, and check flags. + + A single module-level instance (``_state``) replaces the previous nine + module-level globals. Module functions read and write its fields directly — + no ``global`` declarations are needed except in :func:`reset_r_session`, + which rebinds the name. + + :func:`get_r_session` returns ``_state`` directly once the session is + ready; callers access package objects via named fields + (``r.sn``, ``r.circe``, …). + + Attributes + ---------- + session : any + Reference to ``rpy2.robjects`` (populated by + :func:`initialize_r_session`). + sn : any + Loaded ``sn`` R package object. + stats : any + Loaded ``stats`` R package object. + base : any + Loaded ``base`` R package object. + circe : any + Loaded ``CircE`` R package object. + active : bool + ``True`` once :func:`initialize_r_session` completes successfully. + r_checked : bool + ``True`` once :func:`check_r_availability` has passed (cached to avoid + re-querying the R version on every call). + sn_checked : bool + ``True`` once :func:`check_sn_package` has passed (cached). + circe_checked : bool + ``True`` once :func:`check_circe_package` has passed (cached). + + All fields are reset to their defaults when :func:`reset_r_session` is + called, ensuring a clean re-verification on the next + :func:`get_r_session` call. + + """ + + # Package references (populated by initialize_r_session) + session: Any = None + sn: Any = None + stats: Any = None + base: Any = None + circe: Any = None + + # Session status + active: bool = False + + # One-time check flags (cleared on reset so the next call re-verifies) + r_checked: bool = False + sn_checked: bool = False + circe_checked: bool = False + + @property + def is_ready(self) -> bool: + """``True`` when the session is active and all package refs are loaded.""" + return bool( + self.active + and self.session + and self.sn + and self.stats + and self.base + and self.circe + ) + + +# Single module-level state instance. All session functions operate on this +# object; only reset_r_session() rebinds the name (via ``global _state``). +_state = RSession() + + +def _confirm_install_r_packages() -> bool: + """ + Determine whether to auto-install missing R packages. + + Checks the ``SOUNDSCAPY_AUTO_INSTALL_R`` environment variable first: + + - ``"1"``, ``"true"``, or ``"yes"`` → install without prompting (CI / scripts) + - ``"0"``, ``"false"``, or ``"no"`` → never install + + If the variable is unset the user is prompted interactively when stdin is a + TTY. In non-interactive environments the default is *not* to install. + """ + env_val = os.environ.get("SOUNDSCAPY_AUTO_INSTALL_R", "").lower() + if env_val in ("1", "true", "yes"): + return True + if env_val in ("0", "false", "no"): + return False + + if sys.stdin.isatty(): + try: + print( # noqa: T201 + "\nsoundscapy: One or more R packages required for this feature " + "are not installed.\n" + " sn → install.packages('sn')\n" + f" CircE → remotes::install_github('{PKG_SRC.CIRCE.value}')\n" + ) + response = input("Install them now via soundscapy? [y/N] ").strip().lower() + except EOFError: + pass + else: + return response in ("y", "yes") + + return False + + +def _ver(v: str) -> tuple[int, ...]: + """ + Parse a dotted version string into a comparable integer tuple. + + Avoids lexicographic pitfalls where ``"1.10" < "1.2"`` is True. + """ + return tuple(int(x) for x in v.split(".")) def check_r_availability() -> None: """ - Check if R is installed and accessible through rpy2. + Check that R is accessible and meets the minimum version requirement. + + Note: importing this module (or any rpy2-dependent module) already starts + the R process via ``from rpy2 import robjects``. This function therefore + cannot test whether R is *installed* — R is always already running by the + time it is called. Its purpose is to verify the R *version* and to cache + that result (``_state.r_checked``) so the version query runs at most once + per session. Raises ------ ImportError - If R is not installed or cannot be accessed. + If the running R version is older than :data:`REQUIRED_R_VERSION`, or + if the R version cannot be queried for any reason. """ - global _r_checked # noqa: PLW0603 - def _raise_r_not_found_error() -> NoReturn: + def _raise_r_version_too_old_error(r_version_str: str) -> NoReturn: msg = ( - "rpy2 is installed but it cannot find an R installation. " - "Please ensure R is installed and correctly configured. " - "On Linux: Install R with your package manager (e.g., apt-get install r-base)." # noqa: E501 - "On macOS: Install R from CRAN (https://cran.r-project.org/bin/macosx/). " - "On Windows: Install R from CRAN (https://cran.r-project.org/bin/windows/base/)." + f"R version {r_version_str} is too old. " + f"The 'sn' package requires R >= {REQUIRED_R_VERSION}. " + "Please upgrade your R installation." ) raise ImportError(msg) def _raise_r_access_error(e: Exception) -> NoReturn: msg = ( - f"Error accessing R installation: {e!s}. " + f"Error querying R version: {e!s}. " "Please ensure R is installed and correctly configured." ) raise ImportError(msg) - def _raise_r_version_too_old_error(r_version_num: float) -> NoReturn: - msg = ( - f"R version {r_version_num} is too old." - f"The 'sn' package requires R >= {REQUIRED_R_VERSION}." - "Please upgrade your R installation." - ) - raise ImportError(msg) - - if _r_checked: + if _state.r_checked: return try: - from rpy2 import robjects - - # Basic check to ensure R is running by getting R version r_version = robjects.r("R.version.string")[0] # type: ignore[index] logger.debug("R version: %s", r_version) - # Check if minimum R version requirements are met - # The 'sn' package requires R >= 3.6.0 - r_version_num = robjects.r( - "as.numeric(R.version$major) + as.numeric(R.version$minor)/10" - )[0] # type: ignore[index] + # Check if minimum R version requirements are met. + # Use _ver() tuple comparison to avoid float pitfalls (e.g. "2.1" minor + # parsed as 2.1/10 = 0.21 instead of the intended major.minor.patch). + # R's $minor field is like "6.0" for R 4.6.0 or "2.1" for R 4.2.1. + r_version_str = robjects.r("paste(R.version$major, R.version$minor, sep='.')")[ + 0 + ] # type: ignore[index] - if r_version_num < REQUIRED_R_VERSION: - _raise_r_version_too_old_error(r_version_num) + if _ver(r_version_str) < _ver(REQUIRED_R_VERSION): + _raise_r_version_too_old_error(r_version_str) - _r_checked = True + _state.r_checked = True except ImportError: - _raise_r_not_found_error() # Call the handler + raise # from _raise_r_version_too_old_error — don't wrap it except Exception as e: # noqa: BLE001 - _raise_r_access_error(e) # Call the handler + _raise_r_access_error(e) def check_sn_package() -> None: @@ -118,7 +231,6 @@ def check_sn_package() -> None: If the 'sn' package is not installed. """ - global _sn_checked # noqa: PLW0603 def _raise_sn_version_too_old_error(version: str) -> NoReturn: msg = ( @@ -142,7 +254,7 @@ def _raise_sn_check_error(e: Exception) -> NoReturn: ) raise ImportError(msg) - if _sn_checked: + if _state.sn_checked: return # First ensure R is available @@ -156,25 +268,21 @@ def _raise_sn_check_error(e: Exception) -> NoReturn: # Just importing to verify it exists _ = rpackages.importr("sn") - # Get package version using R to verify compatibility - from rpy2 import robjects - # Use R code to get the package version version = robjects.r('as.character(packageVersion("sn"))')[0] # type: ignore[index] logger.debug("R 'sn' package version: %s", version) # Check if package version meets requirements # The SPI implementation requires 'sn' >= 2.0.0 - if version < "2.0.0": + if _ver(version) < (2, 0, 0): _raise_sn_version_too_old_error(version) - _sn_checked = True + _state.sn_checked = True except rpackages.PackageNotInstalledError: _raise_sn_not_installed_error() - except Exception as e: - if "sn" in str(e): - # Already a more specific error about the sn package - raise # Re-raising is okay here + except ImportError: + raise # Already a specific ImportError from our helpers — re-raise as-is + except Exception as e: # noqa: BLE001 _raise_sn_check_error(e) @@ -188,12 +296,11 @@ def check_circe_package() -> None: If the 'CircE' package is not installed. """ - global _circe_checked # noqa: PLW0603 def _raise_circe_not_installed_error() -> NoReturn: msg = ( "R package 'CircE' is not installed. " - f"Please install it by running in R: remotes::install_github({PKG_SRC.CIRCE.value})" # noqa: E501 + f"Please install it by running in R: remotes::install_github('{PKG_SRC.CIRCE.value}')" # noqa: E501 ) raise ImportError(msg) @@ -201,18 +308,18 @@ def _raise_circe_version_too_old_error(version: str) -> NoReturn: msg = ( f"R 'CircE' package version {version} is too old. " "The SPI feature requires 'CircE' >= 1.1. " - f"Please upgrade the package by running in R: remotes::install_github({PKG_SRC.CIRCE.value})" # noqa: E501 + f"Please upgrade the package by running in R: remotes::install_github('{PKG_SRC.CIRCE.value}')" # noqa: E501 ) raise ImportError(msg) - def _raise_sn_check_error(e: Exception) -> NoReturn: + def _raise_circe_check_error(e: Exception) -> NoReturn: msg = ( f"Error checking for R 'CircE' package: {e!s}. " - f"Please ensure the package is installed by running in R: remotes::install_github({PKG_SRC.CIRCE.value})" # noqa: E501 + f"Please ensure the package is installed by running in R: remotes::install_github('{PKG_SRC.CIRCE.value}')" # noqa: E501 ) raise ImportError(msg) - if _circe_checked: + if _state.circe_checked: return # First ensure R is available @@ -226,28 +333,23 @@ def _raise_sn_check_error(e: Exception) -> NoReturn: # Just importing to verify it exists _ = rpackages.importr("CircE") - # Get package version using R to verify compatibility - from rpy2 import robjects - # Use R code to get the package version version = robjects.r('as.character(packageVersion("CircE"))')[0] # type: ignore[index] logger.debug("R 'CircE' package version: %s", version) - # Check if package version meets requirements - # The SPI implementation requires 'sn' >= 2.0.0 - if version < "1.1": + # Tuple comparison avoids lexicographic pitfalls ("1.10" > "1.2") + if _ver(version) < (1, 1): _raise_circe_version_too_old_error(version) - _circe_checked = True + _state.circe_checked = True except rpackages.PackageNotInstalledError: _raise_circe_not_installed_error() - except Exception as e: - if "CircE" in str(e): - # Already a more specific error about the sn package - raise # Re-raising is okay here - _raise_sn_check_error(e) + except ImportError: + raise # Already a specific ImportError from our helpers — re-raise as-is + except Exception as e: # noqa: BLE001 + _raise_circe_check_error(e) def check_dependencies() -> dict[str, Any]: @@ -282,24 +384,24 @@ def check_dependencies() -> dict[str, Any]: check_circe_package() except ImportError: - if AUTO_INSTALL_R_PACKAGES: - logger.warning( - "One or more R dependencies are missing. Attempting to auto-install required R packages..." # noqa: E501 - ) + if _confirm_install_r_packages(): + logger.info("User confirmed: installing missing R packages...") try: install_r_packages() - # After installation, check again to confirm everything is now available + # Re-check to confirm everything is now available check_r_availability() check_sn_package() check_circe_package() except Exception as install_e: msg = ( f"Auto-installation of R packages failed: {install_e!s}. " - "Please install the required R packages manually and ensure they are accessible." # noqa: E501 + "Please install the required R packages manually.\n" + " sn → install.packages('sn')\n" + f" CircE → remotes::install_github('{PKG_SRC.CIRCE.value}')" ) raise ImportError(msg) from install_e else: - raise # Re-raise the original ImportError if auto-install is not enabled + raise # User declined or non-interactive; re-raise the original ImportError # If we get here, all dependencies are available @@ -323,7 +425,7 @@ def initialize_r_session() -> dict[str, Any]: 1. Checks for R and package dependencies 2. Imports required R packages 3. Sets up the R environment - 4. Updates global session state + 4. Updates the ``_state`` singleton Returns ------- @@ -338,10 +440,8 @@ def initialize_r_session() -> dict[str, Any]: If session initialization fails. """ - global _r_session, _sn_package, _stats_package, _base_package, _session_active, _circe_package # noqa: E501, PLW0603 - # If session is already active, just return the state - if _session_active: + if _state.active: logger.debug("R session already initialized") return { "r_session": "active", @@ -357,146 +457,106 @@ def initialize_r_session() -> dict[str, Any]: try: import rpy2.robjects.packages as rpackages - from rpy2 import robjects # Import required packages - _sn_package = rpackages.importr("sn") - _circe_package = rpackages.importr("CircE") - _stats_package = rpackages.importr("stats") - _base_package = rpackages.importr("base") + _state.sn = rpackages.importr("sn") + _state.circe = rpackages.importr("CircE") + _state.stats = rpackages.importr("stats") + _state.base = rpackages.importr("base") logger.debug("Imported R packages: sn, CircE, stats, base") # Set R random seed for reproducibility robjects.r("set.seed(42)") - # Store R session - _r_session = robjects + # Store R session reference + _state.session = robjects - # Update session state - _session_active = True + # Mark session as active + _state.active = True logger.info("R session successfully initialized") - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # Activate numpy and pandas conversion - logger.debug("Activating numpy and pandas conversion") - logger.info( - "rpy2 throws a DeprecationWarning about global activation, which we're ignoring for now." # noqa: E501 - ) - # TODO(MitchellAcoustics): Remove global conversion, as recommended by rpy2 - # https://github.com/MitchellAcoustics/Soundscapy/issues/111 - numpy2ri.activate() - pandas2ri.activate() - return { "r_session": "active", - "sn_package": str(_sn_package), - "stats_package": str(_stats_package), - "base_package": str(_base_package), - "circe_package": str(_circe_package), + "sn_package": str(_state.sn), + "stats_package": str(_state.stats), + "base_package": str(_state.base), + "circe_package": str(_state.circe), **dep_info, } except Exception as e: logger.exception("Failed to initialize R session") - _session_active = False - _r_session = None - _sn_package = None - _stats_package = None - _base_package = None - _circe_package = None + # Reset to a clean state so the next call can retry from scratch. + # reset_r_session() always clears _state regardless of active flag. + reset_r_session() msg = f"Failed to initialize R session: {e!s}" raise RuntimeError(msg) from e -def shutdown_r_session() -> bool: +def reset_r_session() -> bool: """ - Shutdown the R session and clean up resources. + Unload R packages and reset all session state. - This function: - 1. Deactivates numpy conversion - 2. Resets global session state - 3. Performs garbage collection + Replaces ``_state`` with a fresh :class:`RSession` instance, clearing all + package references, the active flag, and the package-check caches. The + next call to :func:`get_r_session` will therefore re-verify package + availability and re-import everything from scratch. + + Note: the *R process itself continues running* — rpy2 does not support + terminating the embedded R interpreter. Returns ------- bool - True if successful, False otherwise. + ``True`` if successful, ``False`` if an error occurred. """ - global _r_session, _sn_package, _stats_package, _base_package, _session_active, _circe_package # noqa: E501, PLW0603 - - if not _session_active: - logger.debug("No active R session to shutdown") - return True + global _state # noqa: PLW0603 try: import gc - # Clear references to R objects - _r_session = None - _sn_package = None - _stats_package = None - _base_package = None - _circe_package = None - - # Update session state - _session_active = False - - # Force garbage collection to release R resources + was_active = _state.active + _state = RSession() gc.collect() - logger.info("R session successfully shutdown") + + if was_active: + logger.info("R session packages successfully unloaded") + else: + logger.debug("R session state cleared") except Exception: - logger.exception("Error during R session shutdown") + logger.exception("Error during R session reset") return False else: return True -def get_r_session() -> tuple[Any, Any, Any, Any, Any]: +def get_r_session() -> RSession: """ - Get the current R session and package objects. - - This function: - 1. Initializes the session if not already active - 2. Returns the session and package references + Get the current R session and package objects, initialising lazily if needed. Returns ------- - tuple[Any, Any, Any, Any, Any] - (r_session, sn_package, stats_package, base_package, circe_package) + RSession + The module-level ``_state`` instance once fully initialised. Access + package objects by name: ``r.sn``, ``r.circe``, ``r.base``, etc. Raises ------ RuntimeError - If session initialization fails. + If session initialisation fails. """ - global _r_session, _sn_package, _stats_package, _base_package, _session_active, _circe_package # noqa: E501, PLW0602 - - if not _session_active: + if not _state.active: logger.debug("R session not active, initializing") initialize_r_session() - if ( - not _session_active - or not _r_session - or not _sn_package - or not _stats_package - or not _base_package - or not _circe_package - ): + if not _state.is_ready: msg = "Failed to initialize R session" raise RuntimeError(msg) - return ( - _r_session, - _sn_package, - _stats_package, - _base_package, - _circe_package, - ) + return _state def install_r_packages(packages: list[str] | None = None) -> None: @@ -506,7 +566,7 @@ def install_r_packages(packages: list[str] | None = None) -> None: Parameters ---------- packages : list[str] | None, optional - List of R package names to install. Defaults to ["sn", "tvtnorm"]. + List of R package names to install. Defaults to ["sn", "CircE"]. Raises ------ @@ -515,7 +575,7 @@ def install_r_packages(packages: list[str] | None = None) -> None: """ if packages is None: - packages = ["sn", "tvtnorm", "CircE"] + packages = ["sn", "CircE"] check_r_availability() @@ -533,7 +593,7 @@ def install_r_packages(packages: list[str] | None = None) -> None: # Install missing packages if len(packnames_to_install) > 0: if "CircE" in packnames_to_install: - # CircE and RTHORR are only available from GitHub + # CircE is only available from GitHub remotes = rpackages.importr("remotes") remotes.install_github(PKG_SRC.CIRCE.value) packnames_to_install.remove("CircE") @@ -559,5 +619,4 @@ def is_session_active() -> bool: True if the session is active, False otherwise. """ - global _session_active # noqa: PLW0602 - return _session_active + return _state.active diff --git a/src/soundscapy/r_wrapper/_rsn_wrapper.py b/src/soundscapy/r_wrapper/_rsn_wrapper.py index c42493c..d72060c 100644 --- a/src/soundscapy/r_wrapper/_rsn_wrapper.py +++ b/src/soundscapy/r_wrapper/_rsn_wrapper.py @@ -2,7 +2,10 @@ import numpy as np import pandas as pd + +# NOTE: importing rpy2 here starts the embedded R process (see _r_wrapper.py). from rpy2 import robjects +from rpy2.robjects import numpy2ri, pandas2ri from rpy2.robjects.methods import RS4 from soundscapy.sspylogging import get_logger @@ -11,33 +14,37 @@ logger = get_logger() -_, sn, _, _, _ = get_r_session() -logger.debug("R session and packages retrieved successfully.") + +def _r2np(r_obj: object) -> np.ndarray: + """Convert a single R numeric object to a numpy array via an explicit converter.""" + with (robjects.default_converter + numpy2ri.converter).context(): + return robjects.conversion.get_conversion().rpy2py(r_obj) def selm(x: str, y: str, data: pd.DataFrame) -> RS4: + r = get_r_session() formula = f"cbind({x}, {y}) ~ 1" - return sn.selm(formula, data=data, family="SN") - - -def calc_cp(x: str, y: str, data: pd.DataFrame) -> tuple: - selm_model = selm(x, y, data) - return extract_cp(selm_model) - - -def calc_dp(x: str, y: str, data: pd.DataFrame) -> tuple: - selm_model = selm(x, y, data) - return extract_dp(selm_model) + with (robjects.default_converter + pandas2ri.converter).context(): + r_data = robjects.conversion.get_conversion().py2rpy(data) + return r.sn.selm(formula, data=r_data, family="SN") def extract_cp(selm_model: RS4) -> tuple: - cp = tuple(selm_model.slots["param"][1]) - return (cp[0].flatten(), cp[1], cp[2].flatten()) + # param[[1]] in R (0-indexed in rpy2) is the CP list: {mean, Sigma, skew} + cp_r = selm_model.slots["param"][1] + mean = _r2np(cp_r[0]).flatten() + sigma = _r2np(cp_r[1]) + skew = _r2np(cp_r[2]).flatten() + return (mean, sigma, skew) def extract_dp(selm_model: RS4) -> tuple: - dp = tuple(selm_model.slots["param"][0]) - return (dp[0].flatten(), dp[1], dp[2].flatten()) + # param[[0]] in R (0-indexed in rpy2) is the DP list: {xi, Omega, alpha} + dp_r = selm_model.slots["param"][0] + xi = _r2np(dp_r[0]).flatten() + omega = _r2np(dp_r[1]) + alpha = _r2np(dp_r[2]).flatten() + return (xi, omega, alpha) def sample_msn( @@ -47,19 +54,22 @@ def sample_msn( alpha: np.ndarray | None = None, n: int = 1000, ) -> np.ndarray: + r = get_r_session() if selm_model is not None: - return sn.rmsn(n, dp=selm_model.slots["param"][0]) - if xi is not None and omega is not None and alpha is not None: + r_result = r.sn.rmsn(n, dp=selm_model.slots["param"][0]) + elif xi is not None and omega is not None and alpha is not None: r_xi = robjects.FloatVector(xi.T) # Transpose to make it a column vector r_omega = robjects.r.matrix( robjects.FloatVector(omega.flatten()), nrow=omega.shape[0], ncol=omega.shape[1], ) # type: ignore[reportCallIssue] - r_alpha = robjects.FloatVector(alpha) # Transpose to make it a column vector - return sn.rmsn(n, xi=r_xi, Omega=r_omega, alpha=r_alpha) - msg = "Either selm_model or xi, omega, and alpha must be provided." - raise ValueError(msg) + r_alpha = robjects.FloatVector(alpha) + r_result = r.sn.rmsn(n, xi=r_xi, Omega=r_omega, alpha=r_alpha) + else: + msg = "Either selm_model or xi, omega, and alpha must be provided." + raise ValueError(msg) + return _r2np(r_result) def sample_mtsn( @@ -70,6 +80,7 @@ def sample_mtsn( a: float = -1, b: float = 1, n: int = 1000, + max_iter: int = 100_000, ) -> np.ndarray: """ Sample from a multivariate truncated skew-normal distribution. @@ -94,6 +105,10 @@ def sample_mtsn( Upper truncation bound for both dimensions, by default 1. n : int, optional Number of samples to generate, by default 1000. + max_iter : int, optional + Maximum total candidate draws before raising ``RuntimeError``. + Guards against an infinite loop when the distribution has negligible + probability mass inside ``[a, b]``. Default: 100 000. Returns ------- @@ -104,11 +119,23 @@ def sample_mtsn( ------ ValueError If neither `selm_model` nor all of `xi`, `omega`, and `alpha` are provided. + RuntimeError + If ``max_iter`` candidate draws are exhausted before ``n`` accepted + samples are collected, which indicates the distribution is largely + outside ``[a, b]``. """ - samples = np.array([[0, 0]]) - n_samples = 0 - while n_samples < n: + accepted: list[np.ndarray] = [] + n_iter = 0 + while len(accepted) < n: + if n_iter >= max_iter: + msg = ( + f"sample_mtsn: reached max_iter={max_iter} without collecting " + f"{n} accepted samples (got {len(accepted)}). " + "The distribution may have negligible mass inside " + f"[{a}, {b}]. Adjust the bounds or increase max_iter." + ) + raise RuntimeError(msg) if selm_model is not None: sample = sample_msn(selm_model, n=1) elif xi is not None and omega is not None and alpha is not None: @@ -116,20 +143,11 @@ def sample_mtsn( else: msg = "Either selm_model or xi, omega, and alpha must be provided." raise ValueError(msg) + n_iter += 1 if a <= sample[0][0] <= b and a <= sample[0][1] <= b: - samples = np.append(samples, sample, axis=0) - if n_samples == 0: - samples = samples[1:] - n_samples += 1 - - # Ensure the sample is within the bounds [a, b] for both dimensions - if not np.all((a <= samples[:, 0]) & (samples[:, 0] <= b)): - msg = f"Sample x-values are out of bounds: [{a}, {b}]" - raise ValueError(msg) - if not np.all((a <= samples[:, 1]) & (samples[:, 1] <= b)): - msg = f"Sample y-values are out of bounds: [{a}, {b}]" - raise ValueError(msg) - return samples + accepted.append(sample) + + return np.vstack(accepted) def dp2cp( @@ -158,13 +176,14 @@ def dp2cp( Tuple containing the centred parameters (mean, sigma, skew). """ + r = get_r_session() r_xi = robjects.FloatVector(xi.T) # Transpose to make it a column vector r_omega = robjects.r.matrix( robjects.FloatVector(omega.flatten()), nrow=omega.shape[0], ncol=omega.shape[1], ) # type: ignore[reportCallIssue] - r_alpha = robjects.FloatVector(alpha) # Transpose to make it a column vector + r_alpha = robjects.FloatVector(alpha) dp_r = robjects.ListVector( { @@ -174,9 +193,8 @@ def dp2cp( } ) - cp_r = sn.dp2cp(dp_r, family=family) - - return tuple(cp_r) + cp_r = r.sn.dp2cp(dp_r, family=family) + return tuple(_r2np(cp_r[i]) for i in range(len(cp_r))) def cp2dp( @@ -205,13 +223,14 @@ def cp2dp( Tuple containing the direct parameters (xi, omega, alpha). """ + r = get_r_session() r_mean = robjects.FloatVector(mean.T) # Transpose to make it a column vector r_sigma = robjects.r.matrix( robjects.FloatVector(sigma.flatten()), nrow=sigma.shape[0], ncol=sigma.shape[1], ) # type: ignore[reportCallIssue] - r_skew = robjects.FloatVector(skew) # Transpose to make it a column vector + r_skew = robjects.FloatVector(skew) cp_r = robjects.ListVector( { "mean": r_mean, @@ -219,5 +238,5 @@ def cp2dp( "skew": r_skew, } ) - dp_r = sn.cp2dp(cp_r, family=family) - return tuple(dp_r) + dp_r = r.sn.cp2dp(cp_r, family=family) + return tuple(_r2np(dp_r[i]) for i in range(len(dp_r))) diff --git a/src/soundscapy/satp/__init__.py b/src/soundscapy/satp/__init__.py index de53314..2138f35 100644 --- a/src/soundscapy/satp/__init__.py +++ b/src/soundscapy/satp/__init__.py @@ -5,8 +5,6 @@ analysis, based on the R implementation. Requires optional dependencies. """ -import warnings - # Check for required dependencies directly # This will raise ImportError if any dependency is missing try: @@ -24,6 +22,3 @@ from soundscapy.satp.circe import SATP, CircE, CircModelE, ModelType __all__ = ["SATP", "CircE", "CircModelE", "ModelType", "circe"] - -msg = "The SATP analysis module is experimental. Use with caution." -warnings.warn(msg, UserWarning, stacklevel=2) diff --git a/src/soundscapy/satp/circe.py b/src/soundscapy/satp/circe.py index d8ebe60..18c2308 100644 --- a/src/soundscapy/satp/circe.py +++ b/src/soundscapy/satp/circe.py @@ -31,7 +31,7 @@ import warnings from enum import Enum from functools import partial -from typing import Annotated +from typing import Annotated, Any import numpy as np import pandas as pd @@ -40,7 +40,6 @@ from pandera.typing.pandas import DataFrame, Series from pydantic import BeforeValidator, ConfigDict from pydantic.dataclasses import dataclass -from rpy2 import robjects as ro import soundscapy.r_wrapper as sspyr from soundscapy import PAQ_IDS, PAQ_LABELS, get_logger @@ -120,12 +119,22 @@ class ModelType: @property def equal_ang(self) -> bool: - """Check if the model uses equal angles constraint.""" - return self.name in {CircModelE.EQUAL_ANG, CircModelE.EQUAL_COM} + """ + Check if the model uses equal angles constraint. + + True for EQUAL_ANG (angles only) and CIRCUMPLEX (both constraints). + EQUAL_COM has free angles (False); UNCONSTRAINED has neither (False). + """ + return self.name in {CircModelE.EQUAL_ANG, CircModelE.CIRCUMPLEX} @property def equal_com(self) -> bool: - """Check if the model uses equal communalities constraint.""" + """ + Check if the model uses equal communalities constraint. + + True for EQUAL_COM (communalities only) and CIRCUMPLEX (both). + EQUAL_ANG has free communalities (False); UNCONSTRAINED neither (False). + """ return self.name in {CircModelE.EQUAL_COM, CircModelE.CIRCUMPLEX} @@ -164,14 +173,13 @@ def length_1_array_to_number(v: np.ndarray | float | None) -> float | None: class CircE: """A data class to hold the results of a CircE model fitting.""" - _raw_bfgs_fit: ro.ListVector model_type: ModelType datasource: str language: str n: Annotated[int, BeforeValidator(length_1_array_to_number)] m: Annotated[int, BeforeValidator(length_1_array_to_number)] chisq: Annotated[float, BeforeValidator(length_1_array_to_number)] - df: Annotated[int, BeforeValidator(length_1_array_to_number)] + d: Annotated[int, BeforeValidator(length_1_array_to_number)] p: Annotated[float, BeforeValidator(length_1_array_to_number)] cfi: Annotated[float, BeforeValidator(length_1_array_to_number)] gfi: Annotated[float, BeforeValidator(length_1_array_to_number)] @@ -186,7 +194,7 @@ class CircE: @classmethod def from_bfgs( cls, - bfgs_model: ro.ListVector, + bfgs_model: Any, datasource: str, language: str, model_type: ModelType, @@ -195,19 +203,23 @@ def from_bfgs( """Create a CircE instance from a fitted BFGS model.""" fit_stats = sspyr.extract_bfgs_fit(bfgs_model) polar_angles = None - # Only extract polar angles for models that support angular parameters - if model_type in (CircModelE.UNCONSTRAINED, CircModelE.EQUAL_COM): - polar_angles = pd.DataFrame(fit_stats.get("polar_angles", None)).T + # Only extract polar angles for models where angles are free parameters. + # model_type.name is the CircModelE enum; compare against that, not the + # ModelType dataclass wrapper (which would never compare equal to an enum). + # The R key is "polar.angles" (dot), not "polar_angles" (underscore). + if model_type.name in (CircModelE.UNCONSTRAINED, CircModelE.EQUAL_COM): + raw_pa = fit_stats.get("polar.angles") + if raw_pa is not None: + polar_angles = pd.DataFrame(raw_pa).T return cls( - _raw_bfgs_fit=bfgs_model, model_type=model_type, datasource=datasource, language=language, n=n, m=fit_stats.get("m", None), chisq=fit_stats.get("chisq", None), - df=fit_stats.get("dfnull", None), + d=fit_stats.get("d", None), p=fit_stats.get("p", None), cfi=fit_stats.get("cfi", None), gfi=fit_stats.get("gfi", None), @@ -224,6 +236,7 @@ def from_bfgs( def compute_bfgs_fit( cls, data_cor: pd.DataFrame, + n: int, datasource: str, language: str, circ_model: CircModelE, @@ -231,6 +244,22 @@ def compute_bfgs_fit( """ Compute and return a CircEResult from the given data correlation matrix. + Parameters + ---------- + data_cor + Correlation matrix of the PAQ data (8x8). + n + Number of observations (participants) used to compute ``data_cor``. + This is used by ``CircE_BFGS`` for chi-square and RMSEA calculations + and must be the row count of the *original* data, not of the + correlation matrix. + datasource + Source identifier for the dataset. + language + Language code for the dataset. + circ_model + Circumplex model type to fit. + Examples -------- >>> import soundscapy as sspy @@ -238,17 +267,17 @@ def compute_bfgs_fit( >>> data_paqs = data[PAQ_IDS] >>> data_paqs = data_paqs.dropna() >>> data_cor = data_paqs.corr() + >>> n = len(data_paqs) >>> circ_model = sspy.satp.CircModelE.CIRCUMPLEX >>> circe_res = sspy.satp.CircE.compute_bfgs_fit( - ... data_cor, "ISD", "EN", circ_model) + ... data_cor, n, "ISD", "EN", circ_model) ... """ - # Get matrix dimensions for model fitting - n = data_cor.shape[0] model_type = ModelType(name=circ_model) bfgs_model = sspyr.bfgs( data_cor=data_cor, + n=n, scales=PAQ_IDS, m_val=3, equal_ang=model_type.equal_ang, @@ -316,11 +345,16 @@ def __init__( If data doesn't conform to SATPSchema requirements """ + warnings.warn( + "The SATP analysis module is experimental. Use with caution.", + UserWarning, + stacklevel=2, + ) # Initialize processing flags and store raw data self._ipsatized = False self._raw_data = data # Validate input data against schema requirements - self.data: DataFrame = SATPSchema.validate(data, lazy=True) + self.data: pd.DataFrame = SATPSchema.validate(data, lazy=True) # Apply ipsatization if requested if ipsatize_data: @@ -357,7 +391,15 @@ def ipsatize(self) -> None: Ipsatization centers each participant's responses around their mean, removing individual response style differences while preserving relative response patterns. + + Calling this method a second time is a no-op (guarded by + ``_ipsatized``): after the first call the ``participant`` column is + dropped by ``groupby.transform``, so a second call would raise + ``KeyError``. """ + if self._ipsatized: + logger.warning("Data has already been ipsatized; skipping.") + return # Apply ipsatization transformation and update flag self.data = self._ipsatize_df(self.data, by="participant") self._ipsatized = True @@ -399,11 +441,12 @@ def run(self, circ_model: CircModelE | None = None) -> None: """ # Determine which models to fit circ_models_to_run = [*CircModelE] if circ_model is None else [circ_model] + n = len(self.data) # Fit each requested model, capturing any errors for model in circ_models_to_run: try: self.model_results[model] = CircE.compute_bfgs_fit( - self.data_corr, self.datasource, self.language, model + self.data_corr, n, self.datasource, self.language, model ) except Exception as e: # noqa: BLE001, PERF203 # Log fitting errors but continue with other models diff --git a/src/soundscapy/spi/__init__.py b/src/soundscapy/spi/__init__.py index c1b1e69..baed5b4 100644 --- a/src/soundscapy/spi/__init__.py +++ b/src/soundscapy/spi/__init__.py @@ -5,8 +5,6 @@ based on the R implementation. Requires optional dependencies. """ -import warnings - # Check for required dependencies directly # This will raise ImportError if any dependency is missing try: @@ -39,6 +37,3 @@ "msn", "spi_score", ] - -msg = "The SPI analysis module is experimental. Use with caution." -warnings.warn(msg, UserWarning, stacklevel=2) diff --git a/src/soundscapy/spi/msn.py b/src/soundscapy/spi/msn.py index 9800955..9582f8d 100644 --- a/src/soundscapy/spi/msn.py +++ b/src/soundscapy/spi/msn.py @@ -4,6 +4,24 @@ Provides classes and functions for defining, fitting, sampling, and analyzing MSN distributions, often used in soundscape analysis for modeling ISOPleasant and ISOEventful ratings. + +Classes +------- +DirectParams + Container for direct parameters (xi, omega, alpha) of a skew-normal distribution. +CentredParams + Container for centred parameters (mean, sigma, skew) of a skew-normal distribution. +MultiSkewNorm + High-level interface for fitting, sampling, and scoring a 2-D skew-normal model. + +Functions +--------- +dp2cp(dp, family="SN") + Convert a :class:`DirectParams` object to a :class:`CentredParams` object via R. +cp2dp(cp, family="SN") + Convert a :class:`CentredParams` object to a :class:`DirectParams` object via R. +spi_score(target, test) + Soundscape Perception Index: ``int((1 - KS_statistic) * 100)``. """ import warnings @@ -121,10 +139,12 @@ def from_cp(cls, cp: "CentredParams") -> "DirectParams": """ warnings.warn( "Converting from Centred Parameters to Direct Parameters " - "is not guaranteed.", + "is not guaranteed to produce a unique result. " + "Prefer constructing from Direct Parameters (xi, omega, alpha) " + "directly when possible.", UserWarning, stacklevel=2, - ) # TODO(MitchellAcoustics): Add a more specific warning message + ) dp = cp2dp(cp) return cls(dp.xi, dp.omega, dp.alpha) @@ -203,8 +223,6 @@ class MultiSkewNorm: Attributes ---------- - selm_model - The fitted SELM model. cp : CentredParams The centred parameters of the fitted model. dp : DirectParams @@ -223,19 +241,25 @@ class MultiSkewNorm: define_dp(xi, omega, alpha) Defines the direct parameters of the model. sample(n=1000, return_sample=False) - Generates a sample from the fitted model. + Generates an unrestricted sample from the fitted model. + sample_mtsn(n=1000, a=-1, b=1, return_sample=False) + Generates a truncated sample (rejection sampling within [a, b]). sspy_plot(color='blue', title=None, n=1000) Plots the joint distribution of the generated sample. - ks2ds(test) - Computes the two-sample Kolmogorov-Smirnov statistic. - spi(test) - Computes the similarity percentage index. + ks2d2s(test) + Computes the two-sample, two-dimensional Kolmogorov-Smirnov statistic. + spi_score(test) + Computes the Soundscape Perception Index (SPI). """ def __init__(self) -> None: """Initialize the MultiSkewNorm object.""" - self.selm_model = None + warnings.warn( + "The SPI analysis module is experimental. Use with caution.", + UserWarning, + stacklevel=2, + ) self.cp = None self.dp = None self.sample_data = None @@ -243,7 +267,7 @@ def __init__(self) -> None: def __repr__(self) -> str: """Return a string representation of the MultiSkewNorm object.""" - if self.cp is None and self.dp is None and self.selm_model is None: + if self.cp is None and self.dp is None: return "MultiSkewNorm() (unfitted)" return f"MultiSkewNorm(dp={self.dp})" @@ -257,7 +281,7 @@ def summary(self) -> str: indicating the model is not fitted. """ - if self.cp is None and self.dp is None and self.selm_model is None: + if self.cp is None and self.dp is None: return "MultiSkewNorm is not fitted." lines = [] if self.data is not None: @@ -300,7 +324,9 @@ def fit( if data is not None: # If data is provided, convert it to a pandas DataFrame if isinstance(data, pd.DataFrame): - # If data is already a DataFrame, no need to convert + # Rename columns to "x"/"y" on a copy so we don't mutate the + # caller's DataFrame. + data = data.copy() data.columns = ["x", "y"] elif isinstance(data, np.ndarray): @@ -325,17 +351,16 @@ def fit( msg = "Either data or x and y must be provided" raise ValueError(msg) - # Fit the model + # Fit the model, extract parameters immediately, then discard the R object. + # Storing rpy2 objects (RS4) beyond the function boundary creates a + # persistent reference into R's heap that can outlive the session. m = sspyr.selm("x", "y", data) - - # Extract the parameters cp = sspyr.extract_cp(m) dp = sspyr.extract_dp(m) self.cp = CentredParams(*cp) self.dp = DirectParams(*dp) self.data = data - self.selm_model = m def define_dp( self, xi: np.ndarray, omega: np.ndarray, alpha: np.ndarray @@ -396,8 +421,9 @@ def from_params( msg = "Either params object or xi, omega, and alpha must be provided." raise ValueError(msg) if xi is not None and omega is not None and alpha is not None: - # If xi, omega, and alpha are provided, create DirectParams + # xi/omega/alpha provided — create DirectParams and derive CP instance.dp = DirectParams(xi, omega, alpha) + instance.cp = CentredParams.from_dp(instance.dp) elif mean is not None and sigma is not None and skew is not None: # If mean, sigma, and skew are provided, create CentredParams cp = CentredParams(mean, sigma, skew) @@ -448,14 +474,12 @@ def sample( parameters (`dp`) are also not defined. """ - if self.selm_model is not None: - sample = sspyr.sample_msn(selm_model=self.selm_model, n=n) - elif self.dp is not None: + if self.dp is not None: sample = sspyr.sample_msn( xi=self.dp.xi, omega=self.dp.omega, alpha=self.dp.alpha, n=n ) else: - msg = "Either selm_model or xi, omega, and alpha must be provided." + msg = "Model is not fitted. Call fit() or define_dp() first." raise ValueError(msg) self.sample_data = sample @@ -490,14 +514,7 @@ def sample_mtsn( The generated sample if `return_sample` is True, otherwise None. """ - if self.selm_model is not None: - sample = sspyr.sample_mtsn( - selm_model=self.selm_model, - n=n, - a=a, - b=b, - ) - elif self.dp is not None: + if self.dp is not None: sample = sspyr.sample_mtsn( xi=self.dp.xi, omega=self.dp.omega, @@ -507,7 +524,7 @@ def sample_mtsn( b=b, ) else: - msg = "Either selm_model or xi, omega, and alpha must be provided." + msg = "Model is not fitted. Call fit() or define_dp() first." raise ValueError(msg) # Store the sample data diff --git a/test/satp/__init__.py b/test/satp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/satp/test_circe.py b/test/satp/test_circe.py new file mode 100644 index 0000000..9dbe7cd --- /dev/null +++ b/test/satp/test_circe.py @@ -0,0 +1,491 @@ +""" +Integration tests for the CircE wrapper. + +Reference values are derived from the CircE R package itself +(https://github.com/MitchellAcoustics/CircE-R). The vocational interests +example from CircE.BFGS.Rd is run directly via our wrapper and the output +is verified against the values printed by the R package. + +Two datasets are used: +- ``VOCATIONAL_COR`` / ``VOCATIONAL_N``: 7-variable example from the CircE + package docs — used for low-level ``bfgs()`` / ``extract_bfgs_fit()`` tests + where we have exact reference values. +- ISD data: 8-PAQ SATP-format data — used for ``CircE.compute_bfgs_fit()`` + and ``SATP`` tests (which require PAQ_IDS columns). +""" + +import numpy as np +import pandas as pd +import pytest +from scipy.stats import chi2 as scipy_chi2 + +# --------------------------------------------------------------------------- +# Vocational interests correlation matrix from CircE.BFGS.Rd example +# (N=175, 7 variables, m=3) +# --------------------------------------------------------------------------- + +_V_NAMES = [ + "Health", + "Science", + "Technology", + "Trades", + "Business Operations", + "Business Contact", + "Social", +] + +_R_LOWER = np.array( + [ + [1, 0, 0, 0, 0, 0, 0], + [0.654, 1, 0, 0, 0, 0, 0], + [0.453, 0.644, 1, 0, 0, 0, 0], + [0.251, 0.440, 0.757, 1, 0, 0, 0], + [0.122, 0.158, 0.551, 0.493, 1, 0, 0], + [0.218, 0.210, 0.570, 0.463, 0.754, 1, 0], + [0.496, 0.264, 0.366, 0.202, 0.471, 0.650, 1], + ] +) +_R_SYM = _R_LOWER + _R_LOWER.T - np.diag(np.diag(_R_LOWER)) +VOCATIONAL_COR = pd.DataFrame(_R_SYM, index=_V_NAMES, columns=_V_NAMES) +VOCATIONAL_N = 175 + + +# --------------------------------------------------------------------------- +# ISD data fixture (8-PAQ format required by compute_bfgs_fit / SATP) +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def isd_paqs(): + """Return the ISD PAQ data (dropna) for use in SATP tests.""" + import soundscapy as sspy + from soundscapy.surveys.survey_utils import PAQ_IDS + + return sspy.isd.load()[PAQ_IDS].dropna() + + +@pytest.fixture(scope="module") +def isd_cor(isd_paqs): + """Correlation matrix of ISD PAQ data.""" + return isd_paqs.corr() + + +@pytest.fixture(scope="module") +def isd_n(isd_paqs): + """Sample size of ISD PAQ data.""" + return len(isd_paqs) + + +# --------------------------------------------------------------------------- +# Tests for bfgs() / extract_bfgs_fit() wrappers +# --------------------------------------------------------------------------- + + +@pytest.mark.optional_deps("satp") +class TestBfgsWrapper: + """ + Direct tests of the bfgs() and extract_bfgs_fit() wrappers. + + All reference values come from running CircE.BFGS(R, v.names, m=3, N=175) + in R and reading the printed output. + """ + + def test_bfgs_returns_list_vector(self): + """bfgs() should return an rpy2 ListVector (the raw R model object).""" + from rpy2.robjects import ListVector + + from soundscapy.r_wrapper._circe_wrapper import bfgs + + result = bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=False, + ) + assert isinstance(result, ListVector) + + def test_extract_bfgs_fit_returns_dict(self): + """extract_bfgs_fit() should return a plain Python dict.""" + from soundscapy.r_wrapper._circe_wrapper import bfgs, extract_bfgs_fit + + fit = extract_bfgs_fit( + bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=False, + ) + ) + assert isinstance(fit, dict) + + def test_bfgs_fit_keys_present(self): + """extract_bfgs_fit() result must contain the expected fit-statistic keys.""" + from soundscapy.r_wrapper._circe_wrapper import bfgs, extract_bfgs_fit + + fit = extract_bfgs_fit( + bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=False, + ) + ) + required = { + "chisq", + "d", + "dfnull", + "p", + "rmsea", + "rmsea.l", + "rmsea.u", + "cfi", + "gfi", + "agfi", + "srmr", + "mcsc", + "m", + } + assert required.issubset(fit.keys()) + + # ------------------------------------------------------------------ + # Reference values from the CircE R package: + # model <- CircE.BFGS(R, v.names, m=3, N=175) + # + # Output (printed by R, rounded to 3 d.p.): + # chi-sq = 11.598, Model df = 5, Null df = 21 + # p = 0.041 (Ho: perfect fit) + # RMSEA = 0.087 [0.017, 0.154] + # CFI = 0.991, GFI = 0.989, AGFI = 0.938, SRMR = 0.038, MCSC = 0.29 + # ------------------------------------------------------------------ + + def test_bfgs_unconstrained_chisq(self): + """Chi-square statistic must match R package reference value (±0.01).""" + from soundscapy.r_wrapper._circe_wrapper import bfgs, extract_bfgs_fit + + fit = extract_bfgs_fit( + bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=False, + ) + ) + assert pytest.approx(fit["chisq"], abs=0.01) == 11.598 + + def test_bfgs_unconstrained_model_df(self): + """Model degrees of freedom must be 5 (not 21 = dfnull).""" + from soundscapy.r_wrapper._circe_wrapper import bfgs, extract_bfgs_fit + + fit = extract_bfgs_fit( + bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=False, + ) + ) + assert int(fit["d"]) == 5 + + def test_bfgs_unconstrained_p_value(self): + """ + p-value must be computed against model df (d=5), not null df (dfnull=21). + + R reference: p = 0.041. + If dfnull=21 were (wrongly) used the result would be ~0.95. + """ + from soundscapy.r_wrapper._circe_wrapper import bfgs, extract_bfgs_fit + + fit = extract_bfgs_fit( + bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=False, + ) + ) + assert pytest.approx(fit["p"], abs=0.005) == 0.041 + assert fit["p"] < 0.1, "p ≈ 0.95 suggests wrong df was used" + + def test_bfgs_p_equals_scipy_chi2_against_model_df(self): + """The stored p must equal scipy_chi2.sf(chisq, d) exactly.""" + from soundscapy.r_wrapper._circe_wrapper import bfgs, extract_bfgs_fit + + fit = extract_bfgs_fit( + bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=False, + ) + ) + expected_p = scipy_chi2.sf(fit["chisq"], fit["d"]) + assert pytest.approx(fit["p"], rel=1e-6) == expected_p + + def test_bfgs_unconstrained_fit_indices(self): + """Fit indices must match R package reference values (±0.001).""" + from soundscapy.r_wrapper._circe_wrapper import bfgs, extract_bfgs_fit + + fit = extract_bfgs_fit( + bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=False, + ) + ) + assert pytest.approx(float(fit["rmsea"]), abs=0.001) == 0.087 + assert pytest.approx(float(fit["rmsea.l"]), abs=0.001) == 0.017 + assert pytest.approx(float(fit["rmsea.u"]), abs=0.001) == 0.154 + assert pytest.approx(float(fit["cfi"]), abs=0.001) == 0.991 + assert pytest.approx(float(fit["gfi"]), abs=0.001) == 0.989 + assert pytest.approx(float(fit["agfi"]), abs=0.001) == 0.938 + assert pytest.approx(float(fit["srmr"]), abs=0.001) == 0.038 + assert pytest.approx(float(fit["mcsc"]), abs=0.001) == 0.290 + + def test_bfgs_equal_com_reference(self): + """ + Equal-communalities model must match R package output. + + R reference (equal_ang=False, equal_com=True): + chi-sq = 50.409, Model df = 11, RMSEA = 0.143, CFI = 0.946, SRMR = 0.060 + """ + from soundscapy.r_wrapper._circe_wrapper import bfgs, extract_bfgs_fit + + fit = extract_bfgs_fit( + bfgs( + VOCATIONAL_COR, + n=VOCATIONAL_N, + scales=_V_NAMES, + m_val=3, + equal_ang=False, + equal_com=True, + ) + ) + assert pytest.approx(fit["chisq"], abs=0.01) == 50.409 + assert int(fit["d"]) == 11 + assert pytest.approx(float(fit["rmsea"]), abs=0.001) == 0.143 + assert pytest.approx(float(fit["cfi"]), abs=0.001) == 0.946 + assert pytest.approx(float(fit["srmr"]), abs=0.001) == 0.060 + # p-value must be very small for this over-constrained model + assert fit["p"] < 0.001 + + +# --------------------------------------------------------------------------- +# Tests for CircE dataclass (uses ISD 8-PAQ data) +# --------------------------------------------------------------------------- + + +@pytest.mark.optional_deps("satp") +class TestCircEDataclass: + """Tests for CircE.compute_bfgs_fit() using ISD data (8-PAQ format).""" + + def test_compute_bfgs_fit_returns_circe(self, isd_cor, isd_n): + """compute_bfgs_fit() must return a CircE instance.""" + from soundscapy.satp.circe import CircE, CircModelE + + result = CircE.compute_bfgs_fit( + isd_cor, isd_n, "ISD", "EN", CircModelE.UNCONSTRAINED + ) + assert isinstance(result, CircE) + + def test_circe_n_matches_input(self, isd_cor, isd_n): + """CircE.n must equal the n passed to compute_bfgs_fit, not N-1 from R.""" + from soundscapy.satp.circe import CircE, CircModelE + + result = CircE.compute_bfgs_fit( + isd_cor, isd_n, "ISD", "EN", CircModelE.UNCONSTRAINED + ) + assert result.n == isd_n + + def test_circe_d_is_model_df_not_null_df(self, isd_cor, isd_n): + """ + CircE.d must be the model degrees of freedom, not dfnull=28 (8-var null). + + For 8 variables, dfnull = 8*7/2 = 28. The unconstrained model df is + much smaller. + """ + from soundscapy.satp.circe import CircE, CircModelE + + result = CircE.compute_bfgs_fit( + isd_cor, isd_n, "ISD", "EN", CircModelE.UNCONSTRAINED + ) + # Model df must be strictly less than null df (28 for 8 variables) + assert result.d < 28, ( + f"CircE.d = {result.d} — looks like dfnull was used instead of d." + ) + assert result.d > 0 + + def test_circe_p_value_formula(self, isd_cor, isd_n): + """CircE.p must equal scipy_chi2.sf(chisq, df) exactly.""" + from soundscapy.satp.circe import CircE, CircModelE + + result = CircE.compute_bfgs_fit( + isd_cor, isd_n, "ISD", "EN", CircModelE.UNCONSTRAINED + ) + expected_p = scipy_chi2.sf(result.chisq, result.d) + assert pytest.approx(result.p, rel=1e-6) == expected_p + + def test_circe_fit_stats_plausible(self, isd_cor, isd_n): + """All fit statistics must be in their valid ranges.""" + from soundscapy.satp.circe import CircE, CircModelE + + result = CircE.compute_bfgs_fit( + isd_cor, isd_n, "ISD", "EN", CircModelE.UNCONSTRAINED + ) + assert result.chisq >= 0 + assert 0.0 <= result.p <= 1.0 + assert result.rmsea >= 0 + assert 0.0 <= result.cfi <= 1.0 + assert 0.0 <= result.gfi <= 1.0 + assert result.srmr >= 0 + + def test_polar_angles_present_for_free_angle_models(self, isd_cor, isd_n): + """polar_angles must be a DataFrame for UNCONSTRAINED and EQUAL_COM models.""" + from soundscapy.satp.circe import CircE, CircModelE + + for model in (CircModelE.UNCONSTRAINED, CircModelE.EQUAL_COM): + result = CircE.compute_bfgs_fit(isd_cor, isd_n, "ISD", "EN", model) + got = type(result.polar_angles) + assert isinstance(result.polar_angles, pd.DataFrame), ( + f"{model.value}: polar_angles should be a DataFrame, got {got}" + ) + assert result.polar_angles.shape[1] == 8, ( + f"{model.value}: expected 8 variables in polar_angles columns" + ) + + def test_polar_angles_none_for_constrained_angle_models(self, isd_cor, isd_n): + """polar_angles must be None for EQUAL_ANG and CIRCUMPLEX models.""" + from soundscapy.satp.circe import CircE, CircModelE + + for model in (CircModelE.EQUAL_ANG, CircModelE.CIRCUMPLEX): + result = CircE.compute_bfgs_fit(isd_cor, isd_n, "ISD", "EN", model) + assert result.polar_angles is None, ( + f"{model.value}: polar_angles should be None for constrained models" + ) + + +# --------------------------------------------------------------------------- +# Tests for SATP class +# --------------------------------------------------------------------------- + + +@pytest.mark.optional_deps("satp") +class TestSATP: + """Integration tests for the full SATP analysis pipeline.""" + + @pytest.fixture + def satp_data(self): + """ + PAQ data from ISD using SessionID as participant grouper. + + The ISD dataset has 66 sessions, each with ~54 rows. Using SessionID + ensures every participant has many rows so ipsatization produces a + valid (non-degenerate) correlation matrix. + """ + import soundscapy as sspy + from soundscapy.surveys.survey_utils import PAQ_IDS + + raw = sspy.isd.load() + return ( + raw[[*PAQ_IDS, "SessionID"]] + .dropna(subset=PAQ_IDS) + .copy() + .rename(columns={"SessionID": "participant"}) + ) + + def test_satp_init_validates_schema(self, satp_data): + """SATP.__init__ must accept valid SATP-format data without raising.""" + from soundscapy.satp.circe import SATP + + satp = SATP(satp_data, language="EN", datasource="ISD") + assert satp is not None + + def test_satp_ipsatize(self, satp_data): + """ + After ipsatization each PAQ column must have zero mean per participant. + + The implementation uses groupby.transform, which centers each *column* + within each participant group, not each row across columns. The correct + invariant is therefore: for every (participant, PAQ_col) pair, the mean + of that column across the participant's rows is zero. + """ + from soundscapy.satp.circe import SATP + from soundscapy.surveys.survey_utils import PAQ_IDS + + # Test _ipsatize_df directly so we can still access the participant labels + # (satp.data loses the participant column after the transform). + ipsatized = SATP._ipsatize_df(satp_data, by="participant") + # groupby.transform preserves the original index, so joining back is safe. + check = ipsatized[PAQ_IDS].assign(participant=satp_data["participant"]) + group_means = check.groupby("participant")[PAQ_IDS].mean() + np.testing.assert_allclose(group_means.to_numpy(), 0.0, atol=1e-10) + + def test_satp_run_single_model(self, satp_data): + """SATP.run(circ_model=...) must populate exactly that model slot.""" + from soundscapy.satp.circe import SATP, CircE, CircModelE + + satp = SATP(satp_data, language="EN", datasource="ISD") + satp.run(circ_model=CircModelE.UNCONSTRAINED) + + assert isinstance(satp.model_results[CircModelE.UNCONSTRAINED], CircE) + for model in [ + CircModelE.EQUAL_ANG, + CircModelE.EQUAL_COM, + CircModelE.CIRCUMPLEX, + ]: + assert satp.model_results[model] is None + + def test_satp_run_captures_n_correctly(self, satp_data): + """The n stored on the CircE result must equal len(satp.data).""" + from soundscapy.satp.circe import SATP, CircModelE + + satp = SATP(satp_data, language="EN", datasource="ISD") + n_data = len(satp.data) + satp.run(circ_model=CircModelE.UNCONSTRAINED) + + result = satp.model_results[CircModelE.UNCONSTRAINED] + assert result is not None + assert result.n == n_data + + def test_satp_run_p_value_formula(self, satp_data): + """CircE.p from a full SATP run must equal scipy_chi2.sf(chisq, df).""" + from soundscapy.satp.circe import SATP, CircModelE + + satp = SATP(satp_data, language="EN", datasource="ISD") + satp.run(circ_model=CircModelE.UNCONSTRAINED) + + result = satp.model_results[CircModelE.UNCONSTRAINED] + assert result is not None + expected_p = scipy_chi2.sf(result.chisq, result.d) + assert pytest.approx(result.p, rel=1e-6) == expected_p + + def test_satp_run_all_models_errors_captured(self, satp_data): + """ + SATP.run() runs all models; convergence failures are captured. + + Failures are stored in _errors and never propagate as exceptions. + """ + from soundscapy.satp.circe import SATP + + satp = SATP(satp_data, language="EN", datasource="ISD") + satp.run() # must not raise + + n_results = sum(v is not None for v in satp.model_results.values()) + n_errors = len(satp._errors) + assert n_results + n_errors == 4 diff --git a/test/spi/test_MSN.py b/test/spi/test_MSN.py index c9f0b56..2213155 100644 --- a/test/spi/test_MSN.py +++ b/test/spi/test_MSN.py @@ -192,7 +192,6 @@ class TestMultiSkewNorm: def test_init(self): """Test initialization of MultiSkewNorm.""" msn = MultiSkewNorm() - assert msn.selm_model is None assert msn.cp is None assert msn.dp is None assert msn.sample_data is None @@ -241,7 +240,6 @@ def test_fit_with_dataframe(self): msn = MultiSkewNorm() msn.fit(data=MOCK_DF.copy()) - assert msn.selm_model is not None # Check R model object exists assert isinstance(msn.cp, CentredParams) assert isinstance(msn.dp, DirectParams) assert msn.data is not None # Add assertion for type checker @@ -260,7 +258,6 @@ def test_fit_with_numpy_array(self): expected_df = pd.DataFrame(numpy_data, columns=["x", "y"]) - assert msn.selm_model is not None assert isinstance(msn.cp, CentredParams) assert isinstance(msn.dp, DirectParams) assert msn.data is not None # Add assertion for type checker @@ -285,7 +282,6 @@ def test_fit_with_x_y(self): expected_df = pd.DataFrame({"x": MOCK_X, "y": MOCK_Y}) - assert msn.selm_model is not None assert isinstance(msn.cp, CentredParams) assert isinstance(msn.dp, DirectParams) assert msn.data is not None # Add assertion for type checker @@ -293,6 +289,20 @@ def test_fit_with_x_y(self): assert msn.cp.mean.shape == (2,) assert msn.dp.xi.shape == (2,) + def test_fit_does_not_mutate_input_dataframe(self): + """ + fit() must not rename columns on the caller's DataFrame. + + Uses non-default column names so a regression would be visible — + MOCK_DF already has columns ["x", "y"] and would pass trivially. + """ + input_df = pd.DataFrame(MOCK_DF.values, columns=["ISOPleasant", "ISOEventful"]) + msn = MultiSkewNorm() + msn.fit(data=input_df) + assert list(input_df.columns) == ["ISOPleasant", "ISOEventful"], ( + "fit() must not modify the caller's DataFrame columns" + ) + def test_fit_no_data(self): """Test fit method raises ValueError when no data is provided.""" msn = MultiSkewNorm() @@ -353,10 +363,88 @@ def test_sample_not_fitted_or_defined(self): # No mock needed msn = MultiSkewNorm() with pytest.raises( ValueError, - match="Either selm_model or xi, omega, and alpha must be provided.", + match="Model is not fitted. Call fit\\(\\) or define_dp\\(\\) first.", ): msn.sample() + # --- sample_mtsn tests --- + + def test_sample_mtsn_shape(self): + """sample_mtsn returns an (n, 2) array.""" + msn = MultiSkewNorm() + msn.define_dp(MOCK_XI, MOCK_OMEGA, MOCK_ALPHA) + result = msn.sample_mtsn(n=5, return_sample=True) + assert isinstance(result, np.ndarray) + assert result.shape == (5, 2) + + def test_sample_mtsn_within_bounds(self): + """All samples returned by sample_mtsn are within [a, b].""" + msn = MultiSkewNorm() + msn.define_dp(MOCK_XI, MOCK_OMEGA, MOCK_ALPHA) + result = msn.sample_mtsn(n=10, a=-1, b=1, return_sample=True) + assert result is not None + assert np.all(result >= -1), "Some samples are below the lower bound" + assert np.all(result <= 1), "Some samples are above the upper bound" + + def test_sample_mtsn_stores_sample(self): + """sample_mtsn stores the result in sample_data when return_sample=False.""" + msn = MultiSkewNorm() + msn.define_dp(MOCK_XI, MOCK_OMEGA, MOCK_ALPHA) + assert msn.sample_data is None + retval = msn.sample_mtsn(n=5, return_sample=False) + assert retval is None + assert isinstance(msn.sample_data, np.ndarray) + assert msn.sample_data.shape == (5, 2) + + def test_sample_mtsn_not_fitted(self): + """sample_mtsn raises ValueError when the model has no parameters.""" + msn = MultiSkewNorm() + with pytest.raises( + ValueError, + match="Model is not fitted. Call fit\\(\\) or define_dp\\(\\) first.", + ): + msn.sample_mtsn() + + # --- from_params branch tests --- + + def test_from_params_with_direct_params_object(self): + """from_params(params=DirectParams(...)) sets dp and computes cp.""" + dp = DirectParams(MOCK_XI, MOCK_OMEGA, MOCK_ALPHA) + msn = MultiSkewNorm.from_params(params=dp) + assert isinstance(msn.dp, DirectParams) + np.testing.assert_array_equal(msn.dp.xi, MOCK_XI) + assert isinstance(msn.cp, CentredParams) + np.testing.assert_allclose(msn.cp.mean, EXPECTED_MEAN, atol=1e-5) + + def test_from_params_with_centred_params_object(self): + """from_params(params=CentredParams(...)) sets cp and converts to dp.""" + cp = CentredParams(EXPECTED_MEAN, EXPECTED_SIGMA_COV, EXPECTED_SKEW) + msn = MultiSkewNorm.from_params(params=cp) + assert isinstance(msn.cp, CentredParams) + assert isinstance(msn.dp, DirectParams) + + def test_from_params_with_xi_omega_alpha_kwargs_sets_cp(self): + """from_params(xi=..., omega=..., alpha=...) must populate both dp and cp.""" + msn = MultiSkewNorm.from_params(xi=MOCK_XI, omega=MOCK_OMEGA, alpha=MOCK_ALPHA) + assert isinstance(msn.dp, DirectParams) + assert isinstance(msn.cp, CentredParams), ( + "cp must not be None when from_params is called with DP kwargs" + ) + np.testing.assert_allclose(msn.cp.mean, EXPECTED_MEAN, atol=1e-5) + + def test_from_params_with_mean_sigma_skew_kwargs(self): + """from_params(mean=..., sigma=..., skew=...) creates MultiSkewNorm from CP.""" + msn = MultiSkewNorm.from_params( + mean=EXPECTED_MEAN, sigma=EXPECTED_SIGMA_COV, skew=EXPECTED_SKEW + ) + assert isinstance(msn.cp, CentredParams) + assert isinstance(msn.dp, DirectParams) + + def test_from_params_no_args_raises(self): + """from_params() with no arguments raises ValueError.""" + with pytest.raises(ValueError, match="Either params object"): + MultiSkewNorm.from_params() + @patch("soundscapy.spi.msn.scatter") # Keep mocking the plotting call def test_sspy_plot_calls_sample_if_needed(self, mock_scatter): """Test sspy_plot calls sample if sample_data is None.""" @@ -420,15 +508,17 @@ def test_ks2d2s_calls_sample_if_needed(self): test_data_df = pd.DataFrame(rng.random((40, 2)), columns=["col1", "col2"]) result = msn.ks2d2s(test_data_df) - # TODO: still need to implement check for actual result values # Check sample was called implicitly and data was generated assert isinstance(msn.sample_data, np.ndarray) assert msn.sample_data.shape[1] == 2 # Check sample data has 2 columns assert isinstance(result, tuple) - assert isinstance(result[0], float) - assert isinstance(result[1], float) + ks_stat, p_value = result + assert isinstance(ks_stat, float) + assert isinstance(p_value, float) + assert 0.0 <= ks_stat <= 1.0, "KS statistic must be in [0, 1]" + assert 0.0 <= p_value <= 1.0, "p-value must be in [0, 1]" def test_ks2d2s(self): """Test ks2d2s converts DataFrame input to numpy array.""" @@ -441,7 +531,6 @@ def test_ks2d2s(self): df_result = msn.ks2d2s(test_data_df) np_result = msn.ks2d2s(test_data_np) - # TODO(MitchellAcoustics): still need to implement check for actual result values # noqa: E501 assert df_result == np_result, ( "Results from DataFrame and numpy array should match." @@ -474,9 +563,8 @@ def test_spi(self): spi_value = msn.spi_score(test_data) - # Check the SPI calculation - # TODO(MitchellAcoustics): Implement actual SPI calculation check assert isinstance(spi_value, int) + assert 0 <= spi_value <= 100, "SPI score must be in [0, 100]" def test_spi_with_dataframe(self): """Test spi method with DataFrame input.""" @@ -485,27 +573,28 @@ def test_spi_with_dataframe(self): spi_value = msn.spi_score(test_data_df) - # Check the SPI calculation - # TODO(MitchellAcoustics): Implement actual SPI calculation check assert isinstance(spi_value, int) + assert 0 <= spi_value <= 100, "SPI score must be in [0, 100]" @pytest.mark.optional_deps("spi") -@pytest.mark.skip( - reason="Cannot directly convert cp to dp. Need to come up with a reasonable test." -) def test_cp2dp(): - """Test cp2dp function.""" - cp_input = CentredParams(EXPECTED_MEAN, EXPECTED_SIGMA_COV, EXPECTED_SKEW) - - # Perform the conversion - dp_output = cp2dp(cp_input) - - assert isinstance(dp_output, DirectParams) - # Check if the output DP matches the original MOCK_DP used to generate the CPs - np.testing.assert_allclose(dp_output.xi, MOCK_XI, atol=1e-5) - np.testing.assert_allclose(dp_output.omega, MOCK_OMEGA, atol=1e-5) - np.testing.assert_allclose(dp_output.alpha, MOCK_ALPHA, atol=1e-5) + """Test cp2dp via a round-trip: dp → cp → dp2cp(dp) should reproduce the same CP.""" + # Convert known DP to CP + dp_input = DirectParams(MOCK_XI, MOCK_OMEGA, MOCK_ALPHA) + cp = dp2cp(dp_input) + + # Convert CP back to DP + dp_recovered = cp2dp(cp) + assert isinstance(dp_recovered, DirectParams) + + # Convert the recovered DP back to CP again; it must match the original CP. + # (The cp2dp→dp2cp round-trip is the numerically stable direction to test.) + cp_roundtrip = dp2cp(dp_recovered) + assert isinstance(cp_roundtrip, CentredParams) + np.testing.assert_allclose(cp_roundtrip.mean, cp.mean, atol=1e-4) + np.testing.assert_allclose(cp_roundtrip.sigma, cp.sigma, atol=1e-4) + np.testing.assert_allclose(cp_roundtrip.skew, cp.skew, atol=1e-4) @pytest.mark.optional_deps("spi") diff --git a/test/spi/test_r_wrapper.py b/test/spi/test_r_wrapper.py index ae046aa..be7a54e 100644 --- a/test/spi/test_r_wrapper.py +++ b/test/spi/test_r_wrapper.py @@ -5,26 +5,30 @@ They are skipped if rpy2 is not installed. """ -import os - import pytest +# === Pure-Python tests (no R required) === + + +def test_ver_basic(): + """_ver parses simple dotted version strings into integer tuples.""" + from soundscapy.r_wrapper._r_wrapper import _ver -def test_initialize_r_session_fails(): - """Test that R session initialization fails if R is not available.""" - # Skip if dependencies are actually installed - if os.environ.get("SPI_DEPS") == "1": - pytest.skip("SPI dependencies are installed") + assert _ver("3.6") == (3, 6) + assert _ver("2.0.0") == (2, 0, 0) + assert _ver("1.1") == (1, 1) - from soundscapy.r_wrapper._r_wrapper import initialize_r_session - # Simulate R not being available - with pytest.raises(ImportError) as excinfo: - initialize_r_session() +def test_ver_avoids_lexicographic_pitfall(): + """_ver must compare 1.10 as greater than 1.2 (not less, as strings would).""" + from soundscapy.r_wrapper._r_wrapper import _ver - # Check for helpful error message - assert "R installation" in str(excinfo.value) - assert "install.packages('R')" in str(excinfo.value) + assert _ver("1.10") > _ver("1.2") + assert _ver("2.0.0") > _ver("1.9.9") + assert _ver("3.6.0") == _ver("3.6.0") + + +# === End-to-end R tests === @pytest.mark.optional_deps("r") @@ -43,65 +47,46 @@ def test_initialize_r_session(self): assert res is not None, "R session should be initialized successfully" assert res["r_session"] == "active", "R session should be active" - def test_shutdown_r_session(self): - """Test R session cleanup.""" - from soundscapy.r_wrapper._r_wrapper import shutdown_r_session + def test_reset_r_session(self): + """Test R session package unloading.""" + from soundscapy.r_wrapper._r_wrapper import reset_r_session # This should not raise if R session is active - res = shutdown_r_session() + res = reset_r_session() - assert res, "R session should be shut down successfully" + assert res, "R session packages should be unloaded successfully" def test_r_session_reinitialization(self): - """Test that the R session can be reinitialized after shutdown.""" + """Test that the R session can be reinitialized after reset.""" from soundscapy.r_wrapper._r_wrapper import ( initialize_r_session, - shutdown_r_session, + reset_r_session, ) # First initialize the R session res = initialize_r_session() assert res is not None, "R session should be initialized successfully" - # Now shut it down - shutdown_res = shutdown_r_session() - assert shutdown_res, "R session should be shut down successfully" + # Now reset it (unload packages; R process keeps running) + reset_res = reset_r_session() + assert reset_res, "R session packages should be unloaded successfully" # Reinitialize the R session reinit_res = initialize_r_session() assert reinit_res is not None, "R session should be reinitialized successfully" def test_check_sn_package(self): - """Test that the R 'sn' package availability is checked.""" - # Skip if dependencies are actually installed - - if os.environ.get("SPI_DEPS") == "1": - import soundscapy.r_wrapper as sspyr + """Test that the R 'sn' package is available when R deps are installed.""" + import soundscapy.r_wrapper as sspyr - sspyr._r_wrapper.check_sn_package() - - else: - with pytest.raises(ImportError) as excinfo: # noqa: PT012 - import soundscapy.r_wrapper as sspyr - - sspyr._r_wrapper.check_sn_package() - - assert "R package 'sn'" in str(excinfo.value) - assert "install.packages('sn')" in str(excinfo.value) + # Should not raise — this test only runs (via optional_deps("r")) when + # rpy2 is present and the tox commands_pre has installed sn. + sspyr._r_wrapper.check_sn_package() def test_check_circe_package(self): - """Test that the R 'circe' package availability is checked.""" - # Skip if dependencies are actually installed - if os.environ.get("SPI_DEPS") == "1": - import soundscapy.r_wrapper as sspyr - - sspyr._r_wrapper.check_circe_package() - - else: - with pytest.raises(ImportError) as excinfo: # noqa: PT012 - import soundscapy.r_wrapper as sspyr - - sspyr._r_wrapper.check_circe_package() + """Test that the R 'CircE' package is available when R deps are installed.""" + import soundscapy.r_wrapper as sspyr - assert "R package 'CircE'" in str(excinfo.value) - assert sspyr.PKG_SRC.CIRCE in str(excinfo.value) + # Should not raise — this test only runs (via optional_deps("r")) when + # rpy2 is present and the tox commands_pre has installed CircE. + sspyr._r_wrapper.check_circe_package() diff --git a/test/test_basic.py b/test/test_basic.py index af95e0b..9f27980 100644 --- a/test/test_basic.py +++ b/test/test_basic.py @@ -32,10 +32,15 @@ def test_soundscapy_spi_module(): # Test top-level imports assert hasattr(soundscapy, "MultiSkewNorm"), "MultiSkewNorm should be available" assert hasattr(soundscapy, "dp2cp"), "dp2cp should be available" - # assert hasattr(soundscapy, "calculate_spi"), "calculate_spi should be available" - # assert hasattr(soundscapy, "calculate_spi_from_data"), ( - # "calculate_spi_from_data should be available" - # ) + assert hasattr(soundscapy, "spi_score"), "spi_score should be available" + + +@pytest.mark.optional_deps("satp") +def test_soundscapy_satp_module(): + """Test that the SATP module can be imported when dependencies are available.""" + assert hasattr(soundscapy, "satp"), "Soundscapy should have a satp module" + assert hasattr(soundscapy, "SATP"), "SATP should be available" + assert hasattr(soundscapy, "CircModelE"), "CircModelE should be available" def test_spi_import_error(): @@ -44,11 +49,20 @@ def test_spi_import_error(): if os.environ.get("SPI_DEPS") == "1": pytest.skip("SPI dependencies are installed") - # Since direct imports are now used instead of __getattr__, we need to test - # through direct access to the module which would trigger ImportError with pytest.raises(ImportError) as excinfo: import soundscapy.spi # noqa: F401 - # Check error message contains helpful instructions assert "SPI functionality requires" in str(excinfo.value) assert "soundscapy[spi]" in str(excinfo.value) + + +def test_satp_import_error(): + """Test that helpful error message is shown when SATP dependencies are missing.""" + if os.environ.get("SATP_DEPS") == "1": + pytest.skip("SATP dependencies are installed") + + with pytest.raises(ImportError) as excinfo: + import soundscapy.satp # noqa: F401 + + assert "SATP functionality requires" in str(excinfo.value) + assert "soundscapy[satp]" in str(excinfo.value) diff --git a/tox.ini b/tox.ini index 0ca414c..12cf475 100644 --- a/tox.ini +++ b/tox.ini @@ -1,4 +1,3 @@ -# tox.ini [tox] env_list = docs, @@ -40,9 +39,9 @@ allowlist_externals = Rscript commands_pre = {[testenv]commands_pre} - # Ensure R 'sn' package is available - Rscript -e "if(!require('sn')) { pak::local_install_deps() }" - Rscript -e "if(!require('CircE')) { pak::pkg_install('MitchellAcoustics/CircE-R') }" + # Install required R packages (sn from CRAN via DESCRIPTION; CircE from GitHub) + Rscript -e "if(!require('sn', quietly=TRUE)) { pak::pkg_install('sn') }" + Rscript -e "if(!require('CircE', quietly=TRUE)) { pak::pkg_install('CircE=MitchellAcoustics/CircE-R') }" commands = # Build the tutorials pytest --nbmake -n=auto docs --ignore=docs/tutorials/BinauralAnalysis.ipynb --ignore=docs/tutorials/4_Understanding_Soundscape_Perception_Index.ipynb --ignore=docs/tutorials/5_Working_with_Soundscape_Databases.ipynb --ignore=docs/tutorials/6_Soundscape_Assessment_Tutorial.ipynb --ignore=docs/tutorials/IoA_Soundscape_Assessment_Tutorial.ipynb --no-cov # BinauralAnalysis is too slow @@ -73,9 +72,9 @@ allowlist_externals = Rscript commands_pre = {[testenv]commands_pre} - # Ensure R 'sn' package is available - ; Rscript -e "if(!require('sn')) { pak::local_install_deps() }" - ; Rscript -e "if(!require('CircE')) { pak::pkg_install('MitchellAcoustics/CircE-R') }" + # Install required R packages (sn from CRAN via DESCRIPTION; CircE from GitHub) + Rscript -e "if(!require('sn', quietly=TRUE)) { pak::pkg_install('sn') }" + Rscript -e "if(!require('CircE', quietly=TRUE)) { pak::pkg_install('CircE=MitchellAcoustics/CircE-R') }" commands = # Run core tests and R-specific tests pytest --cov --cov-report=xml -k "not optional_deps or optional_deps and spi or skip_if_deps and spi" @@ -89,9 +88,9 @@ allowlist_externals = Rscript commands_pre = {[testenv]commands_pre} - # Ensure R 'sn' package is available - Rscript -e "if(!require('sn')) { pak::local_install_deps() }" - Rscript -e "if(!require('CircE')) { pak::pkg_install('MitchellAcoustics/CircE-R') }" + # Install required R packages (sn from CRAN via DESCRIPTION; CircE from GitHub) + Rscript -e "if(!require('sn', quietly=TRUE)) { pak::pkg_install('sn') }" + Rscript -e "if(!require('CircE', quietly=TRUE)) { pak::pkg_install('CircE=MitchellAcoustics/CircE-R') }" commands = # Run all tests, including SPI tests which are skipped with pytestmark pytest --cov --cov-report=xml