Skip to content

Add tensor parameter containers#75

Merged
fnattino merged 24 commits intomainfrom
tensor-param-template
Feb 11, 2026
Merged

Add tensor parameter containers#75
fnattino merged 24 commits intomainfrom
tensor-param-template

Conversation

@fnattino
Copy link
Collaborator

@fnattino fnattino commented Jan 15, 2026

relates #25

@fnattino
Copy link
Collaborator Author

This PR adds:

  • A new parameter template, to be used in the models. This allows to automatically broadcast all parameters to the same shapes, which can be useful in the following scenarios:
    • one parameter is a tensor and others are scalar: the scalar parameters are automatically broadcasted to the tensor shape
    • parameters are scalar, but the driving variables (i.e. the weather data) is a tensor: not yet implemented, but I imagine this to happen when instantiating a parameter object, which can take a shape argument.
  • A new Tensor trait, that can be used to define parameters that are expected to be tensor. This simplifies the definition of parameters/states/rates containers in the models, and it allows to define the expected dtype for a given parameter/state/rate variable. In addition, input parameters are automatically casted into tensors, if they are not such.

What do you think @SarahAlidoost and @SCiarella ? I think this can simplify some expressions in the models - I imagine a lot of the _broadcast_to can be removed due to the template already carrying out the broadcasting.

@SCiarella
Copy link
Collaborator

Thanks @fnattino, this looks fantastic 🚀

I really like the template to automatically broadcast to the correct shape and device at the beginning, because right now we are doing it quite a lot of times in the integration loops.

Ideally, it would be nice to remove all the calls to _broadcast_to, but right now we have to check in each module if the state variables, the rates, the kiosk, and the parameters all have the same shape/device. Can we later add templates for everything and have the engine do the necessary shape/device check at initialization?

@SarahAlidoost
Copy link
Collaborator

This PR adds:

  • A new parameter template, to be used in the models. This allows to automatically broadcast all parameters to the same shapes, which can be useful in the following scenarios:

    • one parameter is a tensor and others are scalar: the scalar parameters are automatically broadcasted to the tensor shape
    • parameters are scalar, but the driving variables (i.e. the weather data) is a tensor: not yet implemented, but I imagine this to happen when instantiating a parameter object, which can take a shape argument.
  • A new Tensor trait, that can be used to define parameters that are expected to be tensor. This simplifies the definition of parameters/states/rates containers in the models, and it allows to define the expected dtype for a given parameter/state/rate variable. In addition, input parameters are automatically casted into tensors, if they are not such.

One thing is the naming "Tensor". If I understood correctly by looking at the class definition, the init function doesn't do anything related to a tensor, it is only a type and subclass of TraitType, right? I found it a bit confusing when for example we do a = Tensor(0.0) because I assume a would be a tensor like torch.Tensor or tf.Tensor. Can we rename it to something else?

What do you think @SarahAlidoost and @SCiarella ? I think this can simplify some expressions in the models - I imagine a lot of the _broadcast_to can be removed due to the template already carrying out the broadcasting.

It is Awesome! 🥇 Thanks. I like how things get simpler and cleaner. Just one comment about naming, see above.

@fnattino
Copy link
Collaborator Author

fnattino commented Jan 19, 2026

Thank you @SCiarella and @SarahAlidoost for the useful feedback!

@SCiarella :

Ideally, it would be nice to remove all the calls to _broadcast_to, but right now we have to check in each module if the state variables, the rates, the kiosk, and the parameters all have the same shape/device. Can we later add templates for everything and have the engine do the necessary shape/device check at initialization?

Indeed, I think it's a good idea to also add similar containers for states and rates, so all variables are initialized with the correct shape and device!

@SarahAlidoost:

One thing is the naming "Tensor". If I understood correctly by looking at the class definition, the init function doesn't do anything related to a tensor, it is only a type and subclass of TraitType, right? I found it a bit confusing when for example we do a = Tensor(0.0) because I assume a would be a tensor like torch.Tensor or tf.Tensor. Can we rename it to something else?

My idea was to use Tensor in order to define variables that are expected to be tensors, in a similar fashion in which pcse has pcse.traitlets.Float or pcse.traitlets.Bool for floats and booleans, respectively. Right now, all the variables expeced to be tensors were marked as generic Any. Variables that are defined as Tensor are automatically checked to be of torch.Tensor type or casted into such type via the validate method, so for instance:

import torch
from diffwofost.physical_models.base import TensorParamTemplate
from diffwofost.physical_models.traitlets import Tensor

class Parameters(TensorParamTemplate):
    A = Tensor(0.)
    B = Tensor(0, dtype=int)

# Parameters A and B are casted into tensors
params = Parameters(dict(A=0., B=0))

params.A
# tensor(0., dtype=torch.float64)

params.B
# tensor(0)

@fnattino fnattino marked this pull request as ready for review February 6, 2026 19:09
@fnattino
Copy link
Collaborator Author

fnattino commented Feb 6, 2026

Hi @SarahAlidoost @SCiarella this is now ready to be reviewed. I know it's quite some changes, so I tried to leave as many comments as possible to facilitate the review.

Copy link
Collaborator

@SCiarella SCiarella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fnattino, the changes look great to me, and I think we can merge once Sarah approves.

I have left some comments for enhancements/opening new issues.

The tests run smoothly on cpu, but there will be some problems on GPU for the issues mentioned in #86, due to the definition of the engine. We should take care of the GPU issues in #86 while solving the conflicts.

I would merge this as it is 🚀

Copy link
Collaborator

@SarahAlidoost SarahAlidoost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fnattino thanks! looks awesome! 🥇 nice to see that the models got cleaned using this implementation 👍 . In addition to Simone's comments, I left a few minor comments. Also, can you please add this to the api_reference.md:

## **Other classes (for developers)**

::: diffwofost.physical_models.base.states_rates.TensorStatesTemplate

::: diffwofost.physical_models.base.states_rates.TensorRatesTemplate

::: diffwofost.physical_models.base.states_rates.TensorParamTemplate

::: diffwofost.physical_models.base.states_rates.TensorContainer

::: diffwofost.physical_models.traitlets.Tensor

fnattino and others added 10 commits February 11, 2026 09:40
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
@fnattino
Copy link
Collaborator Author

Thank you both for the careful review! I am now waiting for all tests to pass, and, if successful, will merge afterwards.

@fnattino fnattino moved this from In progress to In review in DeltaCrop: epic1 Feb 11, 2026
@sonarqubecloud
Copy link

@fnattino fnattino merged commit 0d3aaf7 into main Feb 11, 2026
11 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in DeltaCrop: epic1 Feb 11, 2026
@fnattino fnattino deleted the tensor-param-template branch February 11, 2026 10:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants