Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/code-docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,35 @@ If you call ``loss.backward()`` directly without using ``engine.scale()`` or ``e
will raise a ``RuntimeError`` to prevent training with unscaled gradients, which can lead to incorrect results
or gradient underflow.

Using torch.autocast Outside the Engine
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

DeepSpeed applies ``torch.autocast`` internally during ``engine.forward()``.
However, you may also want autocast to cover code that runs **outside** the engine,
such as a loss function or post-processing logic. In that case, wrap the entire
forward-plus-loss block in your own ``torch.autocast`` context:

.. code-block:: python
# Autocast covers both the engine forward AND the loss computation
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model_engine(input_ids)
loss = loss_fn(logits.view(-1, vocab_size), labels.view(-1))
Without the outer ``torch.autocast``, only the model's forward pass benefits from
autocast; the loss function would run in full precision.

When DeepSpeed detects a nested autocast context, it handles it as follows:

* If ``torch_autocast`` is **enabled** in the DeepSpeed config, the engine overrides the
outer context with the dtype from the config. An info message is logged once.
* If ``torch_autocast`` is **disabled** in the config (i.e., you are using DeepSpeed's
built-in bf16/fp16 support instead), the engine disables autocast inside
``engine.forward()`` and a warning is logged once.

In both cases, PyTorch's ``torch.autocast`` is idempotent when nested with the same
dtype, so there is no performance or correctness penalty from the nesting.

.. autofunction:: deepspeed.runtime.torch_autocast.init_autocast_params
.. autofunction:: deepspeed.runtime.torch_autocast.is_autocast_initialized
.. autofunction:: deepspeed.runtime.torch_autocast.get_default_autocast_lower_precision_modules
Expand Down