Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"attrs>=24.3.0",
"numpy>=1.25.0",
"pandas>=2.2.3",
"pint>=0.24.4",
]
readme = "README.md"
classifiers = [
Expand Down
5 changes: 4 additions & 1 deletion requirements-docs-locked.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ defusedxml==0.7.1
exceptiongroup==1.2.2 ; python_full_version < '3.11'
executing==2.1.0
fastjsonschema==2.21.1
flexcache==0.3
flexparser==0.4
fonttools==4.56.0
fqdn==1.5.1
ghp-import==2.1.0
Expand Down Expand Up @@ -97,6 +99,7 @@ parso==0.8.4
pathspec==0.12.1
pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
pillow==11.1.0
pint==0.24.4
platformdirs==4.3.6
prometheus-client==0.21.1
prompt-toolkit==3.0.48
Expand Down Expand Up @@ -136,7 +139,7 @@ tomli==2.2.1 ; python_full_version < '3.11'
tornado==6.4.2
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
typing-extensions==4.12.2 ; python_full_version < '3.13'
typing-extensions==4.12.2
tzdata==2025.1
uri-template==1.3.0
urllib3==2.3.0
Expand Down
5 changes: 5 additions & 0 deletions requirements-incl-optional-locked.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ attrs==24.3.0
contourpy==1.3.0 ; python_full_version < '3.10'
contourpy==1.3.1 ; python_full_version >= '3.10'
cycler==0.12.1
flexcache==0.3
flexparser==0.4
fonttools==4.56.0
importlib-resources==6.5.2 ; python_full_version < '3.10'
kiwisolver==1.4.7 ; python_full_version < '3.10'
Expand All @@ -14,11 +16,14 @@ numpy==2.0.2
packaging==24.2
pandas==2.2.3
pillow==11.1.0
pint==0.24.4
platformdirs==4.3.6
pyparsing==3.2.1
python-dateutil==2.9.0.post0
pytz==2025.1
scipy==1.13.1 ; python_full_version < '3.10'
scipy==1.15.2 ; python_full_version >= '3.10'
six==1.17.0
typing-extensions==4.12.2
tzdata==2025.1
zipp==3.21.0 ; python_full_version < '3.10'
5 changes: 5 additions & 0 deletions requirements-locked.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# This file was autogenerated by uv via the following command:
# uv export -o requirements-locked.txt --no-hashes --no-dev --no-emit-project
attrs==24.3.0
flexcache==0.3
flexparser==0.4
numpy==2.0.2
pandas==2.2.3
pint==0.24.4
platformdirs==4.3.6
python-dateutil==2.9.0.post0
pytz==2025.1
six==1.17.0
typing-extensions==4.12.2
tzdata==2025.1
63 changes: 38 additions & 25 deletions src/gradient_aware_harmonisation/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Protocol, Union, overload
from typing import TYPE_CHECKING, Any, Protocol, overload

import numpy as np
import attr
import numpy.typing as npt
from attrs import define
from attrs import define, field

from gradient_aware_harmonisation.typing import NP_FLOAT_OR_INT
from gradient_aware_harmonisation.utils import validate_domain

if TYPE_CHECKING:
import scipy.interpolate
from typing_extensions import TypeAlias

NP_FLOAT_OR_INT: TypeAlias = Union[np.floating[Any], np.integer[Any]]
"""
Type alias for a numpy float or int (not complex)
"""

NP_ARRAY_OF_FLOAT_OR_INT: TypeAlias = npt.NDArray[NP_FLOAT_OR_INT]
"""
Expand All @@ -30,8 +29,25 @@ class Spline(Protocol):
Single spline
"""

# domain: [float, float]
# """Domain over the spline can be used"""
domain: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT] = field()
"""
Domain over which spline can be evaluated
"""

@domain.validator
def domain_validator(
self,
attribute: attr.Attribute[Any],
value: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT],
) -> None:
"""
Validate the received values
"""
try:
validate_domain(value)
except AssertionError as exc:
msg = "The value supplied for `domain` failed validation."
raise ValueError(msg) from exc

@overload
def __call__(self, x: int | float) -> int | float: ...
Expand Down Expand Up @@ -64,11 +80,10 @@ class SplineScipy:
An adapter which wraps various classes from [scipy.interpolate][]
"""

# domain: ClassVar[list[float, float]] = [
# np.finfo(np.float64).tiny,
# np.finfo(np.float64).max,
# ]
# """domain of spline (reals)"""
domain: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT]
"""
Domain over which spline can be evaluated
"""

scipy_spline: scipy.interpolate.BSpline | scipy.interpolate.PPoly

Expand Down Expand Up @@ -108,7 +123,7 @@ def derivative(self) -> SplineScipy:
:
Derivative of self
"""
return SplineScipy(self.scipy_spline.derivative())
return SplineScipy(self.scipy_spline.derivative(), domain=self.domain)

def antiderivative(self) -> SplineScipy:
"""
Expand All @@ -134,11 +149,10 @@ class SumOfSplines:
spline_two: Spline
"""Second spline"""

# domain: ClassVar[list[float, float]] = [
# np.finfo(np.float64).tiny,
# np.finfo(np.float64).max,
# ]
# """Domain of spline"""
domain: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT]
"""
Domain over which spline can be evaluated
"""

@overload
def __call__(self, x: int | float) -> int | float: ...
Expand Down Expand Up @@ -204,11 +218,10 @@ class ProductOfSplines:
spline_two: Spline
"""Second spline"""

# domain: ClassVar[list[float, float]] = [
# np.finfo(np.float64).tiny,
# np.finfo(np.float64).max,
# ]
# """Domain of spline"""
domain: tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT]
"""
Domain over which spline can be evaluated
"""

@overload
def __call__(self, x: int | float) -> int | float: ...
Expand Down
1 change: 0 additions & 1 deletion src/gradient_aware_harmonisation/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,4 @@ def to_spline(self, **kwargs: Any) -> SplineScipy:
scipy.interpolate.make_interp_spline(
x=self.time_axis, y=self.values, **kwargs
),
# domain=(self.time_axis.min(), self.time_axis.max())
)
6 changes: 6 additions & 0 deletions src/gradient_aware_harmonisation/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import numpy.typing as npt
import pint.facets.numpy.quantity

if TYPE_CHECKING:
from typing_extensions import TypeAlias
Expand All @@ -21,3 +22,8 @@
"""
Type alias for an array of numpy float or int (not complex)
"""

PINT_SCALAR: TypeAlias = pint.facets.numpy.quantity.NumpyQuantity[NP_FLOAT_OR_INT]
"""
Type alias for a pint quantity that wraps a numpy scalar
"""
29 changes: 29 additions & 0 deletions src/gradient_aware_harmonisation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
SplineScipy,
SumOfSplines,
)
from gradient_aware_harmonisation.typing import NP_FLOAT_OR_INT, PINT_SCALAR


class GetHarmonisedSplineLike(Protocol):
Expand Down Expand Up @@ -185,3 +186,31 @@ def add_constant_to_spline(in_spline: Spline, constant: float | int) -> Spline:
)
),
)


def validate_domain(
domain: Union[
tuple[PINT_SCALAR, PINT_SCALAR], tuple[NP_FLOAT_OR_INT, NP_FLOAT_OR_INT]
],
) -> None:
"""
Check that domain values are valid

Parameters
----------
domain
Domain to check

Raises
------
AssertionError
`len(domain) != 2` or `domain[1] <= domain[0]`.
"""
expected_domain_length = 2
if len(domain) != expected_domain_length:
raise AssertionError(len(domain))

if domain[1] <= domain[0]:
msg = f"domain[1] must be greater than domain[0]. Received {domain=}."

raise AssertionError(msg)
41 changes: 41 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading