11import asyncio
22import time
3- from typing import List
3+ from typing import List , Any , cast
44
55try :
6- from langchain_core .messages import BaseMessage
7- except Exception : # pragma: no cover - optional dependency path
8- # Minimal fallback base type to satisfy typing when langchain is not present
9- class BaseMessage : # type: ignore
10- pass
6+ from langchain_core .messages import BaseMessage as LCBaseMessage , HumanMessage # type: ignore
7+ except ImportError : # pragma: no cover - optional dependency path
8+ # Minimal fallbacks to satisfy typing when langchain is not present
9+ class LCBaseMessage : # type: ignore
10+ content : str
11+ type : str
12+
13+ def __init__ (self , content : str = "" , msg_type : str = "assistant" ):
14+ self .content = content
15+ self .type = msg_type
16+
17+ class HumanMessage (LCBaseMessage ): # type: ignore
18+ def __init__ (self , content : str ):
19+ super ().__init__ (content = content , msg_type = "human" )
1120
1221
1322from eval_protocol .models import EvaluationRow , Message
14- from openai .types import CompletionUsage
1523from eval_protocol .pytest .rollout_processor import RolloutProcessor
1624from eval_protocol .pytest .types import RolloutProcessorConfig
1725
@@ -34,27 +42,17 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
3442 async def _process_row (row : EvaluationRow ) -> EvaluationRow :
3543 start_time = time .perf_counter ()
3644
37- # Build LC messages from EP row
38- try :
39- from langchain_core .messages import HumanMessage
40- except Exception :
41- # Fallback minimal message if langchain_core is unavailable
42- class HumanMessage (BaseMessage ): # type: ignore
43- def __init__ (self , content : str ):
44- self .content = content
45- self .type = "human"
46-
47- lm_messages : List [BaseMessage ] = []
45+ # Build LC messages from EP row (minimal: last user to HumanMessage)
46+ lm_messages : List [LCBaseMessage ] = []
4847 if row .messages :
4948 last_user = [m for m in row .messages if m .role == "user" ]
5049 if last_user :
5150 content = last_user [- 1 ].content or ""
5251 if isinstance (content , list ):
53- # Flatten our SDK content parts into a single string for LangChain
5452 content = "" .join ([getattr (p , "text" , str (p )) for p in content ])
5553 lm_messages .append (HumanMessage (content = str (content )))
5654 if not lm_messages :
57- lm_messages = [HumanMessage (content = "" )] # minimal
55+ lm_messages = [HumanMessage (content = "" )]
5856
5957 target = await self .get_invoke_target (config )
6058
@@ -72,7 +70,7 @@ async def _invoke_direct(payload):
7270
7371 invoke_fn = _invoke_direct
7472 elif callable (target ):
75- # If target is a normal callable, call it directly; if it returns an awaitable, await it
73+
7674 async def _invoke_wrapper (payload ):
7775 result = target (payload )
7876 if asyncio .iscoroutine (result ):
@@ -84,44 +82,18 @@ async def _invoke_wrapper(payload):
8482 raise TypeError ("Unsupported invoke target for LangGraphRolloutProcessor" )
8583
8684 result_obj = await invoke_fn ({"messages" : lm_messages })
87- # Accept both dicts and objects with .get/.messages
8885 if isinstance (result_obj , dict ):
89- result_messages : List [BaseMessage ] = result_obj .get ("messages" , [])
86+ result_messages : List [LCBaseMessage ] = result_obj .get ("messages" , [])
9087 else :
9188 result_messages = getattr (result_obj , "messages" , [])
9289
93- # TODO: i didn't see a langgraph example so couldn't fully test this. should uncomment and test when we have example ready.
94- # total_input_tokens = 0
95- # total_output_tokens = 0
96- # total_tokens = 0
97-
98- # for msg in result_messages:
99- # if isinstance(msg, BaseMessage):
100- # usage = getattr(msg, 'response_metadata', {})
101- # else:
102- # usage = msg.get("response_metadata", {})
103-
104- # if usage:
105- # total_input_tokens += usage.get("prompt_tokens", 0)
106- # total_output_tokens += usage.get("completion_tokens", 0)
107- # total_tokens += usage.get("total_tokens", 0)
108-
109- # row.execution_metadata.usage = CompletionUsage(
110- # prompt_tokens=total_input_tokens,
111- # completion_tokens=total_output_tokens,
112- # total_tokens=total_tokens,
113- # )
114-
115- def _serialize_message (msg : BaseMessage ) -> Message :
116- # Prefer SDK-level serializer
90+ def _serialize_message (msg : LCBaseMessage ) -> Message :
11791 try :
11892 from eval_protocol .adapters .langchain import serialize_lc_message_to_ep as _ser
119-
120- return _ser (msg )
121- except Exception :
122- # Minimal fallback: best-effort string content only
93+ except ImportError :
12394 content = getattr (msg , "content" , "" )
12495 return Message (role = getattr (msg , "type" , "assistant" ), content = str (content ))
96+ return _ser (cast (Any , msg ))
12597
12698 row .messages = [_serialize_message (m ) for m in result_messages ]
12799
0 commit comments