From 91e4cb1f2802dc4b387912d15820869ab2196200 Mon Sep 17 00:00:00 2001 From: gitlost-murali Date: Sun, 9 Nov 2025 11:41:21 +0000 Subject: [PATCH 1/9] feat: add per-request control for number of completions Add optional n parameter to Generator.generate() to allow overriding the default sampling_params.n on a per-request basis. Update GRPO rollout to explicitly request n=1 for single trajectory generation. --- src/forge/actors/generator.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 32ae69906..c2a8c72fb 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -287,12 +287,14 @@ def split_keys(keys): return state_dict @endpoint - async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: + async def generate(self, prompt: str, *, priority: int = 0, n: int | None = None) -> list[Completion]: """Generate a response for the given prompt Args: prompt (str): The prompt to generate a response for. priority (int, optional): The priority of the request. Defaults to 0. + n (int, optional): Number of completions to generate. If not provided, uses the default + from self.sampling_params.n. Returns: list[Completion]: n completions from vLLM based on your prompt. @@ -301,12 +303,18 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: t.start() record_metric("generator/generate/count_requests", 1, Reduce.SUM) + # Use provided n or fall back to default, creating modified params if needed + if n is not None and n != self.sampling_params.n: + params = SamplingParams.from_optional(**{**self.sampling_params.to_dict(), 'n': n}) + else: + params = self.sampling_params + self.request_id += 1 % sys.maxsize request_id = str(self.request_id) tokenization_kwargs = {} # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 - truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens + truncate_prompt_tokens = params.truncate_prompt_tokens _validate_truncation_size( self.vllm_config.model_config.max_model_len, truncate_prompt_tokens, @@ -315,7 +323,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: prompt_str, request = self.processor.process_inputs( request_id=request_id, prompt={"prompt": prompt}, - params=self.sampling_params, + params=params, arrival_time=None, tokenization_kwargs=tokenization_kwargs, trace_headers=None, @@ -331,21 +339,21 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: await self.request_lock.wait_for(lambda: self.accepting_requests) # Explicitly keeping the redundant logic to make it easier to pick up vLLM changes - if (num_samples := self.sampling_params.n) == 1: + if (num_samples := params.n) == 1: self.output_processor.add_request(request, prompt_str, None, 0) request, _ = self._preprocess_add_request(request) request_fut = asyncio.Future() self.requests[request_id] = (None, request_fut) self.scheduler.add_request(request) else: - parent_req = ParentRequest(request_id, self.sampling_params) + parent_req = ParentRequest(request_id, params) for idx in range(num_samples): # Note: `get_child_info` mutates ParentRequest to track the # generated child request - child_request_id, params = parent_req.get_child_info(idx) + child_request_id, params_child = parent_req.get_child_info(idx) child_request = request if idx == num_samples - 1 else copy(request) child_request.request_id = child_request_id - child_request.sampling_params = params + child_request.sampling_params = params_child self.output_processor.add_request( child_request, prompt_str, parent_req, idx ) From 18bba893a732df69b5f8faac72bf170ecc505fed Mon Sep 17 00:00:00 2001 From: gitlost-murali Date: Mon, 10 Nov 2025 23:26:38 +0000 Subject: [PATCH 2/9] refactor: Refactor sampling parameter handling in Generator class Updated the logic for modifying the sampling_params.n attribute to use the __replace__ method for better clarity and consistency. Added a new unit test to verify the behavior of the n parameter logic in the Generator class. --- src/forge/actors/generator.py | 2 +- tests/unit_tests/test_generator_config.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index c2a8c72fb..fdabefb2d 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -305,7 +305,7 @@ async def generate(self, prompt: str, *, priority: int = 0, n: int | None = None # Use provided n or fall back to default, creating modified params if needed if n is not None and n != self.sampling_params.n: - params = SamplingParams.from_optional(**{**self.sampling_params.to_dict(), 'n': n}) + params = self.sampling_params.__replace__(n=n) else: params = self.sampling_params diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py index 94cb58859..a5ea2b31e 100644 --- a/tests/unit_tests/test_generator_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -132,6 +132,27 @@ def test_generator_yaml_config_loading(self): self.assertEqual(generator.sampling_params.n, 2) self.assertEqual(generator.sampling_params.max_tokens, 32) + @pytest.mark.skipif( + _import_error(), + reason="Import error, likely due to missing dependencies on CI.", + ) + def test_generate_n_parameter_logic(self): + from forge.actors.generator import Generator + + generator = Generator(sampling_params={"n": 2, "max_tokens": 16}) + base_params = generator.sampling_params + + def get_params_for(n_override: int | None): + if n_override in (None, base_params.n): + return base_params + return base_params.__replace__(n=n_override) + + self.assertIs(get_params_for(None), base_params) + self.assertIs(get_params_for(2), base_params) + updated = get_params_for(4) + self.assertEqual(updated.n, 4) + self.assertIsNot(updated, base_params) + if __name__ == "__main__": unittest.main() From af98cfa2a8c33df79615e33bbb115dbbb40f9549 Mon Sep 17 00:00:00 2001 From: gitlost-murali Date: Mon, 10 Nov 2025 23:33:58 +0000 Subject: [PATCH 3/9] chore: pre-commit refactor --- src/forge/actors/generator.py | 4 +++- src/forge/controller/launcher.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index fdabefb2d..0b5f4b2c5 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -287,7 +287,9 @@ def split_keys(keys): return state_dict @endpoint - async def generate(self, prompt: str, *, priority: int = 0, n: int | None = None) -> list[Completion]: + async def generate( + self, prompt: str, *, priority: int = 0, n: int | None = None + ) -> list[Completion]: """Generate a response for the given prompt Args: diff --git a/src/forge/controller/launcher.py b/src/forge/controller/launcher.py index a11ab50be..dd74591f1 100644 --- a/src/forge/controller/launcher.py +++ b/src/forge/controller/launcher.py @@ -17,8 +17,6 @@ import monarch import torchx.specs as specs - -from forge.types import Launcher, LauncherConfig from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport @@ -29,6 +27,8 @@ from monarch.tools.commands import create, info from monarch.tools.config import Config, Workspace +from forge.types import Launcher, LauncherConfig + _MAST_AVAILABLE = False try: From 66a7ee9497721ba5ac87f8ecd81f67ee022bf619 Mon Sep 17 00:00:00 2001 From: gitlost-murali Date: Mon, 10 Nov 2025 23:45:26 +0000 Subject: [PATCH 4/9] chore: remove redundant comment --- src/forge/actors/generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 0b5f4b2c5..ac4cf5f49 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -305,7 +305,6 @@ async def generate( t.start() record_metric("generator/generate/count_requests", 1, Reduce.SUM) - # Use provided n or fall back to default, creating modified params if needed if n is not None and n != self.sampling_params.n: params = self.sampling_params.__replace__(n=n) else: From 5620d8eb3a1aa6a44f2477d12bd69ab024d0965b Mon Sep 17 00:00:00 2001 From: gitlost-murali Date: Thu, 13 Nov 2025 21:19:55 +0000 Subject: [PATCH 5/9] refactor: replace n parameter with sampling_params in generate() Allows per-request override of any sampling parameter (temperature, top_p, n, etc.) instead of just n. Preserves output_kind=FINAL_ONLY enforcement from post_init logic. --- src/forge/actors/generator.py | 18 +++++++++++------- tests/unit_tests/test_generator_config.py | 21 --------------------- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index ac4cf5f49..5426edc1c 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -288,15 +288,19 @@ def split_keys(keys): @endpoint async def generate( - self, prompt: str, *, priority: int = 0, n: int | None = None + self, + prompt: str, + *, + priority: int = 0, + sampling_params: SamplingParams | None = None, ) -> list[Completion]: """Generate a response for the given prompt Args: prompt (str): The prompt to generate a response for. priority (int, optional): The priority of the request. Defaults to 0. - n (int, optional): Number of completions to generate. If not provided, uses the default - from self.sampling_params.n. + sampling_params (SamplingParams, optional): Sampling parameters to use for this request. + If not provided, uses self.sampling_params. Returns: list[Completion]: n completions from vLLM based on your prompt. @@ -305,10 +309,10 @@ async def generate( t.start() record_metric("generator/generate/count_requests", 1, Reduce.SUM) - if n is not None and n != self.sampling_params.n: - params = self.sampling_params.__replace__(n=n) - else: - params = self.sampling_params + params = sampling_params or self.sampling_params + # Ensure output_kind is set to FINAL_ONLY (as required by post_init) + if params.output_kind != RequestOutputKind.FINAL_ONLY: + params = params.__replace__(output_kind=RequestOutputKind.FINAL_ONLY) self.request_id += 1 % sys.maxsize request_id = str(self.request_id) diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py index a5ea2b31e..94cb58859 100644 --- a/tests/unit_tests/test_generator_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -132,27 +132,6 @@ def test_generator_yaml_config_loading(self): self.assertEqual(generator.sampling_params.n, 2) self.assertEqual(generator.sampling_params.max_tokens, 32) - @pytest.mark.skipif( - _import_error(), - reason="Import error, likely due to missing dependencies on CI.", - ) - def test_generate_n_parameter_logic(self): - from forge.actors.generator import Generator - - generator = Generator(sampling_params={"n": 2, "max_tokens": 16}) - base_params = generator.sampling_params - - def get_params_for(n_override: int | None): - if n_override in (None, base_params.n): - return base_params - return base_params.__replace__(n=n_override) - - self.assertIs(get_params_for(None), base_params) - self.assertIs(get_params_for(2), base_params) - updated = get_params_for(4) - self.assertEqual(updated.n, 4) - self.assertIsNot(updated, base_params) - if __name__ == "__main__": unittest.main() From c87c1545f923c212b37383a040cfac684a7a380f Mon Sep 17 00:00:00 2001 From: gitlost-murali Date: Thu, 13 Nov 2025 21:32:50 +0000 Subject: [PATCH 6/9] test: verify output_kind override to FINAL_ONLY Add test to ensure Generator always overrides output_kind to FINAL_ONLY when initialized with custom sampling_params dict, protecting against accidental removal of this system requirement. --- tests/unit_tests/test_generator_config.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py index 94cb58859..542772f7f 100644 --- a/tests/unit_tests/test_generator_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -132,6 +132,26 @@ def test_generator_yaml_config_loading(self): self.assertEqual(generator.sampling_params.n, 2) self.assertEqual(generator.sampling_params.max_tokens, 32) + @pytest.mark.skipif( + _import_error(), + reason="Import error, likely due to missing dependencies on CI.", + ) + def test_generator_overrides_output_kind_to_final_only(self): + """Generator overrides output_kind to FINAL_ONLY when initialized with dict.""" + from forge.actors.generator import Generator + from vllm.sampling_params import RequestOutputKind + + sampling_dict = { + "n": 1, + "max_tokens": 16, + "output_kind": RequestOutputKind.DELTA, + } + + generator = Generator(sampling_params=sampling_dict) + self.assertEqual( + generator.sampling_params.output_kind, RequestOutputKind.FINAL_ONLY + ) + if __name__ == "__main__": unittest.main() From 45e58920f30d47bf623cda0717587c70c381c98d Mon Sep 17 00:00:00 2001 From: Murali Manohar Kondragunta Date: Fri, 14 Nov 2025 01:08:56 +0100 Subject: [PATCH 7/9] refactor: simplify overriding output_kind Co-authored-by: Felipe Mello --- src/forge/actors/generator.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 5426edc1c..b85e161e6 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -309,10 +309,11 @@ async def generate( t.start() record_metric("generator/generate/count_requests", 1, Reduce.SUM) + if sampling_params is not None: + # as in `post_init` + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + params = sampling_params or self.sampling_params - # Ensure output_kind is set to FINAL_ONLY (as required by post_init) - if params.output_kind != RequestOutputKind.FINAL_ONLY: - params = params.__replace__(output_kind=RequestOutputKind.FINAL_ONLY) self.request_id += 1 % sys.maxsize request_id = str(self.request_id) From df3b61390656995f27516886c24166b381da4fe0 Mon Sep 17 00:00:00 2001 From: gitlost-murali Date: Fri, 14 Nov 2025 00:54:05 +0000 Subject: [PATCH 8/9] test: remove redundant output_kind override test --- tests/unit_tests/test_generator_config.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py index 542772f7f..94cb58859 100644 --- a/tests/unit_tests/test_generator_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -132,26 +132,6 @@ def test_generator_yaml_config_loading(self): self.assertEqual(generator.sampling_params.n, 2) self.assertEqual(generator.sampling_params.max_tokens, 32) - @pytest.mark.skipif( - _import_error(), - reason="Import error, likely due to missing dependencies on CI.", - ) - def test_generator_overrides_output_kind_to_final_only(self): - """Generator overrides output_kind to FINAL_ONLY when initialized with dict.""" - from forge.actors.generator import Generator - from vllm.sampling_params import RequestOutputKind - - sampling_dict = { - "n": 1, - "max_tokens": 16, - "output_kind": RequestOutputKind.DELTA, - } - - generator = Generator(sampling_params=sampling_dict) - self.assertEqual( - generator.sampling_params.output_kind, RequestOutputKind.FINAL_ONLY - ) - if __name__ == "__main__": unittest.main() From 86e711c7ccb87418451e476f84e2754429fe71b9 Mon Sep 17 00:00:00 2001 From: gitlost-murali Date: Fri, 14 Nov 2025 01:03:03 +0000 Subject: [PATCH 9/9] chore: fix linting issue --- src/forge/actors/generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index b85e161e6..0d808fcfb 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -310,9 +310,9 @@ async def generate( record_metric("generator/generate/count_requests", 1, Reduce.SUM) if sampling_params is not None: - # as in `post_init` - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - + # as in `post_init` + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + params = sampling_params or self.sampling_params self.request_id += 1 % sys.maxsize