|
42 | 42 |
|
43 | 43 | ###################################################################### |
44 | 44 | # Overview of Variable Length Attention |
45 | | -# ------------------------------------- |
| 45 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
46 | 46 | # |
47 | 47 | # In normal SDPA, sequences are expected to be a fixed length. In |
48 | 48 | # practice, this means that input tensors are often **padded** to the same |
49 | 49 | # length in a batch. However, this wastes both memory and compute through |
50 | 50 | # storing this padding and performing unnecessary computations. |
51 | | - |
52 | | -###################################################################### |
53 | 51 | # Variable length attention handles sequences of varying length by |
54 | 52 | # **packing** the tensors in a batch together and essentially collapsing |
55 | 53 | # the batch dimension. |
56 | 54 |
|
| 55 | +###################################################################### |
57 | 56 | # However, we still need to maintain the boundaries between documents. To |
58 | 57 | # do so, we compute cumulative sequence positions for query and key/value |
59 | 58 | # that mark the end of documents. In the diagram below, doc 1 is 7 tokens |
60 | | -# long, doc 2 is 10 tokens long, etc. so |
61 | | -# ``cu_seq_lens = [0, 7, 17, ...]``. |
62 | | - |
63 | | -###################################################################### |
64 | | -# Padding vs Packing Diagram |
65 | | - |
66 | | -# .. figure:: ../_static/img/varlen_diagram.png |
67 | | -# :alt: Padding vs Packing Diagram |
| 59 | +# long, doc 2 is 10 tokens long, etc. so ``cu_seq_lens = [0, 7, 17, ...]``. |
68 | 60 |
|
69 | 61 | ###################################################################### |
70 | 62 | # Note that ``NestedTensor`` is another way to enable |
|
74 | 66 |
|
75 | 67 | ###################################################################### |
76 | 68 | # Definition |
77 | | -# ~~~~~~~~~~ |
78 | | - |
| 69 | +# ---------- |
| 70 | +# |
79 | 71 | # Below is the definition of ``varlen_attn`` which returns the output |
80 | 72 | # tensor from the attention computation. |
81 | | - |
| 73 | +# |
82 | 74 | # .. code:: python |
83 | | - |
| 75 | +# |
84 | 76 | # def varlen_attn( |
85 | 77 | # query: torch.Tensor, |
86 | 78 | # key: torch.Tensor, |
|
93 | 85 | # return_aux: AuxRequest | None = None, |
94 | 86 | # ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
95 | 87 | # |
96 | | - |
97 | | -###################################################################### |
98 | 88 | # ``query``, ``key``, and ``value`` correspond to the ``q``, ``k``, and |
99 | 89 | # ``v`` of the packed input. ``cu_seq_q`` and ``cu_seq_k`` are the |
100 | 90 | # cumulative indices for query and key/value, respectively. These mark the |
101 | 91 | # logical boundaries that separate the documents in our input. ``max_q`` |
102 | 92 | # and ``max_k`` are the maximum sequence lengths of query and key, |
103 | 93 | # respectively. ``is_causal`` applies causal masking if set to True and |
104 | 94 | # ``return_aux`` specifies which auxiliary outputs to return (ie ``lse``). |
105 | | -# |
106 | 95 |
|
107 | 96 | ###################################################################### |
108 | 97 | # **Note on causal masking** |
109 | | - |
110 | 98 | # When ``is_causal`` is set to True, causal masking is applied which means |
111 | 99 | # that tokens can only attend to previous tokens. For bidirectional |
112 | 100 | # attention, set this flag to False. |
113 | | - |
114 | | -###################################################################### |
115 | | -# In torchtitan (PyTorch’s pretraining framework), we set |
| 101 | +# |
| 102 | +# In torchtitan (PyTorch's pretraining framework), we set |
116 | 103 | # ``is_causal = True`` uniformly to prevent the model from cheating and |
117 | | -# artifically driving the loss down too quickly. |
| 104 | +# artificially driving the loss down too quickly. |
118 | 105 |
|
119 | 106 |
|
120 | 107 | ###################################################################### |
121 | 108 | # Example |
122 | | -# ------- |
| 109 | +# ~~~~~~~ |
123 | 110 | # |
124 | 111 | # Let’s walk through a simple example of how we would use ``varlen_attn`` |
125 | 112 | # in the context of training a Transformer model. |
|
128 | 115 |
|
129 | 116 | ###################################################################### |
130 | 117 | # Creating Required Metadata for ``varlen_attn`` from Input Batches |
131 | | -# ~~~~~~~~~~~~~~~~~~~~~ |
132 | | - |
| 118 | +# ----------------------------------------------------------------- |
| 119 | +# |
133 | 120 | # Given an input batch, how would we construct the metadata that |
134 | 121 | # ``varlen_attn`` expects? More specifically, how do we calculate the |
135 | 122 | # cumulative sequence indices? |
136 | | - |
| 123 | +# |
137 | 124 | # The helper function ``create_varlen_metadata`` returns the required |
138 | 125 | # ``cu_seqlens`` and ``max_seqlen`` given ``input_batch`` and the end of |
139 | 126 | # sequence token ID that marks the end of documents. |
@@ -184,16 +171,16 @@ def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int): |
184 | 171 |
|
185 | 172 | ###################################################################### |
186 | 173 | # Implementing the Attention Block with ``varlen_attn`` |
187 | | -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
188 | | - |
189 | | -# Let’s explore how we would use ``varlen_attn`` in an Attention module. |
| 174 | +# ----------------------------------------------------- |
| 175 | +# |
| 176 | +# Let's explore how we would use ``varlen_attn`` in an Attention module. |
190 | 177 | # We define an attention module as usual, but in the ``forward`` method, |
191 | 178 | # we call the new ``varlen_attn`` custom op. |
192 | | - |
| 179 | +# |
193 | 180 | # This function expects the ``cu_seq`` indices and ``max_len`` that we |
194 | 181 | # computed earlier using ``create_varlen_metadata`` to mark the boundaries |
195 | 182 | # of the different documents. |
196 | | - |
| 183 | +# |
197 | 184 | # Before we call ``varlen_attn``, we also pack our input so that it has |
198 | 185 | # the shape ``(total tokens, dim)``. Recall that variable length attention |
199 | 186 | # allows us to collapse the ``batch_size`` dimension so that we can lay |
@@ -244,22 +231,21 @@ def forward( |
244 | 231 |
|
245 | 232 |
|
246 | 233 | ###################################################################### |
247 | | -# We can also use ``torch.compile`` with ``varlen_attn`` and define |
248 | | - |
| 234 | +# We can also use ``torch.compile`` with ``varlen_attn`` and define: |
| 235 | +# |
249 | 236 | # .. code:: python |
250 | | - |
| 237 | +# |
251 | 238 | # compiled_varlen_attn: ClassVar[Callable] = torch.compile( |
252 | 239 | # varlen_attn, mode="max-autotune-no-cudagraphs" |
253 | 240 | # ) |
254 | | - |
| 241 | +# |
255 | 242 | # We can call ``compiled_varlen_attn`` instead of ``varlen_attn`` in the |
256 | 243 | # Attention forward, and everything else stays the same. |
257 | | -# |
258 | 244 |
|
259 | 245 |
|
260 | 246 | ###################################################################### |
261 | 247 | # Creating a Transformer |
262 | | -# ~~~~~~~~~~~~~~~~~~~~~~ |
| 248 | +# ---------------------- |
263 | 249 | # |
264 | 250 | # Now, we can use this ``SimpleVarlenAttention`` module in a simple |
265 | 251 | # Transformer. |
@@ -288,7 +274,7 @@ def forward( |
288 | 274 |
|
289 | 275 | ###################################################################### |
290 | 276 | # Running a Training Step |
291 | | -# ~~~~~~~~~~~~~~~~~~~~~~~ |
| 277 | +# ----------------------- |
292 | 278 | # |
293 | 279 | # Now we’re ready to put all the pieces together! Let’s run a training |
294 | 280 | # step with our ``SimpleVarlenTransformer``. We define our model, compute |
@@ -344,20 +330,20 @@ def main(): |
344 | 330 |
|
345 | 331 | ###################################################################### |
346 | 332 | # Conclusion |
347 | | -# ----------- |
| 333 | +# ~~~~~~~~~~ |
348 | 334 | # |
349 | 335 | # In this tutorial, we have covered how to use the ``varlen_attn`` API in PyTorch to efficiently |
350 | | -# to handle sequences of varying lengths without padding. We explored how to create the |
| 336 | +# handle sequences of varying lengths without padding. We explored how to create the |
351 | 337 | # necessary metadata including the cumulative sequence indices, implemented a simple |
352 | 338 | # Transformer attention layer with variable length attention, and ran a complete |
353 | 339 | # training step. |
354 | 340 |
|
| 341 | +###################################################################### |
355 | 342 | # This approach eliminates wasted computation on padding tokens |
356 | 343 | # and enables more efficient training and inference for models processing |
357 | 344 | # documents of different lengths. |
358 | | - |
359 | | -###################################################################### |
360 | | -# See Also: |
361 | | -# ----------- |
362 | | -# * [Implementing High-Performance Transformers with Scaled Dot Product Attention ](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html) |
363 | | -# * [torch.nn.functional.scaled_dot_product_attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) |
| 345 | +# |
| 346 | +# .. seealso:: |
| 347 | +# |
| 348 | +# - `Implementing High-Performance Transformers with Scaled Dot Product Attention <https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_ |
| 349 | +# - `torch.nn.functional.scaled_dot_product_attention <https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`_ |
0 commit comments