diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index 9b01b2f85315..92e3bcf80f1f 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -119,6 +119,35 @@ If you call ``loss.backward()`` directly without using ``engine.scale()`` or ``e will raise a ``RuntimeError`` to prevent training with unscaled gradients, which can lead to incorrect results or gradient underflow. +Using torch.autocast Outside the Engine +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +DeepSpeed applies ``torch.autocast`` internally during ``engine.forward()``. +However, you may also want autocast to cover code that runs **outside** the engine, +such as a loss function or post-processing logic. In that case, wrap the entire +forward-plus-loss block in your own ``torch.autocast`` context: + +.. code-block:: python + + # Autocast covers both the engine forward AND the loss computation + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model_engine(input_ids) + loss = loss_fn(logits.view(-1, vocab_size), labels.view(-1)) + +Without the outer ``torch.autocast``, only the model's forward pass benefits from +autocast; the loss function would run in full precision. + +When DeepSpeed detects a nested autocast context, it handles it as follows: + +* If ``torch_autocast`` is **enabled** in the DeepSpeed config, the engine overrides the + outer context with the dtype from the config. An info message is logged once. +* If ``torch_autocast`` is **disabled** in the config (i.e., you are using DeepSpeed's + built-in bf16/fp16 support instead), the engine disables autocast inside + ``engine.forward()`` and a warning is logged once. + +In both cases, PyTorch's ``torch.autocast`` is idempotent when nested with the same +dtype, so there is no performance or correctness penalty from the nesting. + .. autofunction:: deepspeed.runtime.torch_autocast.init_autocast_params .. autofunction:: deepspeed.runtime.torch_autocast.is_autocast_initialized .. autofunction:: deepspeed.runtime.torch_autocast.get_default_autocast_lower_precision_modules