Skip to content

Commit b97e75a

Browse files
committed
Move loss_fn into model definition in PP example
Instead of passing loss_fn to AutoParallelPP, wrap the model in a ModelWithLoss module that bakes cross-entropy loss into forward(). This makes the example compatible with tracing self.model directly. stack-info: PR: #324, branch: xmfan/stack/25
1 parent f2853db commit b97e75a

3 files changed

Lines changed: 17 additions & 75 deletions

File tree

autoparallel/api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ def __init__(
273273
self.enable_ac = enable_ac
274274
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
275275
self.reshard_after_forward = reshard_after_forward
276-
self.loss_fn = None
277276

278277
if dynamic:
279278
self.fake_mode.shape_env = ShapeEnv()

autoparallel/api_pp.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -62,78 +62,6 @@ def forward(self, *args):
6262

6363

6464
class AutoParallelPP(AutoParallel):
65-
def __init__(
66-
self,
67-
model,
68-
input_fn,
69-
mesh,
70-
mp_policy=None,
71-
compile: bool = False,
72-
enable_ac: bool = True,
73-
ac_stage_size_in_GiB=None,
74-
reshard_after_forward: bool = True,
75-
dynamic: bool = False,
76-
loss_fn: Optional[Any] = None,
77-
numerics_logger=None,
78-
**kwargs,
79-
):
80-
# Call parent __init__ without loss_fn
81-
super().__init__(
82-
model=model,
83-
input_fn=input_fn,
84-
mesh=mesh,
85-
mp_policy=mp_policy,
86-
compile=compile,
87-
enable_ac=enable_ac,
88-
ac_stage_size_in_GiB=ac_stage_size_in_GiB,
89-
reshard_after_forward=reshard_after_forward,
90-
dynamic=dynamic,
91-
numerics_logger=numerics_logger,
92-
**kwargs,
93-
)
94-
# Set loss_fn after parent initialization
95-
self.loss_fn = loss_fn
96-
97-
def _prepare_model_wrapper_and_inputs(
98-
self, raw_inputs: Any
99-
) -> tuple[Any, tuple[Any, ...]]:
100-
"""
101-
Prepare the model wrapper and formatted inputs for tracing.
102-
103-
Overrides the base class to handle loss_fn when provided.
104-
105-
Args:
106-
raw_inputs: The raw inputs from input_fn()
107-
108-
Returns:
109-
A tuple of (model_wrapper, formatted_inputs) where:
110-
- model_wrapper is a callable that will be traced
111-
- formatted_inputs are the inputs to pass to model_wrapper
112-
"""
113-
if self.loss_fn is not None:
114-
# Expected format: ((inp1, inp2,...), target)
115-
if isinstance(raw_inputs, tuple) and len(raw_inputs) == 2:
116-
model_inputs, target = raw_inputs
117-
# Normalize inputs to always be a tuple
118-
if not isinstance(model_inputs, tuple):
119-
model_inputs = (model_inputs,)
120-
formatted_inputs = (model_inputs, target)
121-
122-
def model_with_loss(model_inputs, target) -> Any:
123-
output = self.model(*model_inputs)
124-
loss = self.loss_fn(output, target) # type: ignore[misc]
125-
return loss
126-
127-
return model_with_loss, formatted_inputs
128-
else:
129-
raise ValueError(
130-
"When loss_fn is provided, input_fn must return (inputs, target) "
131-
"where inputs can be a single tensor or tuple of tensors"
132-
)
133-
else:
134-
# No loss function, use parent implementation
135-
return super()._prepare_model_wrapper_and_inputs(raw_inputs)
136-
13765
def apply_placement_pp(
13866
self, sharding_placement=None, graph_passes: list[str] = []
13967
) -> dict[str, Any]:

examples/example_pp_graph_passes.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,12 @@ def _get_pp_module_and_graphs(
104104
dynamic=True,
105105
compile=False,
106106
reshard_after_forward=False,
107-
loss_fn=dsv3_loss_fn if use_loss_fn else None,
108107
) as autop:
109108
autop.add_parameter_memory_constraint(low=None, high=None)
110109

111110
# x_sharding = (Shard(0), Replicate())
112111
x_sharding = (Shard(0), Shard(0))
113-
if autop.loss_fn is not None:
112+
if use_loss_fn:
114113
autop.add_input_constraints([x_sharding, x_sharding])
115114
autop.add_output_constraints([(Replicate(), Replicate())])
116115
else:
@@ -466,6 +465,22 @@ def test_combined(
466465
model = DeepSeekV3Model(config).bfloat16()
467466
model.tok_embeddings = None # type: ignore[assignment]
468467

468+
if use_loss_fn:
469+
470+
class ModelWithLoss(torch.nn.Module):
471+
def __init__(self, model):
472+
super().__init__()
473+
self.model = model
474+
475+
def forward(self, h, labels):
476+
output = self.model(h)
477+
return dsv3_loss_fn(output, labels)
478+
479+
def init_weights(self, *args, **kwargs):
480+
return self.model.init_weights(*args, **kwargs)
481+
482+
model = ModelWithLoss(model)
483+
469484
def make_input_fn(sharded: bool = False, with_target: bool = False):
470485
"""Create input generator. `sharded` uses mesh-adjusted batch size."""
471486

0 commit comments

Comments
 (0)