Hello,
TL;DR
I have a couple of recommendations I'd like to discuss:
- Disable caching by default to avoid wrong summary upon module changes and to avoid leaking heavy references into memory
- Directly use models as keys for the cache directory.
- Eventually think of using only weak references in
ModelStatistics, but this seems like too much work at this point.
Explanations
I recently spent a good 5h using gc to understand why my neural network was never garbage collected after I had used torchinfo and deleted everything.
After these three steps:
create net -> compute summary -> delete summary and net
I no longer have any handle on my net and summary, and yet, it was never garbage collected!
I finally figured out that the culprit was _cached_forward_pass
|
_cached_forward_pass: dict[str, list[LayerInfo]] = {} |
Debugging was hard even with gc.get_referrers since LayerInfo has a very uninformative repr which obfuscates the nature of the object to anyone not familiar with internals.
Furthermore, I noticed that the keys to this dict are simple class __name__:
|
model_name = model.__class__.__name__ |
I think it may be more confusing than using plain reference to the model, and it could also create conflict for classes with the same
__name__ (think of unimportant names for classes created dynamically for example, or simply nested classes)
Finally, the cache may lead to incorrect output if the network changes before being "re-summarized" (in effect, it would use the same forward, and so the same summary).
All the best!
Élie
Hello,
TL;DR
I have a couple of recommendations I'd like to discuss:
ModelStatistics, but this seems like too much work at this point.Explanations
I recently spent a good 5h using
gcto understand why my neural network was never garbage collected after I had usedtorchinfoand deleted everything.After these three steps:
create net -> compute summary -> delete summary and netI no longer have any handle on my net and summary, and yet, it was never garbage collected!
I finally figured out that the culprit was
_cached_forward_passtorchinfo/torchinfo/torchinfo.py
Line 51 in e67e748
Debugging was hard even with
gc.get_referrerssinceLayerInfohas a very uninformativereprwhich obfuscates the nature of the object to anyone not familiar with internals.Furthermore, I noticed that the keys to this dict are simple class
__name__:torchinfo/torchinfo/torchinfo.py
Line 271 in e67e748
I think it may be more confusing than using plain reference to the model, and it could also create conflict for classes with the same
__name__(think of unimportant names for classes created dynamically for example, or simply nested classes)Finally, the cache may lead to incorrect output if the network changes before being "re-summarized" (in effect, it would use the same forward, and so the same summary).
All the best!
Élie