11"""
2- VLLMPolicy - Policy for TRL's VLLMClient
2+ VLLMPolicy - Policy for TRL's VLLMClient or colocated vLLM LLM.
33
4- Simple policy that calls TRL's vllm_client directly instead of going through LiteLLM.
5- Works with `trl vllm-serve` endpoints.
4+ Thin adapter that turns Eval Protocol-style message lists into a single prompt,
5+ then calls either:
6+
7+ - TRL's VLLMClient (server mode), or
8+ - a colocated vLLM LLM instance (SamplingParams mode).
69"""
710
11+ import logging
812from typing import Any , Dict , List , Optional
913
1014
15+ logger = logging .getLogger (__name__ )
16+
17+
1118class VLLMPolicy :
1219 """
1320 Policy that uses TRL's VLLMClient for generation.
@@ -52,7 +59,7 @@ async def _make_llm_call(
5259 tools : Optional [List ] = None ,
5360 ) -> Dict [str , Any ]:
5461 """
55- Make LLM call using TRL's VLLMClient.
62+ Make LLM call using TRL's VLLMClient or a colocated vLLM LLM .
5663
5764 Args:
5865 messages: List of message dicts with 'role' and 'content'
@@ -70,29 +77,29 @@ async def _make_llm_call(
7077 add_generation_prompt = True ,
7178 tokenize = False ,
7279 )
73- print ("\n [VLLMPolicy] ===== CHAT TEMPLATE APPLIED =====" , flush = True )
74- print (f"[VLLMPolicy] Input messages ({ len (messages )} messages):" , flush = True )
75- for i , msg in enumerate (messages ):
76- content_preview = str (msg .get ("content" , "" ))[:100 ]
77- print (f" [{ i } ] { msg .get ('role' , '?' )} : { content_preview } ..." , flush = True )
78- print (f"[VLLMPolicy] Formatted prompt (length={ len (prompt_text )} ):" , flush = True )
79- print ("[VLLMPolicy] Prompt preview (last 500 chars):" , flush = True )
80- print (f"{ prompt_text [- 500 :]} " , flush = True )
81- print ("[VLLMPolicy] ===================================" , flush = True )
80+ logger .debug (
81+ "[VLLMPolicy] Chat template applied for %d messages (prompt length=%d)" ,
82+ len (messages ),
83+ len (prompt_text ),
84+ )
8285 except Exception as e :
83- print (f"[VLLMPolicy] Warning: Failed to apply chat template: { e } " , flush = True )
84- # Fallback: simple concatenation
85- prompt_text = "\n " .join (f"{ m ['role' ]} : { m ['content' ]} " for m in messages )
86+ logger .warning (
87+ "[VLLMPolicy] Failed to apply chat template: %s" ,
88+ e ,
89+ exc_info = True ,
90+ )
91+ # Fallback: simple concatenation (defensive .get access)
92+ prompt_text = "\n " .join (f"{ m .get ('role' , '?' )} : { m .get ('content' , '' )} " for m in messages )
8693 else :
8794 # No tokenizer: simple concatenation
88- prompt_text = "\n " .join (f"{ m [ 'role' ] } : { m [ 'content' ] } " for m in messages )
95+ prompt_text = "\n " .join (f"{ m . get ( 'role' , '?' ) } : { m . get ( 'content' , '' ) } " for m in messages )
8996
9097 # Check if vllm_client is VLLMClient (server mode) or LLM (colocate mode)
9198 is_llm_object = hasattr (self .vllm_client , "llm_engine" ) # LLM has llm_engine
9299
93100 if is_llm_object :
94101 # Colocate mode: use SamplingParams
95- print ("[VLLMPolicy] Using vLLM LLM (colocate mode) with SamplingParams" , flush = True )
102+ logger . debug ("[VLLMPolicy] Using vLLM LLM (colocate mode) with SamplingParams" )
96103 from vllm import SamplingParams
97104
98105 sampling_params = SamplingParams (
@@ -103,7 +110,7 @@ async def _make_llm_call(
103110 n = 1 ,
104111 )
105112
106- print ("[VLLMPolicy] Calling LLM.generate()..." , flush = True )
113+ logger . debug ("[VLLMPolicy] Calling LLM.generate()" )
107114 outputs = self .vllm_client .generate ([prompt_text ], sampling_params = sampling_params , use_tqdm = False )
108115
109116 # Extract from vLLM output format
@@ -116,7 +123,7 @@ async def _make_llm_call(
116123 }
117124 else :
118125 # Server mode: use VLLMClient with kwargs
119- print ("[VLLMPolicy] Using VLLMClient (server mode)" , flush = True )
126+ logger . debug ("[VLLMPolicy] Using VLLMClient (server mode)" )
120127 vllm_params = {
121128 "temperature" : self .temperature ,
122129 "max_tokens" : self .max_tokens ,
@@ -126,7 +133,7 @@ async def _make_llm_call(
126133 }
127134 vllm_params .update (self .kwargs )
128135
129- print ("[VLLMPolicy] Calling vllm_client.generate()..." , flush = True )
136+ logger . debug ("[VLLMPolicy] Calling vllm_client.generate()" )
130137 response = self .vllm_client .generate (
131138 prompts = [prompt_text ],
132139 ** vllm_params ,
@@ -140,16 +147,18 @@ async def _make_llm_call(
140147 if self .tokenizer is not None :
141148 try :
142149 completion_text = self .tokenizer .decode (completion_ids , skip_special_tokens = True )
143- print ("\n [VLLMPolicy] ===== GENERATION RESULT =====" , flush = True )
144- print (f"[VLLMPolicy] Prompt tokens: { len (prompt_ids )} " , flush = True )
145- print (f"[VLLMPolicy] Completion tokens: { len (completion_ids )} " , flush = True )
146- print (f"[VLLMPolicy] FULL decoded completion ({ len (completion_text )} chars):" , flush = True )
147- print ("───────────────────────────────────────" , flush = True )
148- print (f"{ completion_text } " , flush = True )
149- print ("───────────────────────────────────────" , flush = True )
150- print ("[VLLMPolicy] ==============================" , flush = True )
150+ logger .debug (
151+ "[VLLMPolicy] Generation result: prompt_tokens=%d, completion_tokens=%d, completion_chars=%d" ,
152+ len (prompt_ids ),
153+ len (completion_ids ),
154+ len (completion_text ),
155+ )
151156 except Exception as e :
152- print (f"[VLLMPolicy] Warning: Failed to decode completion: { e } " , flush = True )
157+ logger .warning (
158+ "[VLLMPolicy] Failed to decode completion: %s" ,
159+ e ,
160+ exc_info = True ,
161+ )
153162 completion_text = f"<decoded_error:{ len (completion_ids )} _tokens>"
154163 else :
155164 # Fallback: just indicate number of tokens
0 commit comments