Skip to content

[JAX] Optimize materialized attn mask to consume O(N) memory instead of O(N*N) #2700

@KshitijLakhani

Description

@KshitijLakhani

Is your feature request related to a problem? Please describe.

The attn mask created in _segment_ids_pos_to_seqlens_offsets() consumes O(N*N) memory which can cause OOM issues for larger contexts.

Describe the solution you'd like

Instead of creating a N*N mask, this masking action in _segment_ids_pos_to_seqlens_offsets() can be achieved via other JAX transformations that are O(N)
Ideally, these should pass for CP and non-CP cases

Metadata

Metadata

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions