Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d302d4f
add tensor trait
fnattino Jan 15, 2026
ee1082b
add tensor param template
fnattino Jan 15, 2026
52c2809
use new param template and trait in existing models
fnattino Jan 15, 2026
72d8633
update assimilation and partitioning
fnattino Jan 15, 2026
95387a8
add state and rate variable tensors
fnattino Jan 27, 2026
2c934fc
add tensor variable containers in existing models
fnattino Jan 27, 2026
da73f9a
add info text for clearer errors
fnattino Jan 28, 2026
86a586b
fix for afgen
fnattino Feb 4, 2026
b40d8d6
fix errors in tests
fnattino Feb 4, 2026
dacf9af
add tests for tensor containers
fnattino Feb 6, 2026
cfc786b
move function to extract paramter shapes to engine
fnattino Feb 6, 2026
7e3fcc9
simplify broadcast
fnattino Feb 6, 2026
0b72ed5
new engine cannot run with pcse models anymore
fnattino Feb 6, 2026
8f79c4d
simplify all models with new parameter containers
fnattino Feb 6, 2026
f43cb60
simplify torch ones/full
fnattino Feb 11, 2026
067bd79
Update src/diffwofost/physical_models/base/states_rates.py
fnattino Feb 11, 2026
fd1843c
Update src/diffwofost/physical_models/base/states_rates.py
fnattino Feb 11, 2026
dfa2b18
Update src/diffwofost/physical_models/base/states_rates.py
fnattino Feb 11, 2026
9cd75a7
Update tests/physical_models/test_engine.py
fnattino Feb 11, 2026
a071993
Update tests/physical_models/test_engine.py
fnattino Feb 11, 2026
b148bea
Update tests/physical_models/test_engine.py
fnattino Feb 11, 2026
e431c29
add elements to api reference
fnattino Feb 11, 2026
6548380
Merge branch 'tensor-param-template' of github.com:WUR-AI/diffWOFOST …
fnattino Feb 11, 2026
9006f3f
ruff fixes
fnattino Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,15 @@ hide:
::: diffwofost.physical_models.engine.Engine

::: diffwofost.physical_models.utils.EngineTestHelper

## **Other classes (for developers)**

::: diffwofost.physical_models.base.states_rates.TensorStatesTemplate

::: diffwofost.physical_models.base.states_rates.TensorRatesTemplate

::: diffwofost.physical_models.base.states_rates.TensorParamTemplate

::: diffwofost.physical_models.base.states_rates.TensorContainer

::: diffwofost.physical_models.traitlets.Tensor
5 changes: 5 additions & 0 deletions src/diffwofost/physical_models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .states_rates import TensorParamTemplate
from .states_rates import TensorRatesTemplate
from .states_rates import TensorStatesTemplate

__all__ = ["TensorParamTemplate", "TensorRatesTemplate", "TensorStatesTemplate"]
130 changes: 130 additions & 0 deletions src/diffwofost/physical_models/base/states_rates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from pcse.base import ParamTemplate
from pcse.base import RatesTemplate
from pcse.base import StatesTemplate
from pcse.traitlets import HasTraits
from ..traitlets import Tensor
from ..utils import AfgenTrait


class TensorContainer(HasTraits):
def __init__(self, shape=None, do_not_broadcast=None, **variables):
"""Container of tensor variables.

It includes functionality to broadcast variables to a common shape. This common shape can
be inferred from the container's tensor and AFGEN variables, or it can be set as an input
argument.

Args:
shape (tuple | torch.Size, optional): Shape to which the variables in the container
are broadcasted. If given, it should match the shape of all the input variables that
already have dimensions. Defaults to None.
do_not_broadcast (list, optional): Name of the variables that are not broadcasted
to the container shape. Defaults to None, which means that all variables are
broadcasted.
variables (dict): Collection of variables to initialize the container, as key-value
pairs.
"""
self._shape = ()
self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast
HasTraits.__init__(self, **variables)
self._broadcast(shape)

def _broadcast(self, shape=None):
# Identify which variables should be broadcasted. Also check that the input shape is
# compatible with the existing variable shapes
vars_to_broadcast = self._get_vars_to_broadcast()
vars_shape = self._get_vars_shape()
if shape and vars_shape and vars_shape != shape:
raise ValueError(f"Input shape {shape} does not match variable shape {vars_shape}")
shape = tuple(shape or vars_shape)

# Broadcast all required variables to the identified shape.
for varname, var in vars_to_broadcast.items():
try:
broadcasted = var.expand(shape)
except RuntimeError as error:
raise ValueError(f"Cannot broadcast {varname} to shape {shape}") from error
setattr(self, varname, broadcasted)

# Finally, update the shape of the container
self.shape = shape

def _get_vars_to_broadcast(self):
vars = {}
for varname, trait in self.traits().items():
if varname not in self._do_not_broadcast:
if isinstance(trait, Tensor):
vars[varname] = getattr(self, varname)
return vars

def _get_vars_shape(self):
shape = ()
for varname, trait in self.traits().items():
if varname not in self._do_not_broadcast:
if isinstance(trait, Tensor) or isinstance(trait, AfgenTrait):
var = getattr(self, varname)
if not var.shape or shape == var.shape:
continue
elif var.shape and not shape:
shape = tuple(var.shape)
else:
raise ValueError(
f"Incompatible shapes within variables: {shape} and {var.shape}"
)
return shape

@property
def shape(self):
"""Base shape of the variables in the container."""
return self._shape

@shape.setter
def shape(self, shape):
if self.shape and self.shape != shape:
raise ValueError(f"Container shape already set to {self.shape}")
self._shape = shape


class TensorParamTemplate(TensorContainer, ParamTemplate):
"""Template for storing parameter values as tensors.

It includes functionality to broadcast parameters to a common shape. See
`diffwofost.base.states_rates.TensorContainer` and
`pcse.base.states_rates.ParamTemplate` for details.
"""

def __init__(self, parvalues, shape=None, do_not_broadcast=None):
self._shape = ()
self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast
ParamTemplate.__init__(self, parvalues=parvalues)
self._broadcast(shape)


class TensorStatesTemplate(TensorContainer, StatesTemplate):
"""Template for storing state variable values as tensors.

It includes functionality to broadcast state variables to a common shape. See
`diffwofost.base.states_rates.TensorContainer` and
`pcse.base.states_rates.StatesTemplate` for details.
"""

def __init__(self, kiosk=None, publish=None, shape=None, do_not_broadcast=None, **kwargs):
self._shape = ()
self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast
StatesTemplate.__init__(self, kiosk=kiosk, publish=publish, **kwargs)
self._broadcast(shape)


class TensorRatesTemplate(TensorContainer, RatesTemplate):
"""Template for storing rate variable values as tensors.

It includes functionality to broadcast rate variables to a common shape. See
`diffwofost.base.states_rates.TensorContainer` and
`pcse.base.states_rates.RatesTemplate` for details.
"""

def __init__(self, kiosk=None, publish=None, shape=None, do_not_broadcast=None):
self._shape = ()
self._do_not_broadcast = [] if do_not_broadcast is None else do_not_broadcast
RatesTemplate.__init__(self, kiosk=kiosk, publish=publish)
self._broadcast(shape)
67 changes: 31 additions & 36 deletions src/diffwofost/physical_models/crop/assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
import datetime
from collections import deque
import torch
from pcse.base import ParamTemplate
from pcse.base import RatesTemplate
from pcse.base import SimulationObject
from pcse.base.parameter_providers import ParameterProvider
from pcse.base.variablekiosk import VariableKiosk
from pcse.base.weather import WeatherDataContainer
from pcse.decorators import prepare_rates
from pcse.decorators import prepare_states
from pcse.traitlets import Any
from pcse.util import astro
from diffwofost.physical_models.base import TensorParamTemplate
from diffwofost.physical_models.base import TensorRatesTemplate
from diffwofost.physical_models.config import ComputeConfig
from diffwofost.physical_models.traitlets import Tensor
from diffwofost.physical_models.utils import AfgenTrait
from diffwofost.physical_models.utils import _broadcast_to
from diffwofost.physical_models.utils import _get_drv
from diffwofost.physical_models.utils import _get_params_shape


def _as_python_float(x) -> float:
Expand All @@ -42,6 +41,8 @@ def totass7(
COSLD: torch.Tensor,
*,
epsilon: torch.Tensor,
dtype: torch.Size | tuple,
device: str,
) -> torch.Tensor:
"""Calculates daily total gross CO2 assimilation.

Expand Down Expand Up @@ -69,9 +70,9 @@ def totass7(
COSLD R4 Amplitude of sine of solar height - I
DTGA R4 Daily total gross assimilation kg CO2/ha/d O
"""
xgauss = torch.tensor([0.1127017, 0.5000000, 0.8872983], dtype=DAYL.dtype, device=DAYL.device)
wgauss = torch.tensor([0.2777778, 0.4444444, 0.2777778], dtype=DAYL.dtype, device=DAYL.device)
pi = torch.tensor(torch.pi, dtype=DAYL.dtype, device=DAYL.device)
xgauss = torch.tensor([0.1127017, 0.5000000, 0.8872983], dtype=dtype, device=device)
wgauss = torch.tensor([0.2777778, 0.4444444, 0.2777778], dtype=dtype, device=device)
pi = torch.tensor(torch.pi, dtype=dtype, device=device)

# Only compute where it can be non-zero.
mask = (AMAX > 0) & (LAI > 0) & (DAYL > 0)
Expand Down Expand Up @@ -229,8 +230,6 @@ class WOFOST72_Assimilation(SimulationObject):
| PGASS | AMAXTB, EFFTB, KDIFTB, TMPFTB, TMNFTB |
""" # noqa: E501

params_shape = None

@property
def device(self):
"""Get device from ComputeConfig."""
Expand All @@ -241,33 +240,27 @@ def dtype(self):
"""Get dtype from ComputeConfig."""
return ComputeConfig.get_dtype()

class Parameters(ParamTemplate):
class Parameters(TensorParamTemplate):
AMAXTB = AfgenTrait()
EFFTB = AfgenTrait()
KDIFTB = AfgenTrait()
TMPFTB = AfgenTrait()
TMNFTB = AfgenTrait()

def __init__(self, parvalues):
super().__init__(parvalues)

class RateVariables(RatesTemplate):
PGASS = Any()

def __init__(self, kiosk, publish=None):
dtype = ComputeConfig.get_dtype()
device = ComputeConfig.get_device()
self.PGASS = torch.tensor(0.0, dtype=dtype, device=device)
super().__init__(kiosk, publish=publish)
class RateVariables(TensorRatesTemplate):
PGASS = Tensor(0.0)

def initialize(
self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider
self,
day: datetime.date,
kiosk: VariableKiosk,
parvalues: ParameterProvider,
shape: tuple | torch.Size | None = None,
) -> None:
"""Initialize the assimilation module."""
self.kiosk = kiosk
self.params = self.Parameters(parvalues)
self.params_shape = _get_params_shape(self.params)
self.rates = self.RateVariables(kiosk, publish=["PGASS"])
self.params = self.Parameters(parvalues, shape=shape)
self.rates = self.RateVariables(kiosk, publish=["PGASS"], shape=shape)

# 7-day running average buffer for TMIN (stored as tensors).
self._tmn_window = deque(maxlen=7)
Expand All @@ -285,16 +278,16 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None
_exist_required_external_variables(k)

# External states
dvs = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device)
lai = _broadcast_to(k["LAI"], self.params_shape, dtype=self.dtype, device=self.device)
dvs = _broadcast_to(k["DVS"], self.params.shape, dtype=self.dtype, device=self.device)
lai = _broadcast_to(k["LAI"], self.params.shape, dtype=self.dtype, device=self.device)

# Weather drivers
irrad = _get_drv(drv.IRRAD, self.params_shape, dtype=self.dtype, device=self.device)
dtemp = _get_drv(drv.DTEMP, self.params_shape, dtype=self.dtype, device=self.device)
tmin = _get_drv(drv.TMIN, self.params_shape, dtype=self.dtype, device=self.device)
irrad = _get_drv(drv.IRRAD, self.params.shape, dtype=self.dtype, device=self.device)
dtemp = _get_drv(drv.DTEMP, self.params.shape, dtype=self.dtype, device=self.device)
tmin = _get_drv(drv.TMIN, self.params.shape, dtype=self.dtype, device=self.device)

# Assimilation is zero before crop emergence (DVS < 0)
dvs_mask = (dvs >= 0).to(dtype=self.dtype)
dvs_mask = dvs >= 0
# 7-day running average of TMIN
self._tmn_window.appendleft(tmin * dvs_mask)
self._tmn_window_mask.appendleft(dvs_mask)
Expand All @@ -307,11 +300,11 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None
irrad_for_astro = _as_python_float(drv.IRRAD)
dayl, _daylp, sinld, cosld, difpp, _atmtr, dsinbe, _angot = astro(day, lat, irrad_for_astro)

dayl_t = _broadcast_to(dayl, self.params_shape, dtype=self.dtype, device=self.device)
sinld_t = _broadcast_to(sinld, self.params_shape, dtype=self.dtype, device=self.device)
cosld_t = _broadcast_to(cosld, self.params_shape, dtype=self.dtype, device=self.device)
difpp_t = _broadcast_to(difpp, self.params_shape, dtype=self.dtype, device=self.device)
dsinbe_t = _broadcast_to(dsinbe, self.params_shape, dtype=self.dtype, device=self.device)
dayl_t = _broadcast_to(dayl, self.params.shape, dtype=self.dtype, device=self.device)
sinld_t = _broadcast_to(sinld, self.params.shape, dtype=self.dtype, device=self.device)
cosld_t = _broadcast_to(cosld, self.params.shape, dtype=self.dtype, device=self.device)
difpp_t = _broadcast_to(difpp, self.params.shape, dtype=self.dtype, device=self.device)
dsinbe_t = _broadcast_to(dsinbe, self.params.shape, dtype=self.dtype, device=self.device)

# Parameter tables
amax = p.AMAXTB(dvs)
Expand All @@ -331,6 +324,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None
sinld_t,
cosld_t,
epsilon=self._epsilon,
dtype=self.dtype,
device=self.device,
)

# Correction for low minimum temperature potential
Expand Down
Loading