Is your feature request related to a problem? Please describe.
The attn mask created in _segment_ids_pos_to_seqlens_offsets() consumes O(N*N) memory which can cause OOM issues for larger contexts.
Describe the solution you'd like
Instead of creating a N*N mask, this masking action in _segment_ids_pos_to_seqlens_offsets() can be achieved via other JAX transformations that are O(N)
Ideally, these should pass for CP and non-CP cases