From 4a2cc586b2ec9e78ecf5fd490327c5c0f5d40c60 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 15 Jan 2026 14:24:35 +0100 Subject: [PATCH 1/4] Diffvectorize --- docs/api_reference.md | 2 + .../physical_models/crop/respiration.py | 160 ++++++ tests/physical_models/conftest.py | 1 + .../physical_models/crop/test_respiration.py | 466 ++++++++++++++++++ 4 files changed, 629 insertions(+) create mode 100644 src/diffwofost/physical_models/crop/respiration.py create mode 100644 tests/physical_models/crop/test_respiration.py diff --git a/docs/api_reference.md b/docs/api_reference.md index ed8303e..6381e93 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -16,6 +16,8 @@ hide: ::: diffwofost.physical_models.crop.partitioning.DVS_Partitioning +::: diffwofost.physical_models.crop.respiration.WOFOST_Maintenance_Respiration + ## **Utility (under development)** ::: diffwofost.physical_models.config.Configuration diff --git a/src/diffwofost/physical_models/crop/respiration.py b/src/diffwofost/physical_models/crop/respiration.py new file mode 100644 index 0000000..77f7fe1 --- /dev/null +++ b/src/diffwofost/physical_models/crop/respiration.py @@ -0,0 +1,160 @@ +"""Maintenance respiration for the WOFOST crop model.""" + +import datetime +import torch +from pcse.base import ParamTemplate +from pcse.base import RatesTemplate +from pcse.base import SimulationObject +from pcse.base.parameter_providers import ParameterProvider +from pcse.base.variablekiosk import VariableKiosk +from pcse.base.weather import WeatherDataContainer +from pcse.decorators import prepare_rates +from pcse.traitlets import Any +from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.utils import AfgenTrait +from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_drv +from diffwofost.physical_models.utils import _get_params_shape + + +class WOFOST_Maintenance_Respiration(SimulationObject): + """Maintenance respiration in WOFOST. + + WOFOST calculates the maintenance respiration as proportional to the dry + weights of the plant organs to be maintained, where each plant organ can be + assigned a different maintenance coefficient. Multiplying organ weight + with the maintenance coeffients yields the relative maintenance respiration + (`RMRES`) which is than corrected for senescence (parameter `RFSETB`). Finally, + the actual maintenance respiration rate is calculated using the daily mean + temperature, assuming a relative increase for each 10 degrees increase + in temperature as defined by `Q10`. + + **Simulation parameters** (provide in cropdata dictionary) + + | Name | Description | Type | Unit | + |--------|---------------------------------------------------------- |------|------------------| + | Q10 | Relative increase in maintenance respiration rate with | SCr | - | + | | each 10 degrees increase in temperature | | - | + | RMR | Relative maintenance respiration rate for roots | SCr | kg CH₂O kg⁻¹ d⁻¹ | + | RMS | Relative maintenance respiration rate for stems | SCr | kg CH₂O kg⁻¹ d⁻¹ | + | RML | Relative maintenance respiration rate for leaves | SCr | kg CH₂O kg⁻¹ d⁻¹ | + | RMO | Relative maintenance respiration rate for storage organs | SCr | kg CH₂O kg⁻¹ d⁻¹ | + | RFSETB | Reduction factor for senescence | TCr | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |-------|--------------------------------------------|----|-------------------| + | PMRES | Potential maintenance respiration rate | N | kg CH₂O ha⁻¹ d⁻¹ | + + **Signals send or handled** + + None + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|-------------------------------------|--------------------------------|-----------| + | DVS | Crop development stage | DVS_Phenology | - | + | WRT | Dry weight of living roots | WOFOST_Root_Dynamics | kg ha⁻¹ | + | WST | Dry weight of living stems | WOFOST_Stem_Dynamics | kg ha⁻¹ | + | WLV | Dry weight of living leaves | WOFOST_Leaf_Dynamics | kg ha⁻¹ | + | WSO | Dry weight of living storage organs | WOFOST_Storage_Organ_Dynamics | kg ha⁻¹ | + + **Outputs** + + | Name | Description | Pbl | Unit | + |-------|--------------------------------------------|----|---------------------| + | PMRES | Potential maintenance respiration rate | N | kg CH₂O ha⁻¹ d⁻¹ | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|------------------------------------------| + | PMRES | Q10, RMR, RML, RMS, RMO, RFSETB | + """ + + params_shape = None # Shape of the parameters tensors + + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + + class Parameters(ParamTemplate): + Q10 = Any() + RMR = Any() + RML = Any() + RMS = Any() + RMO = Any() + RFSETB = AfgenTrait() + + class RateVariables(RatesTemplate): + PMRES = Any() + + def __init__(self, kiosk, publish=None): + self.PMRES = torch.tensor( + 0.0, dtype=ComputeConfig.get_dtype(), device=ComputeConfig.get_device() + ) + super().__init__(kiosk, publish=publish) + + def initialize(self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider): + """Initialize the maintenance respiration module. + + Args: + day: Start date of the simulation + kiosk: Variable kiosk of this PCSE instance + parvalues: ParameterProvider object providing parameters as key/value pairs + """ + self.params = self.Parameters(parvalues) + self.rates = self.RateVariables(kiosk, publish=["PMRES"]) + self.kiosk = kiosk + self.params_shape = _get_params_shape(self.params) + + @prepare_rates + def calc_rates(self, day: datetime.date, drv: WeatherDataContainer): + """Calculate maintenance respiration rates. + + Args: + day: Current date + drv: Weather data for the current day + """ + p = self.params + kk = self.kiosk + r = self.rates + + Q10 = _broadcast_to(p.Q10, self.params_shape, dtype=self.dtype, device=self.device) + RMR = _broadcast_to(p.RMR, self.params_shape, dtype=self.dtype, device=self.device) + RML = _broadcast_to(p.RML, self.params_shape, dtype=self.dtype, device=self.device) + RMS = _broadcast_to(p.RMS, self.params_shape, dtype=self.dtype, device=self.device) + RMO = _broadcast_to(p.RMO, self.params_shape, dtype=self.dtype, device=self.device) + + WRT = _broadcast_to(kk["WRT"], self.params_shape, dtype=self.dtype, device=self.device) + WLV = _broadcast_to(kk["WLV"], self.params_shape, dtype=self.dtype, device=self.device) + WST = _broadcast_to(kk["WST"], self.params_shape, dtype=self.dtype, device=self.device) + WSO = _broadcast_to(kk["WSO"], self.params_shape, dtype=self.dtype, device=self.device) + DVS = _broadcast_to(kk["DVS"], self.params_shape, dtype=self.dtype, device=self.device) + + TEMP = _get_drv(drv.TEMP, self.params_shape, dtype=self.dtype, device=self.device) + + RMRES = RMR * WRT + RML * WLV + RMS * WST + RMO * WSO + RMRES = RMRES * p.RFSETB(DVS) + TEFF = Q10 ** ((TEMP - 25.0) / 10.0) + PMRES = RMRES * TEFF + + # No maintenance respiration before emergence (DVS < 0). + r.PMRES = torch.where(DVS < 0, torch.zeros_like(PMRES), PMRES) + + def __call__(self, day: datetime.date, drv: WeatherDataContainer): + """Calculate and return maintenance respiration (PMRES).""" + self.calc_rates(day, drv) + return self.rates.PMRES + + def integrate(self, day: datetime.date, delt: float = 1.0): + """No state variables to integrate for this module.""" + return diff --git a/tests/physical_models/conftest.py b/tests/physical_models/conftest.py index 862c62e..9f830b6 100644 --- a/tests/physical_models/conftest.py +++ b/tests/physical_models/conftest.py @@ -14,6 +14,7 @@ "phenology", "partitioning", "assimilation", + "respiration", ] FILE_NAMES = [ f"test_{model_name}_wofost72_{i:02d}.yaml" for model_name in model_names for i in range(1, 45) diff --git a/tests/physical_models/crop/test_respiration.py b/tests/physical_models/crop/test_respiration.py new file mode 100644 index 0000000..4345929 --- /dev/null +++ b/tests/physical_models/crop/test_respiration.py @@ -0,0 +1,466 @@ +import copy +import warnings +from unittest.mock import patch +import pytest +import torch +from pcse.models import Wofost72_PP +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.crop.respiration import WOFOST_Maintenance_Respiration +from diffwofost.physical_models.utils import EngineTestHelper +from diffwofost.physical_models.utils import calculate_numerical_grad +from diffwofost.physical_models.utils import get_test_data +from diffwofost.physical_models.utils import prepare_engine_input +from .. import phy_data_folder + +respiration_config = Configuration( + CROP=WOFOST_Maintenance_Respiration, + OUTPUT_VARS=["PMRES"], +) + + +def get_test_diff_respiration_model(device: str = "cpu"): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + return DiffRespiration( + copy.deepcopy(crop_model_params_provider), + weather_data_provider, + agro_management_inputs, + respiration_config, + copy.deepcopy(external_states), + device=device, + ) + + +class DiffRespiration(torch.nn.Module): + def __init__( + self, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config, + external_states, + device: str = "cpu", + ): + super().__init__() + self.crop_model_params_provider = crop_model_params_provider + self.weather_data_provider = weather_data_provider + self.agro_management_inputs = agro_management_inputs + self.config = config + self.external_states = external_states + self.device = device + + def forward(self, params_dict): + for name, value in params_dict.items(): + self.crop_model_params_provider.set_override(name, value, check=False) + + engine = EngineTestHelper( + self.crop_model_params_provider, + self.weather_data_provider, + self.agro_management_inputs, + self.config, + self.external_states, + device=self.device, + ) + engine.run_till_terminate() + results = engine.get_output() + + return {"PMRES": torch.stack([item["PMRES"] for item in results])} + + +class TestRespiration: + respiration_data_urls = [ + f"{phy_data_folder}/test_respiration_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + wofost72_data_urls = [ + f"{phy_data_folder}/test_potentialproduction_wofost72_{i:02d}.yaml" + for i in range(1, 45) # there are 44 test files + ] + + @pytest.mark.parametrize("test_data_url", respiration_data_urls) + def test_respiration_with_testengine(self, test_data_url, device): + """EngineTestHelper (not Engine) allows forcing `external_states` from YAML.""" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + abs(reference[var] - model_cpu[var]) < precision + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize("param", ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB", "TEMP"]) + def test_respiration_with_one_parameter_vector(self, param, device): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + if param == "TEMP": + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + with pytest.raises(ValueError): + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + _ = engine.get_output() + return + + if param == "RFSETB": + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + all(abs(reference[var] - model_cpu[var]) < precision) + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize( + "param,delta", + [ + ("Q10", 0.2), + ("RMR", 0.002), + ("RML", 0.002), + ("RMS", 0.002), + ("RMO", 0.002), + ], + ) + def test_respiration_with_different_parameter_values(self, param, delta, device): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + test_value = crop_model_params_provider[param] + param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) + crop_model_params_provider.set_override(param, param_vec, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + abs(reference[var] - model_cpu[var][-1]) < precision + for var, precision in expected_precision.items() + ) + + def test_respiration_with_multiple_parameter_vectors(self, device): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + for param in ("Q10", "RMR", "RML", "RMS", "RMO"): + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + crop_model_params_provider.set_override( + "RFSETB", crop_model_params_provider["RFSETB"].repeat(10, 1), check=False + ) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + + def test_respiration_with_multiple_parameter_arrays(self, device): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + for param in ("Q10", "RMR", "RML", "RMS", "RMO"): + repeated = crop_model_params_provider[param].broadcast_to((30, 5)) + crop_model_params_provider.set_override(param, repeated, check=False) + crop_model_params_provider.set_override( + "RFSETB", crop_model_params_provider["RFSETB"].repeat(30, 5, 1), check=False + ) + + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones((30, 5), dtype=torch.float64) * wdc.TEMP + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + torch.all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + assert all(model[var].shape == (30, 5) for var in expected_precision.keys()) + + def test_respiration_with_incompatible_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + + crop_model_params_provider.set_override( + "RMR", crop_model_params_provider["RMR"].repeat(10), check=False + ) + crop_model_params_provider.set_override( + "RML", crop_model_params_provider["RML"].repeat(5), check=False + ) + + with pytest.raises(AssertionError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device="cpu", + ) + + def test_respiration_with_incompatible_weather_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + + crop_model_params_provider.set_override( + "RMR", crop_model_params_provider["RMR"].repeat(10), check=False + ) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(5, dtype=torch.float64) * wdc.TEMP + + with pytest.raises(ValueError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + respiration_config, + external_states, + device="cpu", + ) + + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) + def test_wofost_pp_with_leaf_dynamics(self, test_data_url): + # prepare model input + test_data = get_test_data(test_data_url) + crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params) + ) + + # get expected results from YAML test data + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + with patch("pcse.crop.wofost72.MaintenanceRespiration", WOFOST_Maintenance_Respiration): + model = Wofost72_PP( + crop_model_params_provider, weather_data_provider, agro_management_inputs + ) + model.run_till_terminate() + actual_results = model.get_output() + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + abs(reference[var] - model[var]) < precision + for var, precision in expected_precision.items() + ) + + +class TestDiffRespirationGradients: + """Parametrized tests for gradient calculations in maintenance respiration.""" + + param_configs = { + "single": { + "Q10": (2.0, torch.float64), + "RMR": (0.015, torch.float64), + "RML": (0.03, torch.float64), + "RMS": (0.02, torch.float64), + "RMO": (0.01, torch.float64), + }, + "tensor": { + "Q10": ([1.5, 2.0, 2.5], torch.float64), + "RMR": ([0.01, 0.015, 0.02], torch.float64), + "RML": ([0.02, 0.03, 0.04], torch.float64), + "RMS": ([0.01, 0.02, 0.03], torch.float64), + "RMO": ([0.005, 0.01, 0.02], torch.float64), + }, + } + + @pytest.mark.parametrize("param_name", ["Q10", "RMR", "RML", "RMS", "RMO"]) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, config_type, device): + model = get_test_diff_respiration_model(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + output = model({param_name: param}) + loss = output["PMRES"].sum() + + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + assert grads is not None, f"Gradients for {param_name} should not be None" + + param.grad = None + loss.backward() + grad_backward = param.grad + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), ( + f"Forward and backward gradients for {param_name} should match" + ) + + @pytest.mark.parametrize("param_name", ["Q10", "RMR", "RML", "RMS", "RMO"]) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, config_type, device): + value, _ = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) + + numerical_grad = calculate_numerical_grad( + lambda: get_test_diff_respiration_model(device=device), + param_name, + param, + "PMRES", + ) + + model = get_test_diff_respiration_model(device=device) + output = model({param_name: param}) + loss = output["PMRES"].sum() + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + torch.testing.assert_close( + numerical_grad.detach().cpu(), + grads.detach().cpu(), + rtol=1e-3, + atol=1e-3, + ) + + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}' with" + + f"respect to output 'PMRES' is zero: {grads.data}", + UserWarning, + ) From d54138c21df348235ebe88a46299f36212d1427c Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 12 Feb 2026 12:45:25 +0100 Subject: [PATCH 2/4] Fix tests --- .../physical_models/crop/respiration.py | 73 +++++++++---------- .../physical_models/crop/test_respiration.py | 44 ++++------- 2 files changed, 50 insertions(+), 67 deletions(-) diff --git a/src/diffwofost/physical_models/crop/respiration.py b/src/diffwofost/physical_models/crop/respiration.py index 77f7fe1..b1084e3 100644 --- a/src/diffwofost/physical_models/crop/respiration.py +++ b/src/diffwofost/physical_models/crop/respiration.py @@ -2,19 +2,18 @@ import datetime import torch -from pcse.base import ParamTemplate -from pcse.base import RatesTemplate from pcse.base import SimulationObject from pcse.base.parameter_providers import ParameterProvider from pcse.base.variablekiosk import VariableKiosk from pcse.base.weather import WeatherDataContainer from pcse.decorators import prepare_rates -from pcse.traitlets import Any +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorRatesTemplate from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.traitlets import Tensor from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to from diffwofost.physical_models.utils import _get_drv -from diffwofost.physical_models.utils import _get_params_shape class WOFOST_Maintenance_Respiration(SimulationObject): @@ -74,8 +73,6 @@ class WOFOST_Maintenance_Respiration(SimulationObject): | PMRES | Q10, RMR, RML, RMS, RMO, RFSETB | """ - params_shape = None # Shape of the parameters tensors - @property def device(self): """Get device from ComputeConfig.""" @@ -86,35 +83,35 @@ def dtype(self): """Get dtype from ComputeConfig.""" return ComputeConfig.get_dtype() - class Parameters(ParamTemplate): - Q10 = Any() - RMR = Any() - RML = Any() - RMS = Any() - RMO = Any() + class Parameters(TensorParamTemplate): + Q10 = Tensor(1) + RMR = Tensor(1) + RML = Tensor(1) + RMS = Tensor(1) + RMO = Tensor(1) RFSETB = AfgenTrait() - class RateVariables(RatesTemplate): - PMRES = Any() - - def __init__(self, kiosk, publish=None): - self.PMRES = torch.tensor( - 0.0, dtype=ComputeConfig.get_dtype(), device=ComputeConfig.get_device() - ) - super().__init__(kiosk, publish=publish) + class RateVariables(TensorRatesTemplate): + PMRES = Tensor(1) - def initialize(self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider): + def initialize( + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, + ): """Initialize the maintenance respiration module. Args: day: Start date of the simulation kiosk: Variable kiosk of this PCSE instance parvalues: ParameterProvider object providing parameters as key/value pairs + shape: Shape of the parameters tensors (optional) """ - self.params = self.Parameters(parvalues) - self.rates = self.RateVariables(kiosk, publish=["PMRES"]) + self.params = self.Parameters(parvalues, shape=shape) + self.rates = self.RateVariables(kiosk, publish=["PMRES"], shape=shape) self.kiosk = kiosk - self.params_shape = _get_params_shape(self.params) @prepare_rates def calc_rates(self, day: datetime.date, drv: WeatherDataContainer): @@ -128,19 +125,21 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer): kk = self.kiosk r = self.rates - Q10 = _broadcast_to(p.Q10, self.params_shape, dtype=self.dtype, device=self.device) - RMR = _broadcast_to(p.RMR, self.params_shape, dtype=self.dtype, device=self.device) - RML = _broadcast_to(p.RML, self.params_shape, dtype=self.dtype, device=self.device) - RMS = _broadcast_to(p.RMS, self.params_shape, dtype=self.dtype, device=self.device) - RMO = _broadcast_to(p.RMO, self.params_shape, dtype=self.dtype, device=self.device) - - WRT = _broadcast_to(kk["WRT"], self.params_shape, dtype=self.dtype, device=self.device) - WLV = _broadcast_to(kk["WLV"], self.params_shape, dtype=self.dtype, device=self.device) - WST = _broadcast_to(kk["WST"], self.params_shape, dtype=self.dtype, device=self.device) - WSO = _broadcast_to(kk["WSO"], self.params_shape, dtype=self.dtype, device=self.device) - DVS = _broadcast_to(kk["DVS"], self.params_shape, dtype=self.dtype, device=self.device) - - TEMP = _get_drv(drv.TEMP, self.params_shape, dtype=self.dtype, device=self.device) + Q10 = p.Q10 + RMR = p.RMR + RML = p.RML + RMS = p.RMS + RMO = p.RMO + + WRT = kk["WRT"] + WLV = kk["WLV"] + WST = kk["WST"] + WSO = kk["WSO"] + # [!] DVS needs to be broadcasted explicetly because it is used + # in torch.where and the kiosk does not format it correctly + DVS = _broadcast_to(kk["DVS"], p.shape, self.dtype, self.device) + + TEMP = _get_drv(drv.TEMP, p.shape, self.dtype, self.device) RMRES = RMR * WRT + RML * WLV + RMS * WST + RMO * WSO RMRES = RMRES * p.RFSETB(DVS) diff --git a/tests/physical_models/crop/test_respiration.py b/tests/physical_models/crop/test_respiration.py index 4345929..45cb941 100644 --- a/tests/physical_models/crop/test_respiration.py +++ b/tests/physical_models/crop/test_respiration.py @@ -18,7 +18,7 @@ ) -def get_test_diff_respiration_model(device: str = "cpu"): +def get_test_diff_respiration_model(): test_data_url = f"{phy_data_folder}/test_respiration_wofost72_01.yaml" test_data = get_test_data(test_data_url) crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] @@ -27,14 +27,13 @@ def get_test_diff_respiration_model(device: str = "cpu"): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) return DiffRespiration( copy.deepcopy(crop_model_params_provider), weather_data_provider, agro_management_inputs, respiration_config, copy.deepcopy(external_states), - device=device, ) @@ -46,7 +45,6 @@ def __init__( agro_management_inputs, config, external_states, - device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider @@ -54,7 +52,6 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states - self.device = device def forward(self, params_dict): for name, value in params_dict.items(): @@ -66,7 +63,6 @@ def forward(self, params_dict): self.agro_management_inputs, self.config, self.external_states, - device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -94,7 +90,7 @@ def test_respiration_with_testengine(self, test_data_url, device): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) engine = EngineTestHelper( crop_model_params_provider, @@ -102,7 +98,6 @@ def test_respiration_with_testengine(self, test_data_url, device): agro_management_inputs, respiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -130,9 +125,7 @@ def test_respiration_with_one_parameter_vector(self, param, device): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input( - test_data, crop_model_params, meteo_range_checks=False, device=device - ) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) if param == "TEMP": for (_, _), wdc in weather_data_provider.store.items(): @@ -144,7 +137,6 @@ def test_respiration_with_one_parameter_vector(self, param, device): agro_management_inputs, respiration_config, external_states, - device=device, ) engine.run_till_terminate() _ = engine.get_output() @@ -162,7 +154,6 @@ def test_respiration_with_one_parameter_vector(self, param, device): agro_management_inputs, respiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -197,7 +188,7 @@ def test_respiration_with_different_parameter_values(self, param, delta, device) weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) test_value = crop_model_params_provider[param] param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) @@ -209,7 +200,6 @@ def test_respiration_with_different_parameter_values(self, param, delta, device) agro_management_inputs, respiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -234,7 +224,7 @@ def test_respiration_with_multiple_parameter_vectors(self, device): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, device=device) + ) = prepare_engine_input(test_data, crop_model_params) for param in ("Q10", "RMR", "RML", "RMS", "RMO"): repeated = crop_model_params_provider[param].repeat(10) @@ -249,7 +239,6 @@ def test_respiration_with_multiple_parameter_vectors(self, device): agro_management_inputs, respiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -273,9 +262,7 @@ def test_respiration_with_multiple_parameter_arrays(self, device): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input( - test_data, crop_model_params, meteo_range_checks=False, device=device - ) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) for param in ("Q10", "RMR", "RML", "RMS", "RMO"): repeated = crop_model_params_provider[param].broadcast_to((30, 5)) @@ -293,7 +280,6 @@ def test_respiration_with_multiple_parameter_arrays(self, device): agro_management_inputs, respiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -327,14 +313,13 @@ def test_respiration_with_incompatible_parameter_vectors(self): "RML", crop_model_params_provider["RML"].repeat(5), check=False ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): EngineTestHelper( crop_model_params_provider, weather_data_provider, agro_management_inputs, respiration_config, external_states, - device="cpu", ) def test_respiration_with_incompatible_weather_parameter_vectors(self): @@ -361,11 +346,10 @@ def test_respiration_with_incompatible_weather_parameter_vectors(self): agro_management_inputs, respiration_config, external_states, - device="cpu", ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) - def test_wofost_pp_with_leaf_dynamics(self, test_data_url): + def test_wofost_pp_with_respiration(self, test_data_url): # prepare model input test_data = get_test_data(test_data_url) crop_model_params = ["Q10", "RMR", "RML", "RMS", "RMO", "RFSETB"] @@ -416,9 +400,9 @@ class TestDiffRespirationGradients: @pytest.mark.parametrize("param_name", ["Q10", "RMR", "RML", "RMS", "RMO"]) @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_gradients_forward_backward_match(self, param_name, config_type, device): - model = get_test_diff_respiration_model(device=device) + model = get_test_diff_respiration_model() value, dtype = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) output = model({param_name: param}) loss = output["PMRES"].sum() @@ -437,16 +421,16 @@ def test_gradients_forward_backward_match(self, param_name, config_type, device) @pytest.mark.parametrize("config_type", ["single", "tensor"]) def test_gradients_numerical(self, param_name, config_type, device): value, _ = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64)) numerical_grad = calculate_numerical_grad( - lambda: get_test_diff_respiration_model(device=device), + lambda: get_test_diff_respiration_model(), param_name, param, "PMRES", ) - model = get_test_diff_respiration_model(device=device) + model = get_test_diff_respiration_model() output = model({param_name: param}) loss = output["PMRES"].sum() grads = torch.autograd.grad(loss, param, retain_graph=True)[0] From ab44dcb5c7cf7432e6a1e3cf4276512f365aae25 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 16 Feb 2026 15:23:31 +0100 Subject: [PATCH 3/4] Update src/diffwofost/physical_models/crop/respiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/respiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/respiration.py b/src/diffwofost/physical_models/crop/respiration.py index b1084e3..0bbb5df 100644 --- a/src/diffwofost/physical_models/crop/respiration.py +++ b/src/diffwofost/physical_models/crop/respiration.py @@ -137,6 +137,7 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer): WSO = kk["WSO"] # [!] DVS needs to be broadcasted explicetly because it is used # in torch.where and the kiosk does not format it correctly + #TODO see #22 DVS = _broadcast_to(kk["DVS"], p.shape, self.dtype, self.device) TEMP = _get_drv(drv.TEMP, p.shape, self.dtype, self.device) From 8dc2f6b7527df9d82ea4b1daf504de5bda663f0c Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 16 Feb 2026 15:26:28 +0100 Subject: [PATCH 4/4] Fix --- .../physical_models/crop/respiration.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffwofost/physical_models/crop/respiration.py b/src/diffwofost/physical_models/crop/respiration.py index 0bbb5df..89d8518 100644 --- a/src/diffwofost/physical_models/crop/respiration.py +++ b/src/diffwofost/physical_models/crop/respiration.py @@ -84,15 +84,15 @@ def dtype(self): return ComputeConfig.get_dtype() class Parameters(TensorParamTemplate): - Q10 = Tensor(1) - RMR = Tensor(1) - RML = Tensor(1) - RMS = Tensor(1) - RMO = Tensor(1) + Q10 = Tensor(-99.0) + RMR = Tensor(-99.0) + RML = Tensor(-99.0) + RMS = Tensor(-99.0) + RMO = Tensor(-99.0) RFSETB = AfgenTrait() class RateVariables(TensorRatesTemplate): - PMRES = Tensor(1) + PMRES = Tensor(0.0) def initialize( self, @@ -110,7 +110,7 @@ def initialize( shape: Shape of the parameters tensors (optional) """ self.params = self.Parameters(parvalues, shape=shape) - self.rates = self.RateVariables(kiosk, publish=["PMRES"], shape=shape) + self.rates = self.RateVariables(kiosk, shape=shape) self.kiosk = kiosk @prepare_rates @@ -137,7 +137,7 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer): WSO = kk["WSO"] # [!] DVS needs to be broadcasted explicetly because it is used # in torch.where and the kiosk does not format it correctly - #TODO see #22 + # TODO see #22 DVS = _broadcast_to(kk["DVS"], p.shape, self.dtype, self.device) TEMP = _get_drv(drv.TEMP, p.shape, self.dtype, self.device)