From ae16823df06a80e2f326d537f1b62177f7e01bd0 Mon Sep 17 00:00:00 2001 From: Florence Bockting <48919471+florence-bockting@users.noreply.github.com> Date: Tue, 13 May 2025 07:13:58 +0200 Subject: [PATCH] initial try to add domain to spline --- pyproject.toml | 1 + requirements-docs-locked.txt | 5 +- requirements-incl-optional-locked.txt | 5 ++ requirements-locked.txt | 5 ++ src/gradient_aware_harmonisation/spline.py | 63 +++++++++++-------- .../timeseries.py | 1 - src/gradient_aware_harmonisation/typing.py | 6 ++ src/gradient_aware_harmonisation/utils.py | 29 +++++++++ uv.lock | 41 ++++++++++++ 9 files changed, 129 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c23dbbd..da8fdb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "attrs>=24.3.0", "numpy>=1.25.0", "pandas>=2.2.3", + "pint>=0.24.4", ] readme = "README.md" classifiers = [ diff --git a/requirements-docs-locked.txt b/requirements-docs-locked.txt index 9623167..9e56077 100644 --- a/requirements-docs-locked.txt +++ b/requirements-docs-locked.txt @@ -26,6 +26,8 @@ defusedxml==0.7.1 exceptiongroup==1.2.2 ; python_full_version < '3.11' executing==2.1.0 fastjsonschema==2.21.1 +flexcache==0.3 +flexparser==0.4 fonttools==4.56.0 fqdn==1.5.1 ghp-import==2.1.0 @@ -97,6 +99,7 @@ parso==0.8.4 pathspec==0.12.1 pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32') pillow==11.1.0 +pint==0.24.4 platformdirs==4.3.6 prometheus-client==0.21.1 prompt-toolkit==3.0.48 @@ -136,7 +139,7 @@ tomli==2.2.1 ; python_full_version < '3.11' tornado==6.4.2 traitlets==5.14.3 types-python-dateutil==2.9.0.20241206 -typing-extensions==4.12.2 ; python_full_version < '3.13' +typing-extensions==4.12.2 tzdata==2025.1 uri-template==1.3.0 urllib3==2.3.0 diff --git a/requirements-incl-optional-locked.txt b/requirements-incl-optional-locked.txt index 76b5451..dc5ebd5 100644 --- a/requirements-incl-optional-locked.txt +++ b/requirements-incl-optional-locked.txt @@ -4,6 +4,8 @@ attrs==24.3.0 contourpy==1.3.0 ; python_full_version < '3.10' contourpy==1.3.1 ; python_full_version >= '3.10' cycler==0.12.1 +flexcache==0.3 +flexparser==0.4 fonttools==4.56.0 importlib-resources==6.5.2 ; python_full_version < '3.10' kiwisolver==1.4.7 ; python_full_version < '3.10' @@ -14,11 +16,14 @@ numpy==2.0.2 packaging==24.2 pandas==2.2.3 pillow==11.1.0 +pint==0.24.4 +platformdirs==4.3.6 pyparsing==3.2.1 python-dateutil==2.9.0.post0 pytz==2025.1 scipy==1.13.1 ; python_full_version < '3.10' scipy==1.15.2 ; python_full_version >= '3.10' six==1.17.0 +typing-extensions==4.12.2 tzdata==2025.1 zipp==3.21.0 ; python_full_version < '3.10' diff --git a/requirements-locked.txt b/requirements-locked.txt index 6b0d33c..5542ffd 100644 --- a/requirements-locked.txt +++ b/requirements-locked.txt @@ -1,9 +1,14 @@ # This file was autogenerated by uv via the following command: # uv export -o requirements-locked.txt --no-hashes --no-dev --no-emit-project attrs==24.3.0 +flexcache==0.3 +flexparser==0.4 numpy==2.0.2 pandas==2.2.3 +pint==0.24.4 +platformdirs==4.3.6 python-dateutil==2.9.0.post0 pytz==2025.1 six==1.17.0 +typing-extensions==4.12.2 tzdata==2025.1 diff --git a/src/gradient_aware_harmonisation/spline.py b/src/gradient_aware_harmonisation/spline.py index 670e42e..92f3dc7 100644 --- a/src/gradient_aware_harmonisation/spline.py +++ b/src/gradient_aware_harmonisation/spline.py @@ -4,20 +4,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol, Union, overload +from typing import TYPE_CHECKING, Any, Protocol, overload -import numpy as np +import attr import numpy.typing as npt -from attrs import define +from attrs import define, field + +from gradient_aware_harmonisation.typing import NP_FLOAT_OR_INT +from gradient_aware_harmonisation.utils import validate_domain if TYPE_CHECKING: import scipy.interpolate from typing_extensions import TypeAlias -NP_FLOAT_OR_INT: TypeAlias = Union[np.floating[Any], np.integer[Any]] -""" -Type alias for a numpy float or int (not complex) -""" NP_ARRAY_OF_FLOAT_OR_INT: TypeAlias = npt.NDArray[NP_FLOAT_OR_INT] """ @@ -30,8 +29,25 @@ class Spline(Protocol): Single spline """ - # domain: [float, float] - # """Domain over the spline can be used""" + domain: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT] = field() + """ + Domain over which spline can be evaluated + """ + + @domain.validator + def domain_validator( + self, + attribute: attr.Attribute[Any], + value: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT], + ) -> None: + """ + Validate the received values + """ + try: + validate_domain(value) + except AssertionError as exc: + msg = "The value supplied for `domain` failed validation." + raise ValueError(msg) from exc @overload def __call__(self, x: int | float) -> int | float: ... @@ -64,11 +80,10 @@ class SplineScipy: An adapter which wraps various classes from [scipy.interpolate][] """ - # domain: ClassVar[list[float, float]] = [ - # np.finfo(np.float64).tiny, - # np.finfo(np.float64).max, - # ] - # """domain of spline (reals)""" + domain: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT] + """ + Domain over which spline can be evaluated + """ scipy_spline: scipy.interpolate.BSpline | scipy.interpolate.PPoly @@ -108,7 +123,7 @@ def derivative(self) -> SplineScipy: : Derivative of self """ - return SplineScipy(self.scipy_spline.derivative()) + return SplineScipy(self.scipy_spline.derivative(), domain=self.domain) def antiderivative(self) -> SplineScipy: """ @@ -134,11 +149,10 @@ class SumOfSplines: spline_two: Spline """Second spline""" - # domain: ClassVar[list[float, float]] = [ - # np.finfo(np.float64).tiny, - # np.finfo(np.float64).max, - # ] - # """Domain of spline""" + domain: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT] + """ + Domain over which spline can be evaluated + """ @overload def __call__(self, x: int | float) -> int | float: ... @@ -204,11 +218,10 @@ class ProductOfSplines: spline_two: Spline """Second spline""" - # domain: ClassVar[list[float, float]] = [ - # np.finfo(np.float64).tiny, - # np.finfo(np.float64).max, - # ] - # """Domain of spline""" + domain: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT] + """ + Domain over which spline can be evaluated + """ @overload def __call__(self, x: int | float) -> int | float: ... diff --git a/src/gradient_aware_harmonisation/timeseries.py b/src/gradient_aware_harmonisation/timeseries.py index 33c8b40..4e53797 100644 --- a/src/gradient_aware_harmonisation/timeseries.py +++ b/src/gradient_aware_harmonisation/timeseries.py @@ -76,5 +76,4 @@ def to_spline(self, **kwargs: Any) -> SplineScipy: scipy.interpolate.make_interp_spline( x=self.time_axis, y=self.values, **kwargs ), - # domain=(self.time_axis.min(), self.time_axis.max()) ) diff --git a/src/gradient_aware_harmonisation/typing.py b/src/gradient_aware_harmonisation/typing.py index 05d6a35..4aadaea 100644 --- a/src/gradient_aware_harmonisation/typing.py +++ b/src/gradient_aware_harmonisation/typing.py @@ -8,6 +8,7 @@ import numpy as np import numpy.typing as npt +import pint.facets.numpy.quantity if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -21,3 +22,8 @@ """ Type alias for an array of numpy float or int (not complex) """ + +PINT_SCALAR: TypeAlias = pint.facets.numpy.quantity.NumpyQuantity[NP_FLOAT_OR_INT] +""" +Type alias for a pint quantity that wraps a numpy scalar +""" diff --git a/src/gradient_aware_harmonisation/utils.py b/src/gradient_aware_harmonisation/utils.py index c33eca9..4d8ced0 100644 --- a/src/gradient_aware_harmonisation/utils.py +++ b/src/gradient_aware_harmonisation/utils.py @@ -18,6 +18,7 @@ SplineScipy, SumOfSplines, ) +from gradient_aware_harmonisation.typing import NP_FLOAT_OR_INT, PINT_SCALAR class GetHarmonisedSplineLike(Protocol): @@ -185,3 +186,31 @@ def add_constant_to_spline(in_spline: Spline, constant: float | int) -> Spline: ) ), ) + + +def validate_domain( + domain: Union[ + tuple[PINT_SCALAR, PINT_SCALAR], tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT] + ], +) -> None: + """ + Check that domain values are valid + + Parameters + ---------- + domain + Domain to check + + Raises + ------ + AssertionError + `len(domain) != 2` or `domain[1] <= domain[0]`. + """ + expected_domain_length = 2 + if len(domain) != expected_domain_length: + raise AssertionError(len(domain)) + + if domain[1] <= domain[0]: + msg = f"domain[1] must be greater than domain[0]. Received {domain=}." + + raise AssertionError(msg) diff --git a/uv.lock b/uv.lock index dc71835..6a52b0c 100644 --- a/uv.lock +++ b/uv.lock @@ -665,6 +665,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/f8/feced7779d755758a52d1f6635d990b8d98dc0a29fa568bbe0625f18fdf3/filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0", size = 16163 }, ] +[[package]] +name = "flexcache" +version = "0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/b0/8a21e330561c65653d010ef112bf38f60890051d244ede197ddaa08e50c1/flexcache-0.3.tar.gz", hash = "sha256:18743bd5a0621bfe2cf8d519e4c3bfdf57a269c15d1ced3fb4b64e0ff4600656", size = 15816 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/cd/c883e1a7c447479d6e13985565080e3fea88ab5a107c21684c813dba1875/flexcache-0.3-py3-none-any.whl", hash = "sha256:d43c9fea82336af6e0115e308d9d33a185390b8346a017564611f1466dcd2e32", size = 13263 }, +] + +[[package]] +name = "flexparser" +version = "0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/99/b4de7e39e8eaf8207ba1a8fa2241dd98b2ba72ae6e16960d8351736d8702/flexparser-0.4.tar.gz", hash = "sha256:266d98905595be2ccc5da964fe0a2c3526fbbffdc45b65b3146d75db992ef6b2", size = 31799 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/5e/3be305568fe5f34448807976dc82fc151d76c3e0e03958f34770286278c1/flexparser-0.4-py3-none-any.whl", hash = "sha256:3738b456192dcb3e15620f324c447721023c0293f6af9955b481e91d00179846", size = 27625 }, +] + [[package]] name = "fonttools" version = "4.56.0" @@ -743,6 +767,7 @@ dependencies = [ { name = "attrs" }, { name = "numpy" }, { name = "pandas" }, + { name = "pint" }, ] [package.optional-dependencies] @@ -1088,6 +1113,7 @@ requires-dist = [ { name = "matplotlib", marker = "extra == 'plots'", specifier = ">=3.7.1" }, { name = "numpy", specifier = ">=1.25.0" }, { name = "pandas", specifier = ">=2.2.3" }, + { name = "pint", specifier = ">=0.24.4" }, { name = "scipy", marker = "extra == 'scipy'", specifier = ">=1.13.1" }, ] provides-extras = ["plots", "scipy", "full"] @@ -2977,6 +3003,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/67/936f9814bdd74b2dfd4822f1f7725ab5d8ff4103919a1664eb4874c58b2f/pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0", size = 2626353 }, ] +[[package]] +name = "pint" +version = "0.24.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flexcache" }, + { name = "flexparser" }, + { name = "platformdirs" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/bb/52b15ddf7b7706ed591134a895dbf6e41c8348171fb635e655e0a4bbb0ea/pint-0.24.4.tar.gz", hash = "sha256:35275439b574837a6cd3020a5a4a73645eb125ce4152a73a2f126bf164b91b80", size = 342225 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/16/bd2f5904557265882108dc2e04f18abc05ab0c2b7082ae9430091daf1d5c/Pint-0.24.4-py3-none-any.whl", hash = "sha256:aa54926c8772159fcf65f82cc0d34de6768c151b32ad1deb0331291c38fe7659", size = 302029 }, +] + [[package]] name = "pip" version = "24.3.1"