With my implementation and an old version from others, I got this assertion error:

But if I print one of the m, v, grad,it will pass (for both my implementation and another one):

Note that I only transform the m and v to ndl.Tensor for once and use the .data in calculation, and I can pass the previous test_optim_sgd_z_memory_check_1.
Could anyone explain this? I'm quite confused why this would influence the count of tensors. Thanks for help :)
With my implementation and an old version from others, I got this assertion error:


But if I print one of the
m, v, grad,it will pass (for both my implementation and another one):Note that I only transform the
mandvtondl.Tensorfor once and use the.datain calculation, and I can pass the previoustest_optim_sgd_z_memory_check_1.Could anyone explain this? I'm quite confused why this would influence the count of tensors. Thanks for help :)