Conversation
|
This PR adds:
What do you think @SarahAlidoost and @SCiarella ? I think this can simplify some expressions in the models - I imagine a lot of the |
|
Thanks @fnattino, this looks fantastic 🚀 I really like the template to automatically broadcast to the correct shape and device at the beginning, because right now we are doing it quite a lot of times in the integration loops. Ideally, it would be nice to remove all the calls to |
One thing is the naming "
It is Awesome! 🥇 Thanks. I like how things get simpler and cleaner. Just one comment about naming, see above. |
|
Thank you @SCiarella and @SarahAlidoost for the useful feedback!
Indeed, I think it's a good idea to also add similar containers for states and rates, so all variables are initialized with the correct shape and device!
My idea was to use import torch
from diffwofost.physical_models.base import TensorParamTemplate
from diffwofost.physical_models.traitlets import Tensor
class Parameters(TensorParamTemplate):
A = Tensor(0.)
B = Tensor(0, dtype=int)
# Parameters A and B are casted into tensors
params = Parameters(dict(A=0., B=0))
params.A
# tensor(0., dtype=torch.float64)
params.B
# tensor(0) |
|
Hi @SarahAlidoost @SCiarella this is now ready to be reviewed. I know it's quite some changes, so I tried to leave as many comments as possible to facilitate the review. |
SCiarella
left a comment
There was a problem hiding this comment.
Thanks @fnattino, the changes look great to me, and I think we can merge once Sarah approves.
I have left some comments for enhancements/opening new issues.
The tests run smoothly on cpu, but there will be some problems on GPU for the issues mentioned in #86, due to the definition of the engine. We should take care of the GPU issues in #86 while solving the conflicts.
I would merge this as it is 🚀
There was a problem hiding this comment.
@fnattino thanks! looks awesome! 🥇 nice to see that the models got cleaned using this implementation 👍 . In addition to Simone's comments, I left a few minor comments. Also, can you please add this to the api_reference.md:
## **Other classes (for developers)**
::: diffwofost.physical_models.base.states_rates.TensorStatesTemplate
::: diffwofost.physical_models.base.states_rates.TensorRatesTemplate
::: diffwofost.physical_models.base.states_rates.TensorParamTemplate
::: diffwofost.physical_models.base.states_rates.TensorContainer
::: diffwofost.physical_models.traitlets.Tensor
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
…into tensor-param-template
|
Thank you both for the careful review! I am now waiting for all tests to pass, and, if successful, will merge afterwards. |
|



relates #25