[JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets #15199
lint.yml
on: pull_request
PyTorch C++
33s
PyTorch Python
2m 25s
JAX C++
34s
JAX Python
31s