Thank you for the great implementation !
The specification of torch.compile is enabled by adding the statement "@torch.compile" just before every forward() function in modeling_retnet.py.

By executing torch.compile, the training speed will be about 5 times faster.
Rather than reporting a bug, it's just information sharing. I wasn't sure if I should post it here, but just in case.