From 9d85c4d6e40eb39d31b2f281a1d7063ffaa7292b Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 3 Mar 2026 17:57:48 -0800 Subject: [PATCH 1/3] Document torch.autocast nesting behavior with DeepSpeed engine Explain why users may want to wrap both engine.forward() and loss computation in their own torch.autocast context, and how DeepSpeed handles the resulting nested autocast. Signed-off-by: Masahiro Tanaka --- docs/_pages/config-json.md | 2 +- docs/code-docs/source/training.rst | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index d5344d3b2320..a99363338d22 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -389,7 +389,7 @@ Example of **scheduler** | Parameter | Type | Default | Description | | --------- | ---- | ------- | ----------- | -| **enabled** | boolean | `false` | Enable torch.autocast (no manual `torch.autocast` call needed in your code). | +| **enabled** | boolean | `false` | Enable torch.autocast for `engine.forward()`. You may still add your own `torch.autocast` outside the engine to cover loss functions or other user-side code (see [nesting behavior](https://deepspeed.readthedocs.io/en/latest/training.html#using-torch-autocast-outside-the-engine)). | | **dtype** | string | `"bfloat16"` | Lower precision dtype (`"bfloat16"` or `"float16"`). Also used for gradient/parameter communication of `lower_precision_safe_modules`. | | **lower_precision_safe_modules** | list | `["torch.nn.Linear", "torch.nn.Conv1d", "torch.nn.Conv2d", "torch.nn.Conv3d"]` | Module types for lower-precision communication (all-reduce/all-gather). | diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index 9b01b2f85315..84670bfdf2e8 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -119,6 +119,35 @@ If you call ``loss.backward()`` directly without using ``engine.scale()`` or ``e will raise a ``RuntimeError`` to prevent training with unscaled gradients, which can lead to incorrect results or gradient underflow. +Using torch.autocast Outside the Engine +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +DeepSpeed applies ``torch.autocast`` internally during ``engine.forward()``. +However, you may also want autocast to cover code that runs **outside** the engine, +such as a loss function or post-processing logic. In that case, wrap the entire +forward-plus-loss block in your own ``torch.autocast`` context: + +.. code-block:: python + + # Autocast covers both the engine forward AND the loss computation + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model_engine(input_ids) + loss = loss_fn(logits.view(-1, vocab_size), labels.view(-1)) + +Without the outer ``torch.autocast``, only the model forward pass benefits from +autocast; the loss function would run in full precision. + +When DeepSpeed detects a nested autocast context, it handles it as follows: + +* If ``torch_autocast`` is **enabled** in the DeepSpeed config, the engine overrides the + outer context with the dtype from the config. An info message is logged once. +* If ``torch_autocast`` is **disabled** in the config (i.e., you are using DeepSpeed's + built-in bf16/fp16 support instead), the engine disables autocast inside + ``engine.forward()`` and a warning is logged once. + +In both cases, PyTorch's ``torch.autocast`` is idempotent when nested with the same +dtype, so there is no performance or correctness penalty from the nesting. + .. autofunction:: deepspeed.runtime.torch_autocast.init_autocast_params .. autofunction:: deepspeed.runtime.torch_autocast.is_autocast_initialized .. autofunction:: deepspeed.runtime.torch_autocast.get_default_autocast_lower_precision_modules From b244ffccc76869bebdd25f26ebfcb6fcb819f80e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 3 Mar 2026 18:12:09 -0800 Subject: [PATCH 2/3] Revert config-json.md change The nesting note is better kept only in training.rst. Signed-off-by: Masahiro Tanaka --- docs/_pages/config-json.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index a99363338d22..d5344d3b2320 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -389,7 +389,7 @@ Example of **scheduler** | Parameter | Type | Default | Description | | --------- | ---- | ------- | ----------- | -| **enabled** | boolean | `false` | Enable torch.autocast for `engine.forward()`. You may still add your own `torch.autocast` outside the engine to cover loss functions or other user-side code (see [nesting behavior](https://deepspeed.readthedocs.io/en/latest/training.html#using-torch-autocast-outside-the-engine)). | +| **enabled** | boolean | `false` | Enable torch.autocast (no manual `torch.autocast` call needed in your code). | | **dtype** | string | `"bfloat16"` | Lower precision dtype (`"bfloat16"` or `"float16"`). Also used for gradient/parameter communication of `lower_precision_safe_modules`. | | **lower_precision_safe_modules** | list | `["torch.nn.Linear", "torch.nn.Conv1d", "torch.nn.Conv2d", "torch.nn.Conv3d"]` | Module types for lower-precision communication (all-reduce/all-gather). | From e573be57e114d606f3052df5b1efbe39cd98d259 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Wed, 4 Mar 2026 09:42:31 -0800 Subject: [PATCH 3/3] Update docs/code-docs/source/training.rst Co-authored-by: Stas Bekman --- docs/code-docs/source/training.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index 84670bfdf2e8..92e3bcf80f1f 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -134,7 +134,7 @@ forward-plus-loss block in your own ``torch.autocast`` context: logits = model_engine(input_ids) loss = loss_fn(logits.view(-1, vocab_size), labels.view(-1)) -Without the outer ``torch.autocast``, only the model forward pass benefits from +Without the outer ``torch.autocast``, only the model's forward pass benefits from autocast; the loss function would run in full precision. When DeepSpeed detects a nested autocast context, it handles it as follows: