From 2992994442be226a065af1fa01a04c9f2bd7adee Mon Sep 17 00:00:00 2001 From: SCiarella Date: Wed, 28 Jan 2026 14:04:40 +0100 Subject: [PATCH 1/9] diffvectorize --- docs/api_reference.md | 2 + .../crop/storage_organ_dynamics.py | 233 +++++++ .../crop/test_storage_organ_dynamics.py | 572 ++++++++++++++++++ 3 files changed, 807 insertions(+) create mode 100644 src/diffwofost/physical_models/crop/storage_organ_dynamics.py create mode 100644 tests/physical_models/crop/test_storage_organ_dynamics.py diff --git a/docs/api_reference.md b/docs/api_reference.md index 4eff37d..b52fbdd 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -16,6 +16,8 @@ hide: ::: diffwofost.physical_models.crop.root_dynamics.WOFOST_Root_Dynamics +::: diffwofost.physical_models.crop.storage_organ_dynamics.WOFOST_Storage_Organ_Dynamics. + ## **Utility (under development)** diff --git a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py new file mode 100644 index 0000000..4ef523e --- /dev/null +++ b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py @@ -0,0 +1,233 @@ +import datetime +import torch +from pcse.base import ParamTemplate +from pcse.base import RatesTemplate +from pcse.base import SimulationObject +from pcse.base import StatesTemplate +from pcse.base.parameter_providers import ParameterProvider +from pcse.base.variablekiosk import VariableKiosk +from pcse.base.weather import WeatherDataContainer +from pcse.decorators import prepare_rates +from pcse.decorators import prepare_states +from pcse.traitlets import Any +from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_params_shape + + +class WOFOST_Storage_Organ_Dynamics(SimulationObject): + """Implementation of storage organ dynamics. + + Storage organs are the most simple component of the plant in WOFOST and + consist of a static pool of biomass. Growth of the storage organs is the + result of assimilate partitioning. Death of storage organs is not + implemented and the corresponding rate variable (DRSO) is always set to + zero. + + Pods are green elements of the plant canopy and can as such contribute + to the total photosynthetic active area. This is expressed as the Pod + Area Index which is obtained by multiplying pod biomass with a fixed + Specific Pod Area (SPA). + + **Simulation parameters** + + | Name | Description | Type | Unit | + |------|===============================================|========|=============| + | TDWI | Initial total crop dry weight | SCr | kg ha⁻¹ | + | SPA | Specific Pod Area | SCr | ha kg⁻¹ | + + **State variables** + + | Name | Description | Pbl | Unit | + |------|==================================================|======|=============| + | PAI | Pod Area Index | Y | - | + | WSO | Weight of living storage organs | Y | kg ha⁻¹ | + | DWSO | Weight of dead storage organs | N | kg ha⁻¹ | + | TWSO | Total weight of storage organs | Y | kg ha⁻¹ | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |------|==================================================|======|=============| + | GRSO | Growth rate storage organs | N | kg ha⁻¹ d⁻¹ | + | DRSO | Death rate storage organs | N | kg ha⁻¹ d⁻¹ | + | GWSO | Net change in storage organ biomass | N | kg ha⁻¹ d⁻¹ | + + **Signals send or handled** + + None + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|====================================|=====================|=============| + | ADMI | Above-ground dry matter increase | CropSimulation | kg ha⁻¹ d⁻¹ | + | FO | Fraction biomass to storage organs | DVS_Partitioning | - | + | FR | Fraction biomass to roots | DVS_Partitioning | - | + + **Outputs:** + + | Name | Description | Provided by | Unit | + |------|------------------------------|-------------|--------------| + | PAI | Pod Area Index | Y | - | + | TWSO | Total weight storage organs | Y | kg ha⁻¹ | + | WSO | Weight living storage organs | Y | kg ha⁻¹ | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|----------------------------| + | PAI | SPA | + | TWSO | TDWI | + | WSO | TDWI | + """ + + params_shape = None # Shape of the parameters tensors + + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + + class Parameters(ParamTemplate): + SPA = Any() + TDWI = Any() + + def __init__(self, parvalues): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.SPA = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.TDWI = [torch.tensor(-99.0, dtype=dtype, device=device)] + + # Call parent init + super().__init__(parvalues) + + class StateVariables(StatesTemplate): + WSO = Any() # Weight living storage organs + DWSO = Any() # Weight dead storage organs + TWSO = Any() # Total weight storage organs + PAI = Any() # Pod Area Index + + def __init__(self, kiosk, publish=None, **kwargs): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + if "WSO" not in kwargs: + self.WSO = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "DWSO" not in kwargs: + self.DWSO = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "TWSO" not in kwargs: + self.TWSO = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "PAI" not in kwargs: + self.PAI = [torch.tensor(-99.0, dtype=dtype, device=device)] + + # Call parent init + super().__init__(kiosk, publish=publish, **kwargs) + + class RateVariables(RatesTemplate): + GRSO = Any() + DRSO = Any() + GWSO = Any() + + def __init__(self, kiosk, publish=None): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.GRSO = torch.tensor(0.0, dtype=dtype, device=device) + self.DRSO = torch.tensor(0.0, dtype=dtype, device=device) + self.GWSO = torch.tensor(0.0, dtype=dtype, device=device) + + # Call parent init + super().__init__(kiosk, publish=publish) + + def initialize( + self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + ) -> None: + """Initialize the storage organ dynamics model. + + :param day: start date of the simulation + :param kiosk: variable kiosk of this PCSE instance + :param parvalues: `ParameterProvider` object providing parameters as + key/value pairs + """ + self.kiosk = kiosk + self.params = self.Parameters(parvalues) + self.rates = self.RateVariables(kiosk, publish=["GRSO"]) + + # INITIAL STATES + params = self.params + self.params_shape = _get_params_shape(params) + shape = self.params_shape + + # Initial storage organ biomass + TDWI = _broadcast_to(params.TDWI, shape, dtype=self.dtype, device=self.device) + SPA = _broadcast_to(params.SPA, shape, dtype=self.dtype, device=self.device) + FO = _broadcast_to(self.kiosk["FO"], shape, dtype=self.dtype, device=self.device) + FR = _broadcast_to(self.kiosk["FR"], shape, dtype=self.dtype, device=self.device) + + WSO = (TDWI * (1 - FR)) * FO + DWSO = torch.zeros(shape, dtype=self.dtype, device=self.device) + TWSO = WSO + DWSO + # Initial Pod Area Index + PAI = WSO * SPA + + self.states = self.StateVariables( + kiosk, publish=["TWSO", "WSO", "PAI"], WSO=WSO, DWSO=DWSO, TWSO=TWSO, PAI=PAI + ) + + @prepare_rates + def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None) -> None: + """Calculate the rates of change of the state variables. + + Args: + day (datetime.date, optional): The current date of the simulation. + drv (WeatherDataContainer, optional): A dictionary-like container holding + weather data elements as key/value. + """ + rates = self.rates + k = self.kiosk + + FO = _broadcast_to(k["FO"], self.params_shape, dtype=self.dtype, device=self.device) + ADMI = _broadcast_to(k["ADMI"], self.params_shape, dtype=self.dtype, device=self.device) + REALLOC_SO = _broadcast_to( + k.get("REALLOC_SO", 0.0), self.params_shape, dtype=self.dtype, device=self.device + ) + + # Growth/death rate organs + rates.GRSO = ADMI * FO + rates.DRSO = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + rates.GWSO = rates.GRSO - rates.DRSO + REALLOC_SO + + @prepare_states + def integrate(self, day: datetime.date = None, delt=1.0) -> None: + """Integrate the state variables. + + Args: + day (datetime.date, optional): The current date of the simulation. + delt (float, optional): The time step for integration. Defaults to 1.0. + """ + params = self.params + rates = self.rates + states = self.states + + SPA = _broadcast_to(params.SPA, self.params_shape, dtype=self.dtype, device=self.device) + + # Stem biomass (living, dead, total) + states.WSO = states.WSO + rates.GWSO + states.DWSO = states.DWSO + rates.DRSO + states.TWSO = states.WSO + states.DWSO + + # Calculate Pod Area Index (SAI) + states.PAI = states.WSO * SPA diff --git a/tests/physical_models/crop/test_storage_organ_dynamics.py b/tests/physical_models/crop/test_storage_organ_dynamics.py new file mode 100644 index 0000000..4a20334 --- /dev/null +++ b/tests/physical_models/crop/test_storage_organ_dynamics.py @@ -0,0 +1,572 @@ +import copy +import warnings +from unittest.mock import patch +import pytest +import torch +from pcse.models import Wofost72_PP +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.crop.storage_organ_dynamics import WOFOST_Storage_Organ_Dynamics +from diffwofost.physical_models.utils import EngineTestHelper +from diffwofost.physical_models.utils import calculate_numerical_grad +from diffwofost.physical_models.utils import get_test_data +from diffwofost.physical_models.utils import prepare_engine_input +from .. import phy_data_folder + +storage_dynamics_config = Configuration( + CROP=WOFOST_Storage_Organ_Dynamics, + OUTPUT_VARS=["PAI", "TWSO", "WSO", "DWSO"], +) + +# [!] Notice that the storage organ module does not have dedicated test data. +# This means that we can only test the execution of the module, +# but not the correctness of its results (except when used within Wofost72_PP). + + +def _prepare_common_storage_inputs(test_data_url, device, meteo_range_checks=True): + # prepare model input + test_data = get_test_data(test_data_url) + crop_model_params = ["TDWI", "SPA"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=meteo_range_checks, device=device + ) + + # Patch missing states + for state in external_states: + if "FO" not in state: + state["FO"] = 0.5 + if "FR" not in state: + state["FR"] = 0.5 + if "ADMI" not in state: + state["ADMI"] = 100.0 + # DVS is unused in storage organ dynamics but good to have if something changes + if "DVS" not in state: + state["DVS"] = 0.0 + + # Patch missing parameters + if "SPA" not in crop_model_params_provider: + crop_model_params_provider.set_override( + "SPA", + torch.tensor(0.01, dtype=torch.float64, device=device), + check=False, + ) + if "TDWI" not in crop_model_params_provider: + crop_model_params_provider.set_override( + "TDWI", torch.tensor(20.0, dtype=torch.float64, device=device), check=False + ) + + return ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) + + +def get_test_diff_storage_model(device: str = "cpu"): + # [!] The storage organ module does not have dedicated test data. + # We reuse the partitioning test data as they contain relevant parameters and states. + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + + ( + _, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device) + + return DiffStorageDynamics( + copy.deepcopy(crop_model_params_provider), + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + copy.deepcopy(external_states), + device=device, + ) + + +class DiffStorageDynamics(torch.nn.Module): + def __init__( + self, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config, + external_states, + device: str = "cpu", + ): + super().__init__() + self.crop_model_params_provider = crop_model_params_provider + self.weather_data_provider = weather_data_provider + self.agro_management_inputs = agro_management_inputs + self.config = config + self.external_states = external_states + self.device = device + + def forward(self, params_dict): + # pass new value of parameters to the model + for name, value in params_dict.items(): + self.crop_model_params_provider.set_override(name, value, check=False) + + engine = EngineTestHelper( + self.crop_model_params_provider, + self.weather_data_provider, + self.agro_management_inputs, + self.config, + self.external_states, + device=self.device, + ) + engine.run_till_terminate() + results = engine.get_output() + + return { + var: torch.stack([item[var] for item in results]) + for var in ["PAI", "TWSO", "WSO", "DWSO"] + } + + +class TestStorageOrganDynamics: + # [!] The storage module does not have dedicated test data. + # We reuse the partitioning test data as they contain relevant parameters and states. + storage_dynamics_data_urls = [ + f"{phy_data_folder}/test_partitioning_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + wofost72_data_urls = [ + f"{phy_data_folder}/test_potentialproduction_wofost72_{i:02d}.yaml" + for i in range(1, 45) # there are 44 test files + ] + + @pytest.mark.parametrize("test_data_url", storage_dynamics_data_urls) + def test_storage_dynamics_with_testengine(self, test_data_url, device): + """EngineTestHelper and not Engine because it allows to specify `external_states`.""" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + @pytest.mark.parametrize("param", ["TDWI", "SPA", "TEMP"]) + def test_storage_dynamics_with_one_parameter_vector(self, param, device): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device, meteo_range_checks=False) + + # Setting a vector (with one value) for the selected parameter + if param == "TEMP": + # Vectorize weather variable + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + else: + # Broadcast all parameters to match the batch size of 10 + for p_name in ["TDWI", "SPA"]: + if p_name in crop_model_params_provider: + p_val = crop_model_params_provider[p_name] + if p_val.dim() == 0: # scalar + crop_model_params_provider.set_override( + p_name, p_val.repeat(10), check=False + ) + elif p_val.dim() == 2 and p_val.shape[0] == 1: # table (1, M) -> (10, M) + crop_model_params_provider.set_override( + p_name, p_val.repeat(10, 1), check=False + ) + + if param == "TEMP": + # Vectorize weather variable + # We expect the model to handle scalar parameters with vectorized weather + # via implicit broadcasting or explicit checks passing. + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + else: + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + @pytest.mark.parametrize( + "param,delta", + [ + ("TDWI", 0.1), + ("SPA", 0.0001), + ], + ) + def test_storage_dynamics_with_different_parameter_values(self, param, delta, device): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device) + + # Setting a vector with multiple values for the selected parameter + test_value = crop_model_params_provider[param] + + param_vec = torch.tensor( + [test_value - delta, test_value + delta, test_value], + device=device, + dtype=torch.float64, + ) + target_batch_size = 3 + crop_model_params_provider.set_override(param, param_vec, check=False) + + # Broadcast all other params + for p_name in ["TDWI", "SPA"]: + if p_name == param: + continue + if p_name not in crop_model_params_provider: + continue + + p_val = crop_model_params_provider[p_name] + if p_val.dim() == 0: + crop_model_params_provider.set_override( + p_name, p_val.repeat(target_batch_size), check=False + ) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + def test_storage_dynamics_with_multiple_parameter_vectors(self, device): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device) + + # Setting a vector (with one value) for the TDWI and SPA parameters + for param in ("TDWI", "SPA"): + if param == "SPA" and crop_model_params_provider[param].dim() == 2: + # In case SPA is treated as table somehow, though here it is scalar + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + def test_storage_dynamics_with_multiple_parameter_arrays(self, device): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + test_data, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device=device, meteo_range_checks=False) + + # Setting an array with arbitrary shape (and one value) + for param in ("TDWI", "SPA"): + repeated = crop_model_params_provider[param].broadcast_to((30, 5)) + crop_model_params_provider.set_override(param, repeated, check=False) + + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones((30, 5), dtype=torch.float64) * wdc.TEMP + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results = test_data["ModelResults"] + + # Assertions on values removed as test data is not appropriate for this module + assert len(actual_results) == len(expected_results) + + def test_storage_dynamics_with_incompatible_parameter_vectors(self): + # prepare model input + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + ( + _, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = _prepare_common_storage_inputs(test_data_url, device="cpu") + + # Setting a vector (with one value) for the TDWI and SPA parameters, + # but with different lengths + crop_model_params_provider.set_override( + "TDWI", crop_model_params_provider["TDWI"].repeat(10), check=False + ) + crop_model_params_provider.set_override( + "SPA", crop_model_params_provider["SPA"].repeat(5), check=False + ) + + with pytest.raises(AssertionError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + storage_dynamics_config, + external_states, + device="cpu", + ) + + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) + def test_wofost_pp_with_storage_dynamics(self, test_data_url): + # prepare model input + test_data = get_test_data(test_data_url) + crop_model_params = ["TDWI", "SPA"] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params) + ) + + # get expected results from YAML test data + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + with patch("pcse.crop.wofost72.Storage_Organ_Dynamics", WOFOST_Storage_Organ_Dynamics): + model = Wofost72_PP( + crop_model_params_provider, weather_data_provider, agro_management_inputs + ) + model.run_till_terminate() + actual_results = model.get_output() + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + abs(reference[var] - model[var]) < precision + for var, precision in expected_precision.items() + ) + + +class TestDiffStorageDynamicsGradients: + """Parametrized tests for gradient calculations in storage organ dynamics.""" + + # Define parameters and outputs + param_names = ["TDWI", "SPA"] + output_names = ["PAI", "TWSO", "WSO"] + + # Define parameter configurations (value, dtype) + param_configs = { + "single": { + "TDWI": (0.2, torch.float64), + "SPA": (0.01, torch.float64), + }, + "tensor": { + "TDWI": ([0.1, 0.2, 0.3], torch.float64), + "SPA": ([0.01, 0.02, 0.03], torch.float64), + }, + } + + # Define which parameter-output pairs should have gradients + # Format: {param_name: [list of outputs that should have gradients]} + gradient_mapping = { + "TDWI": ["PAI", "TWSO", "WSO", "DWSO"], + "SPA": ["PAI"], + } + + # Generate all combinations + gradient_params = [] + no_gradient_params = [] + for param_name in param_names: + for output_name in output_names: + if output_name in gradient_mapping.get(param_name, []): + gradient_params.append((param_name, output_name)) + else: + no_gradient_params.append((param_name, output_name)) + + @pytest.mark.parametrize("param_name,output_name", no_gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_no_gradients(self, param_name, output_name, config_type, device): + """Test cases where parameters should not have gradients for specific outputs.""" + model = get_test_diff_storage_model(device=device) + + if config_type == "tensor": + for p_name, (p_val, p_dtype) in self.param_configs["tensor"].items(): + if p_name != param_name: + model.crop_model_params_provider.set_override( + p_name, torch.tensor(p_val, dtype=p_dtype, device=device), check=False + ) + + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + output = model({param_name: param}) + loss = output[output_name].sum() + + if not loss.requires_grad: + return + + try: + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + except RuntimeError as e: + if "does not require grad" in str(e): + return + raise e + + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): + """Test that forward and backward gradients match for parameter-output pairs.""" + model = get_test_diff_storage_model(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + + overrides = {param_name: param} + if config_type == "tensor": + for p_name, (p_val, p_dtype) in self.param_configs["tensor"].items(): + if p_name != param_name: + overrides[p_name] = torch.tensor(p_val, dtype=p_dtype, device=device) + + output = model(overrides) + loss = output[output_name].sum() + + # this is ∂loss/∂param + # this is called forward gradient here because it is calculated without backpropagation. + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + assert grads is not None, f"Gradients for {param_name} should not be None" + + param.grad = None # clear any existing gradient + loss.backward() + + # this is ∂loss/∂param calculated using backpropagation + grad_backward = param.grad + + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), ( + f"Forward and backward gradients for {param_name} should match" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, output_name, config_type, device): + """Test that analytical gradients match numerical gradients.""" + value, _ = self.param_configs[config_type][param_name] + + # we pass `param` and not `param.data` because we want `requires_grad=True` + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) + + def model_factory(): + m = get_test_diff_storage_model(device=device) + if config_type == "tensor": + for p_name, (p_val, p_dtype) in self.param_configs["tensor"].items(): + if p_name != param_name: + m.crop_model_params_provider.set_override( + p_name, torch.tensor(p_val, dtype=p_dtype, device=device), check=False + ) + return m + + numerical_grad = calculate_numerical_grad(model_factory, param_name, param, output_name) + + model = model_factory() + output = model({param_name: param}) + loss = output[output_name].sum() + + # this is ∂loss/∂param, for comparison with numerical gradient + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + torch.testing.assert_close( + numerical_grad.detach().cpu(), + grads.detach().cpu(), + rtol=1e-3, + atol=1e-3, + ) + + # Warn if gradient is zero (but this shouldn't happen for gradient_params) + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}' with respect to output " + + f"'{output_name}' is zero: {grads.data}", + UserWarning, + ) From e79da0647132af4f0bcd0d46e9c468088b673504 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Wed, 28 Jan 2026 14:36:44 +0100 Subject: [PATCH 2/9] Fix docs --- docs/api_reference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index b52fbdd..7575569 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -16,7 +16,7 @@ hide: ::: diffwofost.physical_models.crop.root_dynamics.WOFOST_Root_Dynamics -::: diffwofost.physical_models.crop.storage_organ_dynamics.WOFOST_Storage_Organ_Dynamics. +::: diffwofost.physical_models.crop.storage_organ_dynamics.WOFOST_Storage_Organ_Dynamics ## **Utility (under development)** From ea838f7a54f3e6fde4a77cf496903df49737fed8 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 26 Feb 2026 12:11:16 +0100 Subject: [PATCH 3/9] Update PR --- .../crop/storage_organ_dynamics.py | 118 +++++------------- .../crop/test_storage_organ_dynamics.py | 38 +++--- 2 files changed, 47 insertions(+), 109 deletions(-) diff --git a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py index 4ef523e..250bd5a 100644 --- a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py +++ b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py @@ -1,18 +1,16 @@ import datetime import torch -from pcse.base import ParamTemplate from pcse.base import RatesTemplate from pcse.base import SimulationObject -from pcse.base import StatesTemplate from pcse.base.parameter_providers import ParameterProvider from pcse.base.variablekiosk import VariableKiosk from pcse.base.weather import WeatherDataContainer from pcse.decorators import prepare_rates from pcse.decorators import prepare_states -from pcse.traitlets import Any +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorStatesTemplate from diffwofost.physical_models.config import ComputeConfig -from diffwofost.physical_models.utils import _broadcast_to -from diffwofost.physical_models.utils import _get_params_shape +from diffwofost.physical_models.traitlets import Tensor class WOFOST_Storage_Organ_Dynamics(SimulationObject): @@ -94,91 +92,43 @@ def dtype(self): """Get dtype from ComputeConfig.""" return ComputeConfig.get_dtype() - class Parameters(ParamTemplate): - SPA = Any() - TDWI = Any() - - def __init__(self, parvalues): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - self.SPA = [torch.tensor(-99.0, dtype=dtype, device=device)] - self.TDWI = [torch.tensor(-99.0, dtype=dtype, device=device)] - - # Call parent init - super().__init__(parvalues) - - class StateVariables(StatesTemplate): - WSO = Any() # Weight living storage organs - DWSO = Any() # Weight dead storage organs - TWSO = Any() # Total weight storage organs - PAI = Any() # Pod Area Index - - def __init__(self, kiosk, publish=None, **kwargs): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - - # Set default values - if "WSO" not in kwargs: - self.WSO = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "DWSO" not in kwargs: - self.DWSO = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "TWSO" not in kwargs: - self.TWSO = [torch.tensor(-99.0, dtype=dtype, device=device)] - if "PAI" not in kwargs: - self.PAI = [torch.tensor(-99.0, dtype=dtype, device=device)] - - # Call parent init - super().__init__(kiosk, publish=publish, **kwargs) + class Parameters(TensorParamTemplate): + SPA = Tensor(-99.0) + TDWI = Tensor(-99.0) - class RateVariables(RatesTemplate): - GRSO = Any() - DRSO = Any() - GWSO = Any() - - def __init__(self, kiosk, publish=None): - # Get dtype and device from ComputeConfig - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() + class StateVariables(TensorStatesTemplate): + WSO = Tensor(-99.0) # Weight living storage organs + DWSO = Tensor(-99.0) # Weight dead storage organs + TWSO = Tensor(-99.0) # Total weight storage organs + PAI = Tensor(-99.0) # Pod Area Index - # Set default values - self.GRSO = torch.tensor(0.0, dtype=dtype, device=device) - self.DRSO = torch.tensor(0.0, dtype=dtype, device=device) - self.GWSO = torch.tensor(0.0, dtype=dtype, device=device) - - # Call parent init - super().__init__(kiosk, publish=publish) + class RateVariables(RatesTemplate): + GRSO = Tensor(0.0) + DRSO = Tensor(0.0) + GWSO = Tensor(0.0) def initialize( - self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | torch.Size | None = None, ) -> None: - """Initialize the storage organ dynamics model. - - :param day: start date of the simulation - :param kiosk: variable kiosk of this PCSE instance - :param parvalues: `ParameterProvider` object providing parameters as - key/value pairs - """ + """Initialize the storage organ dynamics model.""" self.kiosk = kiosk - self.params = self.Parameters(parvalues) + self.params = self.Parameters(parvalues, shape=shape) self.rates = self.RateVariables(kiosk, publish=["GRSO"]) - # INITIAL STATES - params = self.params - self.params_shape = _get_params_shape(params) - shape = self.params_shape + self._drso_zeros = torch.zeros(self.params.shape, dtype=self.dtype, device=self.device) # Initial storage organ biomass - TDWI = _broadcast_to(params.TDWI, shape, dtype=self.dtype, device=self.device) - SPA = _broadcast_to(params.SPA, shape, dtype=self.dtype, device=self.device) - FO = _broadcast_to(self.kiosk["FO"], shape, dtype=self.dtype, device=self.device) - FR = _broadcast_to(self.kiosk["FR"], shape, dtype=self.dtype, device=self.device) + TDWI = self.params.TDWI + SPA = self.params.SPA + FO = self.kiosk["FO"] + FR = self.kiosk["FR"] WSO = (TDWI * (1 - FR)) * FO - DWSO = torch.zeros(shape, dtype=self.dtype, device=self.device) + DWSO = self._drso_zeros TWSO = WSO + DWSO # Initial Pod Area Index PAI = WSO * SPA @@ -199,15 +149,13 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None rates = self.rates k = self.kiosk - FO = _broadcast_to(k["FO"], self.params_shape, dtype=self.dtype, device=self.device) - ADMI = _broadcast_to(k["ADMI"], self.params_shape, dtype=self.dtype, device=self.device) - REALLOC_SO = _broadcast_to( - k.get("REALLOC_SO", 0.0), self.params_shape, dtype=self.dtype, device=self.device - ) + FO = k["FO"] + ADMI = k["ADMI"] + REALLOC_SO = k.get("REALLOC_SO", self._drso_zeros) # Growth/death rate organs rates.GRSO = ADMI * FO - rates.DRSO = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + rates.DRSO = self._drso_zeros rates.GWSO = rates.GRSO - rates.DRSO + REALLOC_SO @prepare_states @@ -222,7 +170,7 @@ def integrate(self, day: datetime.date = None, delt=1.0) -> None: rates = self.rates states = self.states - SPA = _broadcast_to(params.SPA, self.params_shape, dtype=self.dtype, device=self.device) + SPA = params.SPA # Stem biomass (living, dead, total) states.WSO = states.WSO + rates.GWSO diff --git a/tests/physical_models/crop/test_storage_organ_dynamics.py b/tests/physical_models/crop/test_storage_organ_dynamics.py index 4a20334..6352b42 100644 --- a/tests/physical_models/crop/test_storage_organ_dynamics.py +++ b/tests/physical_models/crop/test_storage_organ_dynamics.py @@ -22,7 +22,7 @@ # but not the correctness of its results (except when used within Wofost72_PP). -def _prepare_common_storage_inputs(test_data_url, device, meteo_range_checks=True): +def _prepare_common_storage_inputs(test_data_url, meteo_range_checks=True): # prepare model input test_data = get_test_data(test_data_url) crop_model_params = ["TDWI", "SPA"] @@ -31,9 +31,7 @@ def _prepare_common_storage_inputs(test_data_url, device, meteo_range_checks=Tru weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input( - test_data, crop_model_params, meteo_range_checks=meteo_range_checks, device=device - ) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=meteo_range_checks) # Patch missing states for state in external_states: @@ -51,12 +49,12 @@ def _prepare_common_storage_inputs(test_data_url, device, meteo_range_checks=Tru if "SPA" not in crop_model_params_provider: crop_model_params_provider.set_override( "SPA", - torch.tensor(0.01, dtype=torch.float64, device=device), + torch.tensor(0.01, dtype=torch.float64), check=False, ) if "TDWI" not in crop_model_params_provider: crop_model_params_provider.set_override( - "TDWI", torch.tensor(20.0, dtype=torch.float64, device=device), check=False + "TDWI", torch.tensor(20.0, dtype=torch.float64), check=False ) return ( @@ -79,7 +77,7 @@ def get_test_diff_storage_model(device: str = "cpu"): weather_data_provider, agro_management_inputs, external_states, - ) = _prepare_common_storage_inputs(test_data_url, device=device) + ) = _prepare_common_storage_inputs(test_data_url) return DiffStorageDynamics( copy.deepcopy(crop_model_params_provider), @@ -120,7 +118,6 @@ def forward(self, params_dict): self.agro_management_inputs, self.config, self.external_states, - device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -152,7 +149,7 @@ def test_storage_dynamics_with_testengine(self, test_data_url, device): weather_data_provider, agro_management_inputs, external_states, - ) = _prepare_common_storage_inputs(test_data_url, device=device) + ) = _prepare_common_storage_inputs(test_data_url) engine = EngineTestHelper( crop_model_params_provider, @@ -160,7 +157,6 @@ def test_storage_dynamics_with_testengine(self, test_data_url, device): agro_management_inputs, storage_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -181,13 +177,13 @@ def test_storage_dynamics_with_one_parameter_vector(self, param, device): weather_data_provider, agro_management_inputs, external_states, - ) = _prepare_common_storage_inputs(test_data_url, device=device, meteo_range_checks=False) + ) = _prepare_common_storage_inputs(test_data_url, meteo_range_checks=False) # Setting a vector (with one value) for the selected parameter if param == "TEMP": # Vectorize weather variable for (_, _), wdc in weather_data_provider.store.items(): - wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + wdc.TEMP = torch.ones(10, dtype=torch.float64, device=device) * wdc.TEMP else: # Broadcast all parameters to match the batch size of 10 for p_name in ["TDWI", "SPA"]: @@ -212,7 +208,6 @@ def test_storage_dynamics_with_one_parameter_vector(self, param, device): agro_management_inputs, storage_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -223,7 +218,6 @@ def test_storage_dynamics_with_one_parameter_vector(self, param, device): agro_management_inputs, storage_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -250,7 +244,7 @@ def test_storage_dynamics_with_different_parameter_values(self, param, delta, de weather_data_provider, agro_management_inputs, external_states, - ) = _prepare_common_storage_inputs(test_data_url, device=device) + ) = _prepare_common_storage_inputs(test_data_url) # Setting a vector with multiple values for the selected parameter test_value = crop_model_params_provider[param] @@ -282,7 +276,6 @@ def test_storage_dynamics_with_different_parameter_values(self, param, delta, de agro_management_inputs, storage_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -302,7 +295,7 @@ def test_storage_dynamics_with_multiple_parameter_vectors(self, device): weather_data_provider, agro_management_inputs, external_states, - ) = _prepare_common_storage_inputs(test_data_url, device=device) + ) = _prepare_common_storage_inputs(test_data_url) # Setting a vector (with one value) for the TDWI and SPA parameters for param in ("TDWI", "SPA"): @@ -319,7 +312,6 @@ def test_storage_dynamics_with_multiple_parameter_vectors(self, device): agro_management_inputs, storage_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -339,7 +331,7 @@ def test_storage_dynamics_with_multiple_parameter_arrays(self, device): weather_data_provider, agro_management_inputs, external_states, - ) = _prepare_common_storage_inputs(test_data_url, device=device, meteo_range_checks=False) + ) = _prepare_common_storage_inputs(test_data_url, meteo_range_checks=False) # Setting an array with arbitrary shape (and one value) for param in ("TDWI", "SPA"): @@ -347,7 +339,7 @@ def test_storage_dynamics_with_multiple_parameter_arrays(self, device): crop_model_params_provider.set_override(param, repeated, check=False) for (_, _), wdc in weather_data_provider.store.items(): - wdc.TEMP = torch.ones((30, 5), dtype=torch.float64) * wdc.TEMP + wdc.TEMP = torch.ones((30, 5), dtype=torch.float64, device=device) * wdc.TEMP engine = EngineTestHelper( crop_model_params_provider, @@ -355,7 +347,6 @@ def test_storage_dynamics_with_multiple_parameter_arrays(self, device): agro_management_inputs, storage_dynamics_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -375,7 +366,7 @@ def test_storage_dynamics_with_incompatible_parameter_vectors(self): weather_data_provider, agro_management_inputs, external_states, - ) = _prepare_common_storage_inputs(test_data_url, device="cpu") + ) = _prepare_common_storage_inputs(test_data_url) # Setting a vector (with one value) for the TDWI and SPA parameters, # but with different lengths @@ -386,14 +377,13 @@ def test_storage_dynamics_with_incompatible_parameter_vectors(self): "SPA", crop_model_params_provider["SPA"].repeat(5), check=False ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): EngineTestHelper( crop_model_params_provider, weather_data_provider, agro_management_inputs, storage_dynamics_config, external_states, - device="cpu", ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) From 85cce6614b2d0c88a8a62dbbbfb8bc7f86cb3de0 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 26 Feb 2026 12:12:13 +0100 Subject: [PATCH 4/9] Clean --- src/diffwofost/physical_models/crop/storage_organ_dynamics.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py index 250bd5a..b2a7e8c 100644 --- a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py +++ b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py @@ -80,8 +80,6 @@ class WOFOST_Storage_Organ_Dynamics(SimulationObject): | WSO | TDWI | """ - params_shape = None # Shape of the parameters tensors - @property def device(self): """Get device from ComputeConfig.""" From b695b36f71b1bf94d568e75307db0223dac89497 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:12:43 +0100 Subject: [PATCH 5/9] Update src/diffwofost/physical_models/crop/storage_organ_dynamics.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../physical_models/crop/storage_organ_dynamics.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py index b2a7e8c..3bc77e6 100644 --- a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py +++ b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py @@ -112,7 +112,18 @@ def initialize( parvalues: ParameterProvider, shape: tuple | torch.Size | None = None, ) -> None: - """Initialize the storage organ dynamics model.""" + """Initialize the storage organ dynamics model. + + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. + """ self.kiosk = kiosk self.params = self.Parameters(parvalues, shape=shape) self.rates = self.RateVariables(kiosk, publish=["GRSO"]) From fe1d4556ef360a12a0ebe1305ddd8eaa69d881da Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:13:00 +0100 Subject: [PATCH 6/9] Update src/diffwofost/physical_models/crop/storage_organ_dynamics.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/storage_organ_dynamics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py index 3bc77e6..dd1e7cb 100644 --- a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py +++ b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py @@ -179,7 +179,6 @@ def integrate(self, day: datetime.date = None, delt=1.0) -> None: rates = self.rates states = self.states - SPA = params.SPA # Stem biomass (living, dead, total) states.WSO = states.WSO + rates.GWSO From 6a7d035a3999a0512e3dce3e899da39c5dd1ba5f Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:13:08 +0100 Subject: [PATCH 7/9] Update src/diffwofost/physical_models/crop/storage_organ_dynamics.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/storage_organ_dynamics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py index dd1e7cb..a69152e 100644 --- a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py +++ b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py @@ -186,4 +186,4 @@ def integrate(self, day: datetime.date = None, delt=1.0) -> None: states.TWSO = states.WSO + states.DWSO # Calculate Pod Area Index (SAI) - states.PAI = states.WSO * SPA + states.PAI = states.WSO * params.SPA From e0edcfae2ab250f458801cbfda01eaef4aaebf86 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:13:32 +0100 Subject: [PATCH 8/9] Update tests/physical_models/crop/test_storage_organ_dynamics.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- tests/physical_models/crop/test_storage_organ_dynamics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/physical_models/crop/test_storage_organ_dynamics.py b/tests/physical_models/crop/test_storage_organ_dynamics.py index 6352b42..55dcd13 100644 --- a/tests/physical_models/crop/test_storage_organ_dynamics.py +++ b/tests/physical_models/crop/test_storage_organ_dynamics.py @@ -141,7 +141,7 @@ class TestStorageOrganDynamics: ] @pytest.mark.parametrize("test_data_url", storage_dynamics_data_urls) - def test_storage_dynamics_with_testengine(self, test_data_url, device): + def test_storage_dynamics_with_testengine(self, test_data_url): """EngineTestHelper and not Engine because it allows to specify `external_states`.""" ( test_data, From 93a082d41f90f56bd56df808a10141a1b036c6cd Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 2 Mar 2026 10:17:20 +0100 Subject: [PATCH 9/9] Fix ruff --- src/diffwofost/physical_models/crop/storage_organ_dynamics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py index a69152e..6845de1 100644 --- a/src/diffwofost/physical_models/crop/storage_organ_dynamics.py +++ b/src/diffwofost/physical_models/crop/storage_organ_dynamics.py @@ -123,7 +123,7 @@ def initialize( all parameter sets (crop, soil, site) as key/value. The values are arrays or scalars. See PCSE documentation for details. shape (tuple | torch.Size | None): Target shape for the state and rate variables. - """ + """ self.kiosk = kiosk self.params = self.Parameters(parvalues, shape=shape) self.rates = self.RateVariables(kiosk, publish=["GRSO"]) @@ -179,7 +179,6 @@ def integrate(self, day: datetime.date = None, delt=1.0) -> None: rates = self.rates states = self.states - # Stem biomass (living, dead, total) states.WSO = states.WSO + rates.GWSO states.DWSO = states.DWSO + rates.DRSO