Skip to content

Commit 0097a2d

Browse files
allenwang28Allen Wang
andauthored
Add Trainer Protocol (#533)
Co-authored-by: Allen Wang <allencwang@fb.com>
1 parent 71eaa0a commit 0097a2d

File tree

3 files changed

+526
-0
lines changed

3 files changed

+526
-0
lines changed

src/forge/api/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Forge public API module.
8+
9+
This module defines the public interfaces that all Forge implementations conform to.
10+
"""
11+
12+
from forge.api.trainer import Trainer
13+
from forge.api.types import (
14+
ForwardBackwardResult,
15+
LossFn,
16+
OptimStepResult,
17+
ParallelismConfig,
18+
TextTrainBatch,
19+
TrainerConfig,
20+
TrainerStatus,
21+
)
22+
23+
__all__ = [
24+
"Trainer",
25+
"TextTrainBatch",
26+
"ForwardBackwardResult",
27+
"OptimStepResult",
28+
"TrainerConfig",
29+
"TrainerStatus",
30+
"ParallelismConfig",
31+
"LossFn",
32+
]

src/forge/api/trainer.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Protocol, runtime_checkable
8+
9+
import torch
10+
11+
from forge.api.types import (
12+
ForwardBackwardResult,
13+
LossFn,
14+
OptimStepResult,
15+
TextTrainBatch,
16+
TrainerConfig,
17+
TrainerStatus,
18+
)
19+
20+
21+
@runtime_checkable
22+
class Trainer(Protocol):
23+
"""Protocol defining the standard interface for all Forge trainers.
24+
25+
Trainer implementations are expected to accept a default loss function at
26+
initialization time. This loss function is used when loss_fn is not
27+
provided to forward_backward(). The default loss should follow the
28+
LossFn signature.
29+
"""
30+
31+
async def forward_backward(
32+
self, batch: TextTrainBatch, loss_fn: LossFn | None = None
33+
) -> ForwardBackwardResult:
34+
"""Execute forward pass and backward pass for one batch of data.
35+
36+
Basic usage - single batch per optimizer step:
37+
>>> batch = TextTrainBatch(
38+
>>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]),
39+
>>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]),
40+
>>> )
41+
>>> result = await trainer.forward_backward(batch)
42+
>>> await trainer.optim_step() # Apply gradients
43+
44+
To accumulate gradients over multiple batches before optimizer step:
45+
>>> await trainer.forward_backward(batch1) # Accumulates
46+
>>> await trainer.forward_backward(batch2) # Accumulates another batch
47+
>>> await trainer.optim_step() # Apply all accumulated gradients
48+
49+
Custom loss function for specific batches:
50+
>>> def custom_loss(outputs: dict[str, Any], batch: TextTrainBatch) -> torch.Tensor:
51+
>>> # Custom loss computation (e.g., PPO clip, DPO, cut cross entropy, etc.)
52+
>>> logits = outputs["logits"]
53+
>>> # ... compute loss from logits, or use other outputs like hidden_states
54+
>>> return loss
55+
>>>
56+
>>> result = await trainer.forward_backward(batch, loss_fn=custom_loss)
57+
58+
Args:
59+
batch: TextTrainBatch containing input_ids, target_ids, and optional
60+
target_mask/target_weights. See forge.api.types.TextTrainBatch for details.
61+
loss_fn: Optional custom loss function. If None, uses the loss function
62+
configured at trainer creation. Signature: (outputs, batch) -> loss
63+
where outputs is a dict with at least "logits" key.
64+
Useful for mixed training objectives or experimentation.
65+
66+
Returns:
67+
ForwardBackwardResult containing loss and metrics
68+
69+
Note:
70+
The default loss function is configured at trainer creation time via the
71+
`loss` parameter. The `loss_fn` parameter here allows per-batch override.
72+
All loss functions should accept (outputs: dict[str, Any], batch: TextTrainBatch)
73+
where outputs contains at minimum a "logits" key.
74+
"""
75+
...
76+
77+
async def optim_step(self) -> OptimStepResult:
78+
"""Apply optimizer step using accumulated gradients, then clear gradients.
79+
80+
This method:
81+
1. Applies accumulated gradients via the optimizer
82+
2. Steps the learning rate scheduler
83+
3. Clears all gradients (zero_grad)
84+
4. Increments the training step counter
85+
5. May trigger automatic checkpointing (implementation-dependent)
86+
87+
Gradients must have been accumulated via forward_backward() calls before
88+
calling this method.
89+
90+
Returns:
91+
OptimStepResult containing step number, learning rate, and accumulated batch count
92+
93+
Example:
94+
>>> # Accumulate over 4 batches
95+
>>> for batch in batches[:4]:
96+
>>> await trainer.forward_backward(batch)
97+
>>> result = await trainer.optim_step()
98+
>>> result.step
99+
1000
100+
>>> result.learning_rate
101+
0.0001
102+
>>> result.accumulated_microbatches
103+
4
104+
"""
105+
...
106+
107+
async def clear_gradients(self) -> None:
108+
"""Clear accumulated gradients without applying them.
109+
110+
Use this when you need to discard accumulated gradients without performing
111+
an optimizer step. Common scenarios:
112+
- Exception during gradient accumulation
113+
- Skipping a training step due to some condition
114+
- Recovering from OOM or other errors
115+
116+
This is equivalent to calling optimizer.zero_grad() and resetting internal
117+
accumulation counters.
118+
119+
Example - Error recovery:
120+
>>> try:
121+
>>> for batch in batches:
122+
>>> await trainer.forward_backward(batch)
123+
>>> await trainer.optim_step()
124+
>>> except torch.cuda.OutOfMemoryError:
125+
>>> await trainer.clear_gradients() # Discard partial gradients
126+
>>> # Retry with smaller batches
127+
128+
Example - Conditional skip:
129+
>>> await trainer.forward_backward(batch)
130+
>>> if should_skip_step():
131+
>>> await trainer.clear_gradients() # Don't apply these gradients
132+
>>> else:
133+
>>> await trainer.optim_step()
134+
"""
135+
...
136+
137+
async def forward(self, inputs: dict[str, torch.Tensor]) -> torch.Tensor:
138+
"""Run forward pass only, without backward pass (for evaluation/inference).
139+
140+
This method executes the model's forward pass without computing gradients.
141+
Useful for:
142+
- Evaluation on validation/test data
143+
- Getting model predictions/logits
144+
- Debugging model outputs
145+
146+
Args:
147+
inputs: Dictionary containing model inputs. Typically includes:
148+
- input_ids: torch.Tensor [batch_size, seq_len]
149+
Other keys depend on the model architecture.
150+
151+
Returns:
152+
Model output logits. Shape: [batch_size, seq_len, vocab_size]
153+
154+
Note:
155+
This runs in torch.no_grad() context - no gradients are computed.
156+
157+
Example:
158+
>>> eval_batch = {"input_ids": torch.tensor([[1, 2, 3, 4]])}
159+
>>> logits = await trainer.forward(eval_batch) # [1, 4, vocab_size]
160+
>>> predictions = logits.argmax(dim=-1) # [1, 4]
161+
"""
162+
...
163+
164+
async def save(
165+
self,
166+
name: str | None = None,
167+
path: str | None = None,
168+
weights_only: bool = False,
169+
) -> str:
170+
"""Save trainer state or weights to persistent storage.
171+
172+
By default, saves complete training state (model weights, optimizer state,
173+
learning rate scheduler state, and step counter). Set weights_only=True to
174+
save only model weights for inference/deployment.
175+
176+
Args:
177+
name: Optional checkpoint name/identifier. If None, uses the current
178+
step number (e.g., "step-1000" or "weights-step-1000").
179+
path: Optional base directory or URI where checkpoint should be saved.
180+
If None, uses the default checkpoint directory configured at trainer
181+
creation. Supports different backends via URI schemes:
182+
- `/local/path` - local filesystem
183+
- `ts://key` - TorchStore
184+
- `s3://bucket/key` - S3
185+
weights_only: If True, saves only model weights (lighter, for inference).
186+
If False (default), saves full training state including optimizer.
187+
188+
189+
Returns:
190+
Full path/URI where checkpoint was saved
191+
192+
Example:
193+
>>> # Save full training state (default)
194+
>>> path = await trainer.save(name="checkpoint-1000")
195+
>>> path
196+
"/default/checkpoint-1000"
197+
>>>
198+
>>> # Save weights only for inference
199+
>>> path = await trainer.save(name="policy-v1", weights_only=True)
200+
>>> path
201+
"/default/policy-v1"
202+
>>>
203+
>>> # Save to TorchStore
204+
>>> path = await trainer.save(name="best", path="ts://checkpoints")
205+
>>> path
206+
"ts://checkpoints/best"
207+
"""
208+
...
209+
210+
async def load(self, path: str | None = None) -> str:
211+
"""Load a previously saved checkpoint.
212+
213+
Restores training state from a checkpoint. Automatically handles both
214+
full checkpoints and weights-only checkpoints.
215+
216+
Args:
217+
path: Optional path or URI to the checkpoint to load. If None, loads
218+
the most recent checkpoint from the default directory. Can be:
219+
- `/local/path/checkpoint` - local filesystem
220+
- `ts://key` - TorchStore
221+
- `s3://bucket/key` - S3
222+
223+
Returns:
224+
Path/URI that was loaded
225+
226+
Example:
227+
>>> # Load latest checkpoint from default location
228+
>>> path = await trainer.load()
229+
>>> path
230+
"/default/step-5000"
231+
>>>
232+
>>> # Load specific checkpoint by path
233+
>>> path = await trainer.load("/checkpoints/step-5000")
234+
>>> path
235+
"/checkpoints/step-5000"
236+
>>>
237+
>>> # Load from TorchStore
238+
>>> path = await trainer.load("ts://checkpoint-key")
239+
>>> path
240+
"ts://checkpoint-key"
241+
"""
242+
...
243+
244+
async def get_config(self) -> TrainerConfig:
245+
"""Get static trainer and model configuration.
246+
247+
Returns configuration information that doesn't change during training.
248+
For runtime state like current step, use get_status() instead.
249+
250+
Returns:
251+
TrainerConfig containing model name, model_config, and parallelism settings
252+
253+
Example:
254+
>>> config = await trainer.get_config()
255+
>>> config.model_name
256+
"Qwen/Qwen2.5-7B"
257+
>>> config.model_config["vocab_size"]
258+
151936
259+
>>> config.parallelism.dp_degree
260+
4
261+
>>> config.parallelism.device
262+
"cuda:0"
263+
"""
264+
...
265+
266+
async def get_status(self) -> TrainerStatus:
267+
"""Get current runtime status of the trainer.
268+
269+
Returns dynamic information about the trainer's current state that changes
270+
during training.
271+
272+
Returns:
273+
TrainerStatus containing current step and accumulated batch count
274+
275+
Example:
276+
>>> status = await trainer.get_status()
277+
>>> status.step
278+
1000
279+
>>> status.accumulated_microbatches
280+
2
281+
"""
282+
...
283+
284+
async def get_tokenizer(self):
285+
"""Get the tokenizer associated with this model.
286+
287+
Returns the tokenizer used for encoding/decoding text with this model.
288+
Useful for preprocessing inputs or decoding model outputs.
289+
290+
Returns:
291+
PreTrainedTokenizer: The HuggingFace tokenizer for this model
292+
293+
Example:
294+
>>> tokenizer = await trainer.get_tokenizer()
295+
>>> tokens = tokenizer.encode("Hello world")
296+
>>> text = tokenizer.decode([1, 2, 3, 4])
297+
"""
298+
...

0 commit comments

Comments
 (0)