Skip to content

Commit 581d15e

Browse files
committed
fix sphinx syntax
1 parent b7c5dbf commit 581d15e

1 file changed

Lines changed: 33 additions & 47 deletions

File tree

intermediate_source/variable_length_attention_tutorial.py

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -42,29 +42,21 @@
4242

4343
######################################################################
4444
# Overview of Variable Length Attention
45-
# -------------------------------------
45+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4646
#
4747
# In normal SDPA, sequences are expected to be a fixed length. In
4848
# practice, this means that input tensors are often **padded** to the same
4949
# length in a batch. However, this wastes both memory and compute through
5050
# storing this padding and performing unnecessary computations.
51-
52-
######################################################################
5351
# Variable length attention handles sequences of varying length by
5452
# **packing** the tensors in a batch together and essentially collapsing
5553
# the batch dimension.
5654

55+
######################################################################
5756
# However, we still need to maintain the boundaries between documents. To
5857
# do so, we compute cumulative sequence positions for query and key/value
5958
# that mark the end of documents. In the diagram below, doc 1 is 7 tokens
60-
# long, doc 2 is 10 tokens long, etc. so
61-
# ``cu_seq_lens = [0, 7, 17, ...]``.
62-
63-
######################################################################
64-
# Padding vs Packing Diagram
65-
66-
# .. figure:: ../_static/img/varlen_diagram.png
67-
# :alt: Padding vs Packing Diagram
59+
# long, doc 2 is 10 tokens long, etc. so ``cu_seq_lens = [0, 7, 17, ...]``.
6860

6961
######################################################################
7062
# Note that ``NestedTensor`` is another way to enable
@@ -74,13 +66,13 @@
7466

7567
######################################################################
7668
# Definition
77-
# ~~~~~~~~~~
78-
69+
# ----------
70+
#
7971
# Below is the definition of ``varlen_attn`` which returns the output
8072
# tensor from the attention computation.
81-
73+
#
8274
# .. code:: python
83-
75+
#
8476
# def varlen_attn(
8577
# query: torch.Tensor,
8678
# key: torch.Tensor,
@@ -93,33 +85,28 @@
9385
# return_aux: AuxRequest | None = None,
9486
# ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
9587
#
96-
97-
######################################################################
9888
# ``query``, ``key``, and ``value`` correspond to the ``q``, ``k``, and
9989
# ``v`` of the packed input. ``cu_seq_q`` and ``cu_seq_k`` are the
10090
# cumulative indices for query and key/value, respectively. These mark the
10191
# logical boundaries that separate the documents in our input. ``max_q``
10292
# and ``max_k`` are the maximum sequence lengths of query and key,
10393
# respectively. ``is_causal`` applies causal masking if set to True and
10494
# ``return_aux`` specifies which auxiliary outputs to return (ie ``lse``).
105-
#
10695

10796
######################################################################
10897
# **Note on causal masking**
109-
11098
# When ``is_causal`` is set to True, causal masking is applied which means
11199
# that tokens can only attend to previous tokens. For bidirectional
112100
# attention, set this flag to False.
113-
114-
######################################################################
115-
# In torchtitan (PyTorch’s pretraining framework), we set
101+
#
102+
# In torchtitan (PyTorch's pretraining framework), we set
116103
# ``is_causal = True`` uniformly to prevent the model from cheating and
117-
# artifically driving the loss down too quickly.
104+
# artificially driving the loss down too quickly.
118105

119106

120107
######################################################################
121108
# Example
122-
# -------
109+
# ~~~~~~~
123110
#
124111
# Let’s walk through a simple example of how we would use ``varlen_attn``
125112
# in the context of training a Transformer model.
@@ -128,12 +115,12 @@
128115

129116
######################################################################
130117
# Creating Required Metadata for ``varlen_attn`` from Input Batches
131-
# ~~~~~~~~~~~~~~~~~~~~~
132-
118+
# -----------------------------------------------------------------
119+
#
133120
# Given an input batch, how would we construct the metadata that
134121
# ``varlen_attn`` expects? More specifically, how do we calculate the
135122
# cumulative sequence indices?
136-
123+
#
137124
# The helper function ``create_varlen_metadata`` returns the required
138125
# ``cu_seqlens`` and ``max_seqlen`` given ``input_batch`` and the end of
139126
# sequence token ID that marks the end of documents.
@@ -184,16 +171,16 @@ def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):
184171

185172
######################################################################
186173
# Implementing the Attention Block with ``varlen_attn``
187-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
188-
189-
# Lets explore how we would use ``varlen_attn`` in an Attention module.
174+
# -----------------------------------------------------
175+
#
176+
# Let's explore how we would use ``varlen_attn`` in an Attention module.
190177
# We define an attention module as usual, but in the ``forward`` method,
191178
# we call the new ``varlen_attn`` custom op.
192-
179+
#
193180
# This function expects the ``cu_seq`` indices and ``max_len`` that we
194181
# computed earlier using ``create_varlen_metadata`` to mark the boundaries
195182
# of the different documents.
196-
183+
#
197184
# Before we call ``varlen_attn``, we also pack our input so that it has
198185
# the shape ``(total tokens, dim)``. Recall that variable length attention
199186
# allows us to collapse the ``batch_size`` dimension so that we can lay
@@ -244,22 +231,21 @@ def forward(
244231

245232

246233
######################################################################
247-
# We can also use ``torch.compile`` with ``varlen_attn`` and define
248-
234+
# We can also use ``torch.compile`` with ``varlen_attn`` and define:
235+
#
249236
# .. code:: python
250-
237+
#
251238
# compiled_varlen_attn: ClassVar[Callable] = torch.compile(
252239
# varlen_attn, mode="max-autotune-no-cudagraphs"
253240
# )
254-
241+
#
255242
# We can call ``compiled_varlen_attn`` instead of ``varlen_attn`` in the
256243
# Attention forward, and everything else stays the same.
257-
#
258244

259245

260246
######################################################################
261247
# Creating a Transformer
262-
# ~~~~~~~~~~~~~~~~~~~~~~
248+
# ----------------------
263249
#
264250
# Now, we can use this ``SimpleVarlenAttention`` module in a simple
265251
# Transformer.
@@ -288,7 +274,7 @@ def forward(
288274

289275
######################################################################
290276
# Running a Training Step
291-
# ~~~~~~~~~~~~~~~~~~~~~~~
277+
# -----------------------
292278
#
293279
# Now we’re ready to put all the pieces together! Let’s run a training
294280
# step with our ``SimpleVarlenTransformer``. We define our model, compute
@@ -344,20 +330,20 @@ def main():
344330

345331
######################################################################
346332
# Conclusion
347-
# -----------
333+
# ~~~~~~~~~~
348334
#
349335
# In this tutorial, we have covered how to use the ``varlen_attn`` API in PyTorch to efficiently
350-
# to handle sequences of varying lengths without padding. We explored how to create the
336+
# handle sequences of varying lengths without padding. We explored how to create the
351337
# necessary metadata including the cumulative sequence indices, implemented a simple
352338
# Transformer attention layer with variable length attention, and ran a complete
353339
# training step.
354340

341+
######################################################################
355342
# This approach eliminates wasted computation on padding tokens
356343
# and enables more efficient training and inference for models processing
357344
# documents of different lengths.
358-
359-
######################################################################
360-
# See Also:
361-
# -----------
362-
# * [Implementing High-Performance Transformers with Scaled Dot Product Attention ](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
363-
# * [torch.nn.functional.scaled_dot_product_attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
345+
#
346+
# .. seealso::
347+
#
348+
# - `Implementing High-Performance Transformers with Scaled Dot Product Attention <https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
349+
# - `torch.nn.functional.scaled_dot_product_attention <https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`_

0 commit comments

Comments
 (0)