Skip to content

Variable batch size and LR scheduler#7104

Merged
tjruwase merged 30 commits intodeepspeedai:masterfrom
bm-synth:variable_batch_size_and_lr_2
Mar 27, 2025
Merged

Variable batch size and LR scheduler#7104
tjruwase merged 30 commits intodeepspeedai:masterfrom
bm-synth:variable_batch_size_and_lr_2

Conversation

@bm-synth
Copy link
Contributor

@bm-synth bm-synth commented Mar 3, 2025

Background and rationale

In many use cases, particularly LLMs, one is faced with inputs (sentences) of variable lengths. A common practice is to pack batches by token count (not a fixed batch size), ie by putting together sentences whose given metric (eg sequence lengths) will add up to an user-provided value. As an example, in Attention is all you need, section 5.1:

Sentence pairs were batched together by approximate sequence length. Each training
batch contained a set of sentence pairs containing approximately 25000 source tokens and 25000
target tokens.

Dynamic batch sizes has been requested in DeepSpeed issue 1051, DeepSpeed issue 3455 , Pytorch Lightning issue 16914, huggingface issue 2647 and is available already in many libraries e.g. NVIDIA Triton and Meta FairSeq (implementation here ).

The immediate use case for this is when one needs to maximize GPU utilization. Moreover, this is particularly relevant for curriculum learning where a BxTxE (Batch x Time x Embedding) -shaped input should ideally have high B and low T at the early curriculum steps (many short sentences packed together as a batch), and low B and high T at the late steps (few long sentences in the batch). A dynamic size T is already supported by Deepspeed, e.g. in the documentation for pipeline parallelism's reset_activation_shape():

For curriculum learning that changes the seqlen of each sample, we need to call this whenever the seqlen is going to change.

However, dynamic B is not supported. A dynamic B would require an adequate increase/decrease of learning rate. This technique has been applied previously, and the two most common LR scaling algorithms have been described as:

  1. Linear Scaling Rule: "When the minibatch size is multiplied by k, multiply the learning rate by k", as in Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour, Goyal et al.
  2. Square Root scaling: "when multiplying the batch size by k, multiply the learning rate by √k, to keep the variance in the gradient expectation constant" by One weird trick for parallelizing convolutional neural networks, A. Krizhevsky et al.

In practice, the user picks the total token count per batch as the metric that drives batching, instead of batching by sentence count. During runtime, the variable batch size is computed and the LR is adjusted respectively, based on the LR and batch size provided by the config.

Illustration of dynamic batch size, sequence length and LR

Imagine we picked a limit of 30 tokens per batch, and have set a reference lr=1e-3 for a train_batch_size=2 (in the deepspeed config). The batching algorithm for curriculum may pack the data into batches of short sentences (left) at the early stages, and batches of long sentences (right) as later stages, e.g.:

dynamic_batch_size_and_lr

Above, we collected samples until we filled up the batch with at most 30 tokens. The batch sizes (number of samples) became then 10 and 4 on the left and right examples, respectively. Using the linear scaling rule, the LR for those batches become 5e-3 and 2e-3.

Pipeline parallelism

Pipeline parallelism requires the same batch size and same sequence length across all micro-batches in a batch, as the activation sizes must be fixed between gradient accumulation steps. Between batches, these may change, and long as engine.reset_activation_shape() is called so that the new shapes are communicated on the first gradient accumulation step in the batch. Enforcing similar BxTxE between batches may lead to smaller micro-batches. As an example, below we can see an illustration of a 2-node 2-gradient-accumulation-step (ie 4 micro-batches) batching for the same dataset, when preparing data for the regular DDP (left) and for the pipeline parallelism use cases (right):

dynamic_batch_size_and_lr_microbatching

We can see that the pipeline use case (right) has the same BxTxE shape across all the 4 micro-batches in the same batch, and in order to respect that, it packs less samples in the batch, when compared to the standard use case (left hand size)

Attention Head

For an input of size BxTxE the attention has a shape of TxT for a mask of fixed size across samples of same size, or BxTxT for a different mask per sample (when samples have different sizes, as in the dataset above). This 3D attention matrix can be illustrated for the DDP microbatch 1 (picture above top-left, 4 sentences) as:

dynamic_batch_size_and_lr_attn_matrix

Note the memory savings: the attention head has a size of BxTxT, i.e. a linear memory dependency on the batch size B and quadratic memory dependency on the largest sequence length T in the (micro-) batch. Thus, supporting a dynamic size T allows for an increase of B.

PR overview

This PRs implements dynamic batching and LR scaling. The dataloader and LR scheduler necessary can be retrieved by calling get_dataloader_and_lr_scheduler_for_variable_batch_size. A small explanation of that function follows:

  • The logic behind the algorithms for LR scaling is in scale_lr;
  • The partitioning of samples into batches is done by batch_by_seqlen.
  • For pipeline parallelism, it is required that all micro-batches in a pipeline pass to have the same activation shapes. This is enabled by setting to True the following parameters:
    • required_microbatches_of_same_sizes that will force the B dimension to be the same across all gradient accumulation steps of all dataloaders on a batch;
    • required_microbatches_of_same_lengths that will force the T dimension to be the same across all gradient accumulation steps. Works by calling the user-provided sample_padding_fn(sentence, len) that pads a given sentence to the argument length;
    • batch_by_seqlen returns microbatch_sample_ids (the list of sample ids per micro-batch), batch_sizes (the size of effective batch sizes, and batch_max_seqlens (longest sequence across all microbatches in a batch)
  • dataloader_for_variable_batch_size relies on microbatch_sample_ids and will iterate/collate/pad samples for every batch and return a dataloader that iterates the final (variable-size) batches;
  • lr_scheduler_for_variable_batch_size relies on batch_sizes to compute the learning rate for each effective batch, taking into account the batch size and LR in the config file, and scaling the LR based on the size of each effective batch, and the scaling rule mentioned above (Linear, Square root, etc).
    • Special note to the lr_scheduler returned that will either accept either:
      1. an user-provided Optimizer that will scale the learning rates (in param groups) at every batch, or
      2. an user-defined LRScheduler, that in this case will first get the learning rate from the scheduler and then scale it accordingly.

Example

An example for the use case with and without pipelining is provided in file DeepSpeedExamples/training/data_efficiency/variable_batch_size_and_lr/variable_batch_size_and_lr_example.py. The example shows an attention head with attention of variable-sized BxTxT per batch, followed by a fixed size feed forward network. These are the main blocks on a Large Language Model. The feed-forward (or linear layer) that follows the attention head requires a constant input size, equivalent to the largest sentence in the whole dataset, so the output of the attention must be padded (see feedforward: needs to convert BxTxE to BxMxE by padding extra tokens in the code).

Config

The example file also comments the relevant deepspeed config with comments:

config = {
  "train_batch_size": 16,
  # `train_micro_batch_size_per_gpu` tells how many sequence packs of `max_tokens` each will be collated together.
  #  I.e. the number of tokens per micro batch (ie per gpu iteration) is `train_micro_batch_size_per_gpu`*`max_tokens`.
  "train_micro_batch_size_per_gpu": 2,
  "data_efficiency": {
    "enabled": True,
    # seed to be applied to all data efficiency modules, including dynamic batching
    "seed": 42,
    "data_sampling": {
      "num_workers": 0, # dataloader num_workers argument
      "pin_memory": False,  # dataloader pin_memory argument
      "dynamic_batching": {
        # enables or disables dynamic batching
        "enabled": True,
        # how many tokens we need to fill a pack of sequences (that will be collated together as a sample)
        "max_tokens": 100,
        # Input and output write to read from or write the length of every sequence.
        # Sequence lengths will be loaded from: {metrics_path}/seqlen/seqlen_sample_to_metric.bin and *.idx
        # If files dont exist, they'll be computed and saved on the first run, and loaded on subsequent runs.
        "metrics_path": "./curriculum_output/",
        # As batch size increases/decreses, which method to use to scale LR accordingly?
        # Options: linear, sqrt (square root), or None to disable
        "lr_scaling_method": "linear",
        # how to pick sentences to be packed into samples:
        # - dataloader: by same order as they come in with the dataloader
        # - seqlen: by sequence length (shortest to longest)
        # - random: random order using the seed in config['data_efficiency']['seed'
        "sentence_picking_order": "dataloader",  # "random" / "seqlen" / "dataloader"
        # minimum number of sequences required to reach `max_tokens`. If sentence pack is smaller, it's discarded.
        "min_batch_size": 1,
        # maximum number of sequences required to reach `max_tokens`. If sentence pack is larger, it's discarded.
        "max_batch_size": 10,
        # enable the output of microbatching information about sentence packing
        "verbose": True,
      },
    },
  },
}

Future work

A follow-up PR will enable dynamic batching when calling deepspeed.initialize. I.e. instead of this:

engine, _, _, _ = deepspeed.initialize(config=config, model=model)
dataloader, lr_scheduler, _ = get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(...)
engine.lr_scheduler = lr_scheduler

we'd ideally have this:

engine, _, dataloader, lr_scheduler = deepspeed.initialize(config=config, model=model)

where initialize will call internally get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed.

bm-synth added 3 commits March 3, 2025 11:25
Signed-off-by: Bruno Magalhaes <bruno.magalhaes@synthesia.io>
sign off of all commits in tree

Signed-off-by: Bruno Magalhaes <bruno.magalhaes@synthesia.io>
@bm-synth bm-synth requested a review from GuanhuaWang as a code owner March 4, 2025 01:04
@bm-synth bm-synth force-pushed the variable_batch_size_and_lr_2 branch from c5615c3 to 810d89b Compare March 4, 2025 01:13
@tjruwase
Copy link
Contributor

tjruwase commented Mar 4, 2025

@bm-synth, please see the formatting failure.

@loadams
Copy link
Collaborator

loadams commented Mar 14, 2025

Hi @bm-synth - could you take a look at the test failures here?

@loadams
Copy link
Collaborator

loadams commented Mar 25, 2025

Hi @bm-synth - could you take a look at the test failures here?

Just a reminder @bm-synth - our CI appears to be running fine now, so the errors look related to the PR, for example here. Let us know if you have any other questions/issues.

@bm-synth
Copy link
Contributor Author

bm-synth commented Mar 27, 2025

@loadams fixed that missing data_sampling key error here.

@bm-synth
Copy link
Contributor Author

bm-synth commented Mar 27, 2025

@loadams @tjruwase all tests pased. I believe it can be merged. I will then start working on a better integration with the deepspeed.initialize(...) in a separate PR.

@tjruwase tjruwase added this pull request to the merge queue Mar 27, 2025
@tjruwase
Copy link
Contributor

@loadams @tjruwase all tests pased. I believe it can be merged. I will then start working on a better integration with the deepspeed.initialize(...) in a separate PR.

@bm-synth, thanks for the quick resolution.

If you ever decide to write a blog on this awesome feature, we would be delighted to collaborate and advertise here: https://www.deepspeed.ai/

Merged via the queue into deepspeedai:master with commit 20f988e Mar 27, 2025
11 checks passed
@bm-synth bm-synth deleted the variable_batch_size_and_lr_2 branch March 27, 2025 16:58
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Mar 28, 2025
# Background and rationale

In many use cases, particularly LLMs, one is faced with inputs
(sentences) of variable lengths. A common practice is to pack batches by
token count (not a fixed batch size), ie by putting together sentences
whose given metric (eg sequence lengths) will add up to an user-provided
value. As an example, in [Attention is all you
need](https://arxiv.org/abs/1706.03762), section 5.1:

> Sentence pairs were batched together by approximate sequence length.
Each training
batch contained a set of sentence pairs containing approximately 25000
source tokens and 25000
target tokens.

Dynamic batch sizes has been requested in [DeepSpeed issue
1051](deepspeedai#1051), [DeepSpeed
issue 3455 ](deepspeedai#3455),
[Pytorch Lightning issue
16914](Lightning-AI/pytorch-lightning#16914),
[huggingface issue
2647](huggingface/accelerate#2647) and is
available already in many libraries e.g. [NVIDIA
Triton](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher)
and [Meta FairSeq](https://github.com/facebookresearch/fairseq)
(implementation
[here](https://github.com/facebookresearch/fairseq/blob/34973a94d09ecc12092a5ecc8afece5e536b7692/fairseq/data/fairseq_dataset.py#L104)
).

The immediate use case for this is when one needs to maximize GPU
utilization. Moreover, this is particularly relevant for curriculum
learning where a `BxTxE` (Batch x Time x Embedding) -shaped input should
ideally have high `B` and low `T` at the early curriculum steps (many
short sentences packed together as a batch), and low `B` and high `T` at
the late steps (few long sentences in the batch). A dynamic size `T` is
already supported by Deepspeed, e.g. in the documentation for pipeline
parallelism's
[reset_activation_shape()](https://deepspeed.readthedocs.io/en/stable/pipeline.html#deepspeed.runtime.pipe.engine.PipelineEngine.reset_activation_shape):
> For curriculum learning that changes the seqlen of each sample, we
need to call this whenever the seqlen is going to change.

However, dynamic `B` is not supported. A dynamic `B` would require an
adequate increase/decrease of learning rate. This technique has been
applied previously, and the two most common LR scaling algorithms have
been described as:
1. Linear Scaling Rule: "When the minibatch size is multiplied by k,
multiply the learning rate by k", as in [Accurate, Large Minibatch SGD:
Training ImageNet in 1 Hour, Goyal et
al.](https://arxiv.org/abs/1706.02677)
2. Square Root scaling: "when multiplying the batch size by k, multiply
the learning rate by √k, to keep the variance in the gradient
expectation constant" by [One weird trick for parallelizing
convolutional neural networks, A. Krizhevsky et
al.](https://arxiv.org/abs/1404.5997)

In practice, the user picks the total token count per batch as the
metric that drives batching, instead of batching by sentence count.
During runtime, the variable batch size is computed and the LR is
adjusted respectively, based on the LR and batch size provided by the
config.

# Illustration of dynamic batch size, sequence length and LR

Imagine we picked a limit of `30` tokens per batch, and have set a
reference `lr=1e-3` for a `train_batch_size=2` (in the deepspeed
config). The batching algorithm for curriculum may pack the data into
batches of short sentences (left) at the early stages, and batches of
long sentences (right) as later stages, e.g.:


![dynamic_batch_size_and_lr](https://github.com/microsoft/DeepSpeed/assets/150697676/324bda09-8f0b-430c-bb33-cc1bd01c3fe7)

Above, we collected samples until we filled up the batch with at most 30
tokens. The batch sizes (number of samples) became then `10` and `4` on
the left and right examples, respectively. Using the linear scaling
rule, the LR for those batches become `5e-3` and `2e-3`.

# Pipeline parallelism

Pipeline parallelism requires the same batch size and same sequence
length across all micro-batches in a batch, as the activation sizes must
be fixed between gradient accumulation steps. Between batches, these may
change, and long as `engine.reset_activation_shape()` is called so that
the new shapes are communicated on the first gradient accumulation step
in the batch. Enforcing similar `BxTxE` between batches may lead to
smaller micro-batches. As an example, below we can see an illustration
of a 2-node 2-gradient-accumulation-step (ie 4 micro-batches) batching
for the same dataset, when preparing data for the regular DDP (left) and
for the pipeline parallelism use cases (right):


![dynamic_batch_size_and_lr_microbatching](https://github.com/microsoft/DeepSpeed/assets/150697676/3fed5e1c-f2f5-4efe-a9c5-5b5e20719d45)

We can see that the pipeline use case (right) has the same `BxTxE` shape
across all the 4 micro-batches in the same batch, and in order to
respect that, it packs less samples in the batch, when compared to the
standard use case (left hand size)

# Attention Head

For an input of size `BxTxE` the attention has a shape of `TxT` for a
mask of fixed size across samples of same size, or `BxTxT` for a
different mask per sample (when samples have different sizes, as in the
dataset above). This 3D attention matrix can be illustrated for the DDP
microbatch 1 (picture above top-left, 4 sentences) as:
 

![dynamic_batch_size_and_lr_attn_matrix](https://github.com/microsoft/DeepSpeed/assets/150697676/707d2f17-66da-4034-8a12-a87df2044bfb)

Note the memory savings: the attention head has a size of `BxTxT`, i.e.
a linear memory dependency on the batch size `B` and quadratic memory
dependency on the largest sequence length `T` in the (micro-) batch.
Thus, supporting a dynamic size `T` allows for an increase of `B`.

# PR overview

This PRs implements dynamic batching and LR scaling. The dataloader and
LR scheduler necessary can be retrieved by calling
`get_dataloader_and_lr_scheduler_for_variable_batch_size`. A small
explanation of that function follows:
- The logic behind the algorithms for LR scaling is in `scale_lr`;
- The partitioning of samples into batches is done by `batch_by_seqlen`.
- For pipeline parallelism, it is required that all micro-batches in a
pipeline pass to have the same activation shapes. This is enabled by
setting to `True` the following parameters:
- `required_microbatches_of_same_sizes` that will force the `B`
dimension to be the same across all gradient accumulation steps of all
dataloaders on a batch;
- `required_microbatches_of_same_lengths` that will force the `T`
dimension to be the same across all gradient accumulation steps. Works
by calling the user-provided `sample_padding_fn(sentence, len)` that
pads a given sentence to the argument length;
- `batch_by_seqlen` returns `microbatch_sample_ids` (the list of sample
ids per micro-batch), `batch_sizes` (the size of effective batch sizes,
and `batch_max_seqlens` (longest sequence across all microbatches in a
batch)
- `dataloader_for_variable_batch_size` relies on `microbatch_sample_ids`
and will iterate/collate/pad samples for every batch and return a
dataloader that iterates the final (variable-size) batches;
- `lr_scheduler_for_variable_batch_size` relies on `batch_sizes` to
compute the learning rate for each effective batch, taking into account
the batch size and LR in the config file, and scaling the LR based on
the size of each effective batch, and the scaling rule mentioned above
(Linear, Square root, etc).
- Special note to the `lr_scheduler` returned that will either accept
either:
1. an user-provided `Optimizer` that will scale the learning rates (in
param groups) at every batch, or
2. an user-defined `LRScheduler`, that in this case will first get the
learning rate from the scheduler and then scale it accordingly.

# Example

An example for the use case with and without pipelining is provided in
file
[`DeepSpeedExamples/training/data_efficiency/variable_batch_size_and_lr/variable_batch_size_and_lr_example.py`](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/data_efficiency/variable_batch_size_and_lr).
The example shows an attention head with attention of variable-sized
`BxTxT` per batch, followed by a fixed size feed forward network. These
are the main blocks on a Large Language Model. The feed-forward (or
linear layer) that follows the attention head requires a constant input
size, equivalent to the largest sentence in the whole dataset, so the
output of the attention must be padded (see `feedforward: needs to
convert BxTxE to BxMxE by padding extra tokens` in the code).


# Config

The example file also comments the relevant deepspeed config with
comments:

```python
config = {
  "train_batch_size": 16,
  # `train_micro_batch_size_per_gpu` tells how many sequence packs of `max_tokens` each will be collated together.
  #  I.e. the number of tokens per micro batch (ie per gpu iteration) is `train_micro_batch_size_per_gpu`*`max_tokens`.
  "train_micro_batch_size_per_gpu": 2,
  "data_efficiency": {
    "enabled": True,
    # seed to be applied to all data efficiency modules, including dynamic batching
    "seed": 42,
    "data_sampling": {
      "num_workers": 0, # dataloader num_workers argument
      "pin_memory": False,  # dataloader pin_memory argument
      "dynamic_batching": {
        # enables or disables dynamic batching
        "enabled": True,
        # how many tokens we need to fill a pack of sequences (that will be collated together as a sample)
        "max_tokens": 100,
        # Input and output write to read from or write the length of every sequence.
        # Sequence lengths will be loaded from: {metrics_path}/seqlen/seqlen_sample_to_metric.bin and *.idx
        # If files dont exist, they'll be computed and saved on the first run, and loaded on subsequent runs.
        "metrics_path": "./curriculum_output/",
        # As batch size increases/decreses, which method to use to scale LR accordingly?
        # Options: linear, sqrt (square root), or None to disable
        "lr_scaling_method": "linear",
        # how to pick sentences to be packed into samples:
        # - dataloader: by same order as they come in with the dataloader
        # - seqlen: by sequence length (shortest to longest)
        # - random: random order using the seed in config['data_efficiency']['seed'
        "sentence_picking_order": "dataloader",  # "random" / "seqlen" / "dataloader"
        # minimum number of sequences required to reach `max_tokens`. If sentence pack is smaller, it's discarded.
        "min_batch_size": 1,
        # maximum number of sequences required to reach `max_tokens`. If sentence pack is larger, it's discarded.
        "max_batch_size": 10,
        # enable the output of microbatching information about sentence packing
        "verbose": True,
      },
    },
  },
}
```

# Future work

A follow-up PR will enable dynamic batching when calling
`deepspeed.initialize`. I.e. instead of this:

```python
engine, _, _, _ = deepspeed.initialize(config=config, model=model)
dataloader, lr_scheduler, _ = get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(...)
engine.lr_scheduler = lr_scheduler
```

we'd ideally have this:

```python
engine, _, dataloader, lr_scheduler = deepspeed.initialize(config=config, model=model)
```

where `initialize` will call internally
`get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed`.

---------

Signed-off-by: Bruno Magalhaes <bruno.magalhaes@synthesia.io>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
deepcharm pushed a commit to deepcharm/DeepSpeed that referenced this pull request Mar 31, 2025
# Background and rationale

In many use cases, particularly LLMs, one is faced with inputs
(sentences) of variable lengths. A common practice is to pack batches by
token count (not a fixed batch size), ie by putting together sentences
whose given metric (eg sequence lengths) will add up to an user-provided
value. As an example, in [Attention is all you
need](https://arxiv.org/abs/1706.03762), section 5.1:

> Sentence pairs were batched together by approximate sequence length.
Each training
batch contained a set of sentence pairs containing approximately 25000
source tokens and 25000
target tokens.

Dynamic batch sizes has been requested in [DeepSpeed issue
1051](deepspeedai#1051), [DeepSpeed
issue 3455 ](deepspeedai#3455),
[Pytorch Lightning issue
16914](Lightning-AI/pytorch-lightning#16914),
[huggingface issue
2647](huggingface/accelerate#2647) and is
available already in many libraries e.g. [NVIDIA
Triton](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher)
and [Meta FairSeq](https://github.com/facebookresearch/fairseq)
(implementation
[here](https://github.com/facebookresearch/fairseq/blob/34973a94d09ecc12092a5ecc8afece5e536b7692/fairseq/data/fairseq_dataset.py#L104)
).

The immediate use case for this is when one needs to maximize GPU
utilization. Moreover, this is particularly relevant for curriculum
learning where a `BxTxE` (Batch x Time x Embedding) -shaped input should
ideally have high `B` and low `T` at the early curriculum steps (many
short sentences packed together as a batch), and low `B` and high `T` at
the late steps (few long sentences in the batch). A dynamic size `T` is
already supported by Deepspeed, e.g. in the documentation for pipeline
parallelism's
[reset_activation_shape()](https://deepspeed.readthedocs.io/en/stable/pipeline.html#deepspeed.runtime.pipe.engine.PipelineEngine.reset_activation_shape):
> For curriculum learning that changes the seqlen of each sample, we
need to call this whenever the seqlen is going to change.

However, dynamic `B` is not supported. A dynamic `B` would require an
adequate increase/decrease of learning rate. This technique has been
applied previously, and the two most common LR scaling algorithms have
been described as:
1. Linear Scaling Rule: "When the minibatch size is multiplied by k,
multiply the learning rate by k", as in [Accurate, Large Minibatch SGD:
Training ImageNet in 1 Hour, Goyal et
al.](https://arxiv.org/abs/1706.02677)
2. Square Root scaling: "when multiplying the batch size by k, multiply
the learning rate by √k, to keep the variance in the gradient
expectation constant" by [One weird trick for parallelizing
convolutional neural networks, A. Krizhevsky et
al.](https://arxiv.org/abs/1404.5997)

In practice, the user picks the total token count per batch as the
metric that drives batching, instead of batching by sentence count.
During runtime, the variable batch size is computed and the LR is
adjusted respectively, based on the LR and batch size provided by the
config.

# Illustration of dynamic batch size, sequence length and LR

Imagine we picked a limit of `30` tokens per batch, and have set a
reference `lr=1e-3` for a `train_batch_size=2` (in the deepspeed
config). The batching algorithm for curriculum may pack the data into
batches of short sentences (left) at the early stages, and batches of
long sentences (right) as later stages, e.g.:

![dynamic_batch_size_and_lr](https://github.com/microsoft/DeepSpeed/assets/150697676/324bda09-8f0b-430c-bb33-cc1bd01c3fe7)

Above, we collected samples until we filled up the batch with at most 30
tokens. The batch sizes (number of samples) became then `10` and `4` on
the left and right examples, respectively. Using the linear scaling
rule, the LR for those batches become `5e-3` and `2e-3`.

# Pipeline parallelism

Pipeline parallelism requires the same batch size and same sequence
length across all micro-batches in a batch, as the activation sizes must
be fixed between gradient accumulation steps. Between batches, these may
change, and long as `engine.reset_activation_shape()` is called so that
the new shapes are communicated on the first gradient accumulation step
in the batch. Enforcing similar `BxTxE` between batches may lead to
smaller micro-batches. As an example, below we can see an illustration
of a 2-node 2-gradient-accumulation-step (ie 4 micro-batches) batching
for the same dataset, when preparing data for the regular DDP (left) and
for the pipeline parallelism use cases (right):

![dynamic_batch_size_and_lr_microbatching](https://github.com/microsoft/DeepSpeed/assets/150697676/3fed5e1c-f2f5-4efe-a9c5-5b5e20719d45)

We can see that the pipeline use case (right) has the same `BxTxE` shape
across all the 4 micro-batches in the same batch, and in order to
respect that, it packs less samples in the batch, when compared to the
standard use case (left hand size)

# Attention Head

For an input of size `BxTxE` the attention has a shape of `TxT` for a
mask of fixed size across samples of same size, or `BxTxT` for a
different mask per sample (when samples have different sizes, as in the
dataset above). This 3D attention matrix can be illustrated for the DDP
microbatch 1 (picture above top-left, 4 sentences) as:

![dynamic_batch_size_and_lr_attn_matrix](https://github.com/microsoft/DeepSpeed/assets/150697676/707d2f17-66da-4034-8a12-a87df2044bfb)

Note the memory savings: the attention head has a size of `BxTxT`, i.e.
a linear memory dependency on the batch size `B` and quadratic memory
dependency on the largest sequence length `T` in the (micro-) batch.
Thus, supporting a dynamic size `T` allows for an increase of `B`.

# PR overview

This PRs implements dynamic batching and LR scaling. The dataloader and
LR scheduler necessary can be retrieved by calling
`get_dataloader_and_lr_scheduler_for_variable_batch_size`. A small
explanation of that function follows:
- The logic behind the algorithms for LR scaling is in `scale_lr`;
- The partitioning of samples into batches is done by `batch_by_seqlen`.
- For pipeline parallelism, it is required that all micro-batches in a
pipeline pass to have the same activation shapes. This is enabled by
setting to `True` the following parameters:
- `required_microbatches_of_same_sizes` that will force the `B`
dimension to be the same across all gradient accumulation steps of all
dataloaders on a batch;
- `required_microbatches_of_same_lengths` that will force the `T`
dimension to be the same across all gradient accumulation steps. Works
by calling the user-provided `sample_padding_fn(sentence, len)` that
pads a given sentence to the argument length;
- `batch_by_seqlen` returns `microbatch_sample_ids` (the list of sample
ids per micro-batch), `batch_sizes` (the size of effective batch sizes,
and `batch_max_seqlens` (longest sequence across all microbatches in a
batch)
- `dataloader_for_variable_batch_size` relies on `microbatch_sample_ids`
and will iterate/collate/pad samples for every batch and return a
dataloader that iterates the final (variable-size) batches;
- `lr_scheduler_for_variable_batch_size` relies on `batch_sizes` to
compute the learning rate for each effective batch, taking into account
the batch size and LR in the config file, and scaling the LR based on
the size of each effective batch, and the scaling rule mentioned above
(Linear, Square root, etc).
- Special note to the `lr_scheduler` returned that will either accept
either:
1. an user-provided `Optimizer` that will scale the learning rates (in
param groups) at every batch, or
2. an user-defined `LRScheduler`, that in this case will first get the
learning rate from the scheduler and then scale it accordingly.

# Example

An example for the use case with and without pipelining is provided in
file
[`DeepSpeedExamples/training/data_efficiency/variable_batch_size_and_lr/variable_batch_size_and_lr_example.py`](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/data_efficiency/variable_batch_size_and_lr).
The example shows an attention head with attention of variable-sized
`BxTxT` per batch, followed by a fixed size feed forward network. These
are the main blocks on a Large Language Model. The feed-forward (or
linear layer) that follows the attention head requires a constant input
size, equivalent to the largest sentence in the whole dataset, so the
output of the attention must be padded (see `feedforward: needs to
convert BxTxE to BxMxE by padding extra tokens` in the code).

# Config

The example file also comments the relevant deepspeed config with
comments:

```python
config = {
  "train_batch_size": 16,
  # `train_micro_batch_size_per_gpu` tells how many sequence packs of `max_tokens` each will be collated together.
  #  I.e. the number of tokens per micro batch (ie per gpu iteration) is `train_micro_batch_size_per_gpu`*`max_tokens`.
  "train_micro_batch_size_per_gpu": 2,
  "data_efficiency": {
    "enabled": True,
    # seed to be applied to all data efficiency modules, including dynamic batching
    "seed": 42,
    "data_sampling": {
      "num_workers": 0, # dataloader num_workers argument
      "pin_memory": False,  # dataloader pin_memory argument
      "dynamic_batching": {
        # enables or disables dynamic batching
        "enabled": True,
        # how many tokens we need to fill a pack of sequences (that will be collated together as a sample)
        "max_tokens": 100,
        # Input and output write to read from or write the length of every sequence.
        # Sequence lengths will be loaded from: {metrics_path}/seqlen/seqlen_sample_to_metric.bin and *.idx
        # If files dont exist, they'll be computed and saved on the first run, and loaded on subsequent runs.
        "metrics_path": "./curriculum_output/",
        # As batch size increases/decreses, which method to use to scale LR accordingly?
        # Options: linear, sqrt (square root), or None to disable
        "lr_scaling_method": "linear",
        # how to pick sentences to be packed into samples:
        # - dataloader: by same order as they come in with the dataloader
        # - seqlen: by sequence length (shortest to longest)
        # - random: random order using the seed in config['data_efficiency']['seed'
        "sentence_picking_order": "dataloader",  # "random" / "seqlen" / "dataloader"
        # minimum number of sequences required to reach `max_tokens`. If sentence pack is smaller, it's discarded.
        "min_batch_size": 1,
        # maximum number of sequences required to reach `max_tokens`. If sentence pack is larger, it's discarded.
        "max_batch_size": 10,
        # enable the output of microbatching information about sentence packing
        "verbose": True,
      },
    },
  },
}
```

# Future work

A follow-up PR will enable dynamic batching when calling
`deepspeed.initialize`. I.e. instead of this:

```python
engine, _, _, _ = deepspeed.initialize(config=config, model=model)
dataloader, lr_scheduler, _ = get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(...)
engine.lr_scheduler = lr_scheduler
```

we'd ideally have this:

```python
engine, _, dataloader, lr_scheduler = deepspeed.initialize(config=config, model=model)
```

where `initialize` will call internally
`get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed`.

---------

Signed-off-by: Bruno Magalhaes <bruno.magalhaes@synthesia.io>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
ys950902 pushed a commit to ys950902/DeepSpeed that referenced this pull request May 21, 2025
# Background and rationale

In many use cases, particularly LLMs, one is faced with inputs
(sentences) of variable lengths. A common practice is to pack batches by
token count (not a fixed batch size), ie by putting together sentences
whose given metric (eg sequence lengths) will add up to an user-provided
value. As an example, in [Attention is all you
need](https://arxiv.org/abs/1706.03762), section 5.1:

> Sentence pairs were batched together by approximate sequence length.
Each training
batch contained a set of sentence pairs containing approximately 25000
source tokens and 25000
target tokens.

Dynamic batch sizes has been requested in [DeepSpeed issue
1051](deepspeedai#1051), [DeepSpeed
issue 3455 ](deepspeedai#3455),
[Pytorch Lightning issue
16914](Lightning-AI/pytorch-lightning#16914),
[huggingface issue
2647](huggingface/accelerate#2647) and is
available already in many libraries e.g. [NVIDIA
Triton](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher)
and [Meta FairSeq](https://github.com/facebookresearch/fairseq)
(implementation
[here](https://github.com/facebookresearch/fairseq/blob/34973a94d09ecc12092a5ecc8afece5e536b7692/fairseq/data/fairseq_dataset.py#L104)
).

The immediate use case for this is when one needs to maximize GPU
utilization. Moreover, this is particularly relevant for curriculum
learning where a `BxTxE` (Batch x Time x Embedding) -shaped input should
ideally have high `B` and low `T` at the early curriculum steps (many
short sentences packed together as a batch), and low `B` and high `T` at
the late steps (few long sentences in the batch). A dynamic size `T` is
already supported by Deepspeed, e.g. in the documentation for pipeline
parallelism's
[reset_activation_shape()](https://deepspeed.readthedocs.io/en/stable/pipeline.html#deepspeed.runtime.pipe.engine.PipelineEngine.reset_activation_shape):
> For curriculum learning that changes the seqlen of each sample, we
need to call this whenever the seqlen is going to change.

However, dynamic `B` is not supported. A dynamic `B` would require an
adequate increase/decrease of learning rate. This technique has been
applied previously, and the two most common LR scaling algorithms have
been described as:
1. Linear Scaling Rule: "When the minibatch size is multiplied by k,
multiply the learning rate by k", as in [Accurate, Large Minibatch SGD:
Training ImageNet in 1 Hour, Goyal et
al.](https://arxiv.org/abs/1706.02677)
2. Square Root scaling: "when multiplying the batch size by k, multiply
the learning rate by √k, to keep the variance in the gradient
expectation constant" by [One weird trick for parallelizing
convolutional neural networks, A. Krizhevsky et
al.](https://arxiv.org/abs/1404.5997)

In practice, the user picks the total token count per batch as the
metric that drives batching, instead of batching by sentence count.
During runtime, the variable batch size is computed and the LR is
adjusted respectively, based on the LR and batch size provided by the
config.

# Illustration of dynamic batch size, sequence length and LR

Imagine we picked a limit of `30` tokens per batch, and have set a
reference `lr=1e-3` for a `train_batch_size=2` (in the deepspeed
config). The batching algorithm for curriculum may pack the data into
batches of short sentences (left) at the early stages, and batches of
long sentences (right) as later stages, e.g.:

![dynamic_batch_size_and_lr](https://github.com/microsoft/DeepSpeed/assets/150697676/324bda09-8f0b-430c-bb33-cc1bd01c3fe7)

Above, we collected samples until we filled up the batch with at most 30
tokens. The batch sizes (number of samples) became then `10` and `4` on
the left and right examples, respectively. Using the linear scaling
rule, the LR for those batches become `5e-3` and `2e-3`.

# Pipeline parallelism

Pipeline parallelism requires the same batch size and same sequence
length across all micro-batches in a batch, as the activation sizes must
be fixed between gradient accumulation steps. Between batches, these may
change, and long as `engine.reset_activation_shape()` is called so that
the new shapes are communicated on the first gradient accumulation step
in the batch. Enforcing similar `BxTxE` between batches may lead to
smaller micro-batches. As an example, below we can see an illustration
of a 2-node 2-gradient-accumulation-step (ie 4 micro-batches) batching
for the same dataset, when preparing data for the regular DDP (left) and
for the pipeline parallelism use cases (right):

![dynamic_batch_size_and_lr_microbatching](https://github.com/microsoft/DeepSpeed/assets/150697676/3fed5e1c-f2f5-4efe-a9c5-5b5e20719d45)

We can see that the pipeline use case (right) has the same `BxTxE` shape
across all the 4 micro-batches in the same batch, and in order to
respect that, it packs less samples in the batch, when compared to the
standard use case (left hand size)

# Attention Head

For an input of size `BxTxE` the attention has a shape of `TxT` for a
mask of fixed size across samples of same size, or `BxTxT` for a
different mask per sample (when samples have different sizes, as in the
dataset above). This 3D attention matrix can be illustrated for the DDP
microbatch 1 (picture above top-left, 4 sentences) as:

![dynamic_batch_size_and_lr_attn_matrix](https://github.com/microsoft/DeepSpeed/assets/150697676/707d2f17-66da-4034-8a12-a87df2044bfb)

Note the memory savings: the attention head has a size of `BxTxT`, i.e.
a linear memory dependency on the batch size `B` and quadratic memory
dependency on the largest sequence length `T` in the (micro-) batch.
Thus, supporting a dynamic size `T` allows for an increase of `B`.

# PR overview

This PRs implements dynamic batching and LR scaling. The dataloader and
LR scheduler necessary can be retrieved by calling
`get_dataloader_and_lr_scheduler_for_variable_batch_size`. A small
explanation of that function follows:
- The logic behind the algorithms for LR scaling is in `scale_lr`;
- The partitioning of samples into batches is done by `batch_by_seqlen`.
- For pipeline parallelism, it is required that all micro-batches in a
pipeline pass to have the same activation shapes. This is enabled by
setting to `True` the following parameters:
- `required_microbatches_of_same_sizes` that will force the `B`
dimension to be the same across all gradient accumulation steps of all
dataloaders on a batch;
- `required_microbatches_of_same_lengths` that will force the `T`
dimension to be the same across all gradient accumulation steps. Works
by calling the user-provided `sample_padding_fn(sentence, len)` that
pads a given sentence to the argument length;
- `batch_by_seqlen` returns `microbatch_sample_ids` (the list of sample
ids per micro-batch), `batch_sizes` (the size of effective batch sizes,
and `batch_max_seqlens` (longest sequence across all microbatches in a
batch)
- `dataloader_for_variable_batch_size` relies on `microbatch_sample_ids`
and will iterate/collate/pad samples for every batch and return a
dataloader that iterates the final (variable-size) batches;
- `lr_scheduler_for_variable_batch_size` relies on `batch_sizes` to
compute the learning rate for each effective batch, taking into account
the batch size and LR in the config file, and scaling the LR based on
the size of each effective batch, and the scaling rule mentioned above
(Linear, Square root, etc).
- Special note to the `lr_scheduler` returned that will either accept
either:
1. an user-provided `Optimizer` that will scale the learning rates (in
param groups) at every batch, or
2. an user-defined `LRScheduler`, that in this case will first get the
learning rate from the scheduler and then scale it accordingly.

# Example

An example for the use case with and without pipelining is provided in
file
[`DeepSpeedExamples/training/data_efficiency/variable_batch_size_and_lr/variable_batch_size_and_lr_example.py`](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/data_efficiency/variable_batch_size_and_lr).
The example shows an attention head with attention of variable-sized
`BxTxT` per batch, followed by a fixed size feed forward network. These
are the main blocks on a Large Language Model. The feed-forward (or
linear layer) that follows the attention head requires a constant input
size, equivalent to the largest sentence in the whole dataset, so the
output of the attention must be padded (see `feedforward: needs to
convert BxTxE to BxMxE by padding extra tokens` in the code).

# Config

The example file also comments the relevant deepspeed config with
comments:

```python
config = {
  "train_batch_size": 16,
  # `train_micro_batch_size_per_gpu` tells how many sequence packs of `max_tokens` each will be collated together.
  #  I.e. the number of tokens per micro batch (ie per gpu iteration) is `train_micro_batch_size_per_gpu`*`max_tokens`.
  "train_micro_batch_size_per_gpu": 2,
  "data_efficiency": {
    "enabled": True,
    # seed to be applied to all data efficiency modules, including dynamic batching
    "seed": 42,
    "data_sampling": {
      "num_workers": 0, # dataloader num_workers argument
      "pin_memory": False,  # dataloader pin_memory argument
      "dynamic_batching": {
        # enables or disables dynamic batching
        "enabled": True,
        # how many tokens we need to fill a pack of sequences (that will be collated together as a sample)
        "max_tokens": 100,
        # Input and output write to read from or write the length of every sequence.
        # Sequence lengths will be loaded from: {metrics_path}/seqlen/seqlen_sample_to_metric.bin and *.idx
        # If files dont exist, they'll be computed and saved on the first run, and loaded on subsequent runs.
        "metrics_path": "./curriculum_output/",
        # As batch size increases/decreses, which method to use to scale LR accordingly?
        # Options: linear, sqrt (square root), or None to disable
        "lr_scaling_method": "linear",
        # how to pick sentences to be packed into samples:
        # - dataloader: by same order as they come in with the dataloader
        # - seqlen: by sequence length (shortest to longest)
        # - random: random order using the seed in config['data_efficiency']['seed'
        "sentence_picking_order": "dataloader",  # "random" / "seqlen" / "dataloader"
        # minimum number of sequences required to reach `max_tokens`. If sentence pack is smaller, it's discarded.
        "min_batch_size": 1,
        # maximum number of sequences required to reach `max_tokens`. If sentence pack is larger, it's discarded.
        "max_batch_size": 10,
        # enable the output of microbatching information about sentence packing
        "verbose": True,
      },
    },
  },
}
```

# Future work

A follow-up PR will enable dynamic batching when calling
`deepspeed.initialize`. I.e. instead of this:

```python
engine, _, _, _ = deepspeed.initialize(config=config, model=model)
dataloader, lr_scheduler, _ = get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(...)
engine.lr_scheduler = lr_scheduler
```

we'd ideally have this:

```python
engine, _, dataloader, lr_scheduler = deepspeed.initialize(config=config, model=model)
```

where `initialize` will call internally
`get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed`.

---------

Signed-off-by: Bruno Magalhaes <bruno.magalhaes@synthesia.io>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: yisheng <yi.sheng@intel.com>
deepcharm pushed a commit to deepcharm/DeepSpeed that referenced this pull request Jun 16, 2025
# Background and rationale

In many use cases, particularly LLMs, one is faced with inputs
(sentences) of variable lengths. A common practice is to pack batches by
token count (not a fixed batch size), ie by putting together sentences
whose given metric (eg sequence lengths) will add up to an user-provided
value. As an example, in [Attention is all you
need](https://arxiv.org/abs/1706.03762), section 5.1:

> Sentence pairs were batched together by approximate sequence length.
Each training
batch contained a set of sentence pairs containing approximately 25000
source tokens and 25000
target tokens.

Dynamic batch sizes has been requested in [DeepSpeed issue
1051](deepspeedai#1051), [DeepSpeed
issue 3455 ](deepspeedai#3455),
[Pytorch Lightning issue
16914](Lightning-AI/pytorch-lightning#16914),
[huggingface issue
2647](huggingface/accelerate#2647) and is
available already in many libraries e.g. [NVIDIA
Triton](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher)
and [Meta FairSeq](https://github.com/facebookresearch/fairseq)
(implementation
[here](https://github.com/facebookresearch/fairseq/blob/34973a94d09ecc12092a5ecc8afece5e536b7692/fairseq/data/fairseq_dataset.py#L104)
).

The immediate use case for this is when one needs to maximize GPU
utilization. Moreover, this is particularly relevant for curriculum
learning where a `BxTxE` (Batch x Time x Embedding) -shaped input should
ideally have high `B` and low `T` at the early curriculum steps (many
short sentences packed together as a batch), and low `B` and high `T` at
the late steps (few long sentences in the batch). A dynamic size `T` is
already supported by Deepspeed, e.g. in the documentation for pipeline
parallelism's
[reset_activation_shape()](https://deepspeed.readthedocs.io/en/stable/pipeline.html#deepspeed.runtime.pipe.engine.PipelineEngine.reset_activation_shape):
> For curriculum learning that changes the seqlen of each sample, we
need to call this whenever the seqlen is going to change.

However, dynamic `B` is not supported. A dynamic `B` would require an
adequate increase/decrease of learning rate. This technique has been
applied previously, and the two most common LR scaling algorithms have
been described as:
1. Linear Scaling Rule: "When the minibatch size is multiplied by k,
multiply the learning rate by k", as in [Accurate, Large Minibatch SGD:
Training ImageNet in 1 Hour, Goyal et
al.](https://arxiv.org/abs/1706.02677)
2. Square Root scaling: "when multiplying the batch size by k, multiply
the learning rate by √k, to keep the variance in the gradient
expectation constant" by [One weird trick for parallelizing
convolutional neural networks, A. Krizhevsky et
al.](https://arxiv.org/abs/1404.5997)

In practice, the user picks the total token count per batch as the
metric that drives batching, instead of batching by sentence count.
During runtime, the variable batch size is computed and the LR is
adjusted respectively, based on the LR and batch size provided by the
config.

# Illustration of dynamic batch size, sequence length and LR

Imagine we picked a limit of `30` tokens per batch, and have set a
reference `lr=1e-3` for a `train_batch_size=2` (in the deepspeed
config). The batching algorithm for curriculum may pack the data into
batches of short sentences (left) at the early stages, and batches of
long sentences (right) as later stages, e.g.:

![dynamic_batch_size_and_lr](https://github.com/microsoft/DeepSpeed/assets/150697676/324bda09-8f0b-430c-bb33-cc1bd01c3fe7)

Above, we collected samples until we filled up the batch with at most 30
tokens. The batch sizes (number of samples) became then `10` and `4` on
the left and right examples, respectively. Using the linear scaling
rule, the LR for those batches become `5e-3` and `2e-3`.

# Pipeline parallelism

Pipeline parallelism requires the same batch size and same sequence
length across all micro-batches in a batch, as the activation sizes must
be fixed between gradient accumulation steps. Between batches, these may
change, and long as `engine.reset_activation_shape()` is called so that
the new shapes are communicated on the first gradient accumulation step
in the batch. Enforcing similar `BxTxE` between batches may lead to
smaller micro-batches. As an example, below we can see an illustration
of a 2-node 2-gradient-accumulation-step (ie 4 micro-batches) batching
for the same dataset, when preparing data for the regular DDP (left) and
for the pipeline parallelism use cases (right):

![dynamic_batch_size_and_lr_microbatching](https://github.com/microsoft/DeepSpeed/assets/150697676/3fed5e1c-f2f5-4efe-a9c5-5b5e20719d45)

We can see that the pipeline use case (right) has the same `BxTxE` shape
across all the 4 micro-batches in the same batch, and in order to
respect that, it packs less samples in the batch, when compared to the
standard use case (left hand size)

# Attention Head

For an input of size `BxTxE` the attention has a shape of `TxT` for a
mask of fixed size across samples of same size, or `BxTxT` for a
different mask per sample (when samples have different sizes, as in the
dataset above). This 3D attention matrix can be illustrated for the DDP
microbatch 1 (picture above top-left, 4 sentences) as:

![dynamic_batch_size_and_lr_attn_matrix](https://github.com/microsoft/DeepSpeed/assets/150697676/707d2f17-66da-4034-8a12-a87df2044bfb)

Note the memory savings: the attention head has a size of `BxTxT`, i.e.
a linear memory dependency on the batch size `B` and quadratic memory
dependency on the largest sequence length `T` in the (micro-) batch.
Thus, supporting a dynamic size `T` allows for an increase of `B`.

# PR overview

This PRs implements dynamic batching and LR scaling. The dataloader and
LR scheduler necessary can be retrieved by calling
`get_dataloader_and_lr_scheduler_for_variable_batch_size`. A small
explanation of that function follows:
- The logic behind the algorithms for LR scaling is in `scale_lr`;
- The partitioning of samples into batches is done by `batch_by_seqlen`.
- For pipeline parallelism, it is required that all micro-batches in a
pipeline pass to have the same activation shapes. This is enabled by
setting to `True` the following parameters:
- `required_microbatches_of_same_sizes` that will force the `B`
dimension to be the same across all gradient accumulation steps of all
dataloaders on a batch;
- `required_microbatches_of_same_lengths` that will force the `T`
dimension to be the same across all gradient accumulation steps. Works
by calling the user-provided `sample_padding_fn(sentence, len)` that
pads a given sentence to the argument length;
- `batch_by_seqlen` returns `microbatch_sample_ids` (the list of sample
ids per micro-batch), `batch_sizes` (the size of effective batch sizes,
and `batch_max_seqlens` (longest sequence across all microbatches in a
batch)
- `dataloader_for_variable_batch_size` relies on `microbatch_sample_ids`
and will iterate/collate/pad samples for every batch and return a
dataloader that iterates the final (variable-size) batches;
- `lr_scheduler_for_variable_batch_size` relies on `batch_sizes` to
compute the learning rate for each effective batch, taking into account
the batch size and LR in the config file, and scaling the LR based on
the size of each effective batch, and the scaling rule mentioned above
(Linear, Square root, etc).
- Special note to the `lr_scheduler` returned that will either accept
either:
1. an user-provided `Optimizer` that will scale the learning rates (in
param groups) at every batch, or
2. an user-defined `LRScheduler`, that in this case will first get the
learning rate from the scheduler and then scale it accordingly.

# Example

An example for the use case with and without pipelining is provided in
file
[`DeepSpeedExamples/training/data_efficiency/variable_batch_size_and_lr/variable_batch_size_and_lr_example.py`](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/data_efficiency/variable_batch_size_and_lr).
The example shows an attention head with attention of variable-sized
`BxTxT` per batch, followed by a fixed size feed forward network. These
are the main blocks on a Large Language Model. The feed-forward (or
linear layer) that follows the attention head requires a constant input
size, equivalent to the largest sentence in the whole dataset, so the
output of the attention must be padded (see `feedforward: needs to
convert BxTxE to BxMxE by padding extra tokens` in the code).

# Config

The example file also comments the relevant deepspeed config with
comments:

```python
config = {
  "train_batch_size": 16,
  # `train_micro_batch_size_per_gpu` tells how many sequence packs of `max_tokens` each will be collated together.
  #  I.e. the number of tokens per micro batch (ie per gpu iteration) is `train_micro_batch_size_per_gpu`*`max_tokens`.
  "train_micro_batch_size_per_gpu": 2,
  "data_efficiency": {
    "enabled": True,
    # seed to be applied to all data efficiency modules, including dynamic batching
    "seed": 42,
    "data_sampling": {
      "num_workers": 0, # dataloader num_workers argument
      "pin_memory": False,  # dataloader pin_memory argument
      "dynamic_batching": {
        # enables or disables dynamic batching
        "enabled": True,
        # how many tokens we need to fill a pack of sequences (that will be collated together as a sample)
        "max_tokens": 100,
        # Input and output write to read from or write the length of every sequence.
        # Sequence lengths will be loaded from: {metrics_path}/seqlen/seqlen_sample_to_metric.bin and *.idx
        # If files dont exist, they'll be computed and saved on the first run, and loaded on subsequent runs.
        "metrics_path": "./curriculum_output/",
        # As batch size increases/decreses, which method to use to scale LR accordingly?
        # Options: linear, sqrt (square root), or None to disable
        "lr_scaling_method": "linear",
        # how to pick sentences to be packed into samples:
        # - dataloader: by same order as they come in with the dataloader
        # - seqlen: by sequence length (shortest to longest)
        # - random: random order using the seed in config['data_efficiency']['seed'
        "sentence_picking_order": "dataloader",  # "random" / "seqlen" / "dataloader"
        # minimum number of sequences required to reach `max_tokens`. If sentence pack is smaller, it's discarded.
        "min_batch_size": 1,
        # maximum number of sequences required to reach `max_tokens`. If sentence pack is larger, it's discarded.
        "max_batch_size": 10,
        # enable the output of microbatching information about sentence packing
        "verbose": True,
      },
    },
  },
}
```

# Future work

A follow-up PR will enable dynamic batching when calling
`deepspeed.initialize`. I.e. instead of this:

```python
engine, _, _, _ = deepspeed.initialize(config=config, model=model)
dataloader, lr_scheduler, _ = get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(...)
engine.lr_scheduler = lr_scheduler
```

we'd ideally have this:

```python
engine, _, dataloader, lr_scheduler = deepspeed.initialize(config=config, model=model)
```

where `initialize` will call internally
`get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed`.

---------

Signed-off-by: Bruno Magalhaes <bruno.magalhaes@synthesia.io>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
snorkelopstesting1-a11y pushed a commit to snorkel-marlin-repos/deepspeedai_DeepSpeed_pr_7104_acc7866b-6c21-41f3-b5af-259889f08234 that referenced this pull request Oct 2, 2025
Original PR #7104 by bm-synth
Original: deepspeedai/DeepSpeed#7104
snorkelopstesting1-a11y added a commit to snorkel-marlin-repos/deepspeedai_DeepSpeed_pr_7104_acc7866b-6c21-41f3-b5af-259889f08234 that referenced this pull request Oct 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants