Current router kernels are present in common and callable from Pytorch side but not JAX. Need to support JAX router for either standalone use or later intergation to Maxtext moe layer.
These router kernels are CUDA kernels so using jax ffi lowering is sufficient.