Skip to content

Commit 854cb5c

Browse files
committed
LangGraph simple example
1 parent 35f32aa commit 854cb5c

File tree

12 files changed

+609
-125
lines changed

12 files changed

+609
-125
lines changed

eval_protocol/adapters/bigquery.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
import logging
10-
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Union, cast, TypeAlias
10+
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeAlias
1111

1212
from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message
1313

@@ -108,10 +108,7 @@ def __init__(
108108
# Avoid strict typing on optional dependency
109109
self.client = _bigquery_runtime.Client(**client_args) # type: ignore[no-untyped-call, assignment]
110110

111-
except DefaultCredentialsError as e:
112-
logger.error("Failed to authenticate with BigQuery: %s", e)
113-
raise
114-
except Exception as e:
111+
except (DefaultCredentialsError, ImportError, ValueError, TypeError) as e:
115112
logger.error("Failed to initialize BigQuery client: %s", e)
116113
raise
117114

@@ -191,10 +188,7 @@ def get_evaluation_rows(
191188

192189
row_count += 1
193190

194-
except (NotFound, Forbidden) as e:
195-
logger.error("BigQuery access error: %s", e)
196-
raise
197-
except Exception as e:
191+
except (NotFound, Forbidden, RuntimeError, ValueError, TypeError, AttributeError) as e:
198192
logger.error("Error executing BigQuery query: %s", e)
199193
raise
200194

eval_protocol/adapters/langchain.py

Lines changed: 37 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from typing import Any, Dict, List, Optional
55

6-
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
6+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
77

88
from eval_protocol.models import Message
99

@@ -49,75 +49,12 @@ def serialize_lc_message_to_ep(msg: BaseMessage) -> Message:
4949
parts.append(item)
5050
content = "\n".join(parts)
5151

52-
tool_calls_payload: Optional[List[Dict[str, Any]]] = None
53-
54-
def _normalize_tool_calls(tc_list: List[Any]) -> List[Dict[str, Any]]:
55-
mapped: List[Dict[str, Any]] = []
56-
for call in tc_list:
57-
if not isinstance(call, dict):
58-
continue
59-
try:
60-
call_id = call.get("id") or "toolcall_0"
61-
if isinstance(call.get("function"), dict):
62-
fn = call["function"]
63-
fn_name = fn.get("name") or call.get("name") or "tool"
64-
fn_args = fn.get("arguments")
65-
else:
66-
fn_name = call.get("name") or "tool"
67-
fn_args = call.get("arguments") if call.get("arguments") is not None else call.get("args")
68-
if not isinstance(fn_args, str):
69-
import json as _json
70-
71-
fn_args = _json.dumps(fn_args or {}, ensure_ascii=False)
72-
mapped.append(
73-
{
74-
"id": call_id,
75-
"type": "function",
76-
"function": {"name": fn_name, "arguments": fn_args},
77-
}
78-
)
79-
except Exception:
80-
continue
81-
return mapped
82-
83-
ak = getattr(msg, "additional_kwargs", None)
84-
if isinstance(ak, dict):
85-
tc = ak.get("tool_calls")
86-
if isinstance(tc, list) and tc:
87-
mapped = _normalize_tool_calls(tc)
88-
if mapped:
89-
tool_calls_payload = mapped
90-
91-
if tool_calls_payload is None:
92-
raw_attr_tc = getattr(msg, "tool_calls", None)
93-
if isinstance(raw_attr_tc, list) and raw_attr_tc:
94-
mapped = _normalize_tool_calls(raw_attr_tc)
95-
if mapped:
96-
tool_calls_payload = mapped
97-
98-
# Extract reasoning/thinking parts into reasoning_content
99-
reasoning_content = None
100-
if isinstance(msg.content, list):
101-
collected = [
102-
it.get("thinking", "") for it in msg.content if isinstance(it, dict) and it.get("type") == "thinking"
103-
]
104-
if collected:
105-
reasoning_content = "\n\n".join([s for s in collected if s]) or None
106-
107-
# Message.tool_calls expects List[ChatCompletionMessageToolCall] | None.
108-
# We pass through Dicts at runtime but avoid type error by casting.
109-
ep_msg = Message(
110-
role="assistant",
111-
content=content,
112-
tool_calls=tool_calls_payload, # type: ignore[arg-type]
113-
reasoning_content=reasoning_content,
114-
)
52+
ep_msg = Message(role="assistant", content=content)
11553
_dbg_print(
11654
"[EP-Ser] -> EP Message:",
11755
{
11856
"role": ep_msg.role,
11957
"content_len": len(ep_msg.content or ""),
120-
"tool_calls": len(ep_msg.tool_calls or []) if isinstance(ep_msg.tool_calls, list) else 0,
12158
},
12259
)
12360
return ep_msg
@@ -141,3 +78,38 @@ def _normalize_tool_calls(tc_list: List[Any]) -> List[Dict[str, Any]]:
14178
ep_msg = Message(role=getattr(msg, "type", "assistant"), content=str(getattr(msg, "content", "")))
14279
_dbg_print("[EP-Ser] -> EP Message (fallback):", {"role": ep_msg.role, "len": len(ep_msg.content or "")})
14380
return ep_msg
81+
82+
83+
def serialize_ep_messages_to_lc(messages: List[Message]) -> List[BaseMessage]:
84+
"""Convert eval_protocol Message objects to LangChain BaseMessage list.
85+
86+
- Flattens content parts into strings when content is a list
87+
- Maps EP roles to LC message classes
88+
"""
89+
lc_messages: List[BaseMessage] = []
90+
for m in messages or []:
91+
content = m.content
92+
if isinstance(content, list):
93+
text_parts: List[str] = []
94+
for part in content:
95+
try:
96+
text_parts.append(getattr(part, "text", ""))
97+
except AttributeError:
98+
pass
99+
content = "\n".join([t for t in text_parts if t])
100+
if content is None:
101+
content = ""
102+
text = str(content)
103+
104+
role = (m.role or "").lower()
105+
if role == "user":
106+
lc_messages.append(HumanMessage(content=text))
107+
elif role == "assistant":
108+
lc_messages.append(AIMessage(content=text))
109+
elif role == "system":
110+
from langchain_core.messages import SystemMessage # local import to avoid unused import
111+
112+
lc_messages.append(SystemMessage(content=text))
113+
else:
114+
lc_messages.append(HumanMessage(content=text))
115+
return lc_messages

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 23 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
import asyncio
22
import time
3-
from typing import List
3+
from typing import List, Any, cast
44

55
try:
6-
from langchain_core.messages import BaseMessage
7-
except Exception: # pragma: no cover - optional dependency path
8-
# Minimal fallback base type to satisfy typing when langchain is not present
9-
class BaseMessage: # type: ignore
10-
pass
6+
from langchain_core.messages import BaseMessage as LCBaseMessage, HumanMessage # type: ignore
7+
except ImportError: # pragma: no cover - optional dependency path
8+
# Minimal fallbacks to satisfy typing when langchain is not present
9+
class LCBaseMessage: # type: ignore
10+
content: str
11+
type: str
12+
13+
def __init__(self, content: str = "", msg_type: str = "assistant"):
14+
self.content = content
15+
self.type = msg_type
16+
17+
class HumanMessage(LCBaseMessage): # type: ignore
18+
def __init__(self, content: str):
19+
super().__init__(content=content, msg_type="human")
1120

1221

1322
from eval_protocol.models import EvaluationRow, Message
14-
from openai.types import CompletionUsage
1523
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1624
from eval_protocol.pytest.types import RolloutProcessorConfig
1725

@@ -34,27 +42,17 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig):
3442
async def _process_row(row: EvaluationRow) -> EvaluationRow:
3543
start_time = time.perf_counter()
3644

37-
# Build LC messages from EP row
38-
try:
39-
from langchain_core.messages import HumanMessage
40-
except Exception:
41-
# Fallback minimal message if langchain_core is unavailable
42-
class HumanMessage(BaseMessage): # type: ignore
43-
def __init__(self, content: str):
44-
self.content = content
45-
self.type = "human"
46-
47-
lm_messages: List[BaseMessage] = []
45+
# Build LC messages from EP row (minimal: last user to HumanMessage)
46+
lm_messages: List[LCBaseMessage] = []
4847
if row.messages:
4948
last_user = [m for m in row.messages if m.role == "user"]
5049
if last_user:
5150
content = last_user[-1].content or ""
5251
if isinstance(content, list):
53-
# Flatten our SDK content parts into a single string for LangChain
5452
content = "".join([getattr(p, "text", str(p)) for p in content])
5553
lm_messages.append(HumanMessage(content=str(content)))
5654
if not lm_messages:
57-
lm_messages = [HumanMessage(content="")] # minimal
55+
lm_messages = [HumanMessage(content="")]
5856

5957
target = await self.get_invoke_target(config)
6058

@@ -72,7 +70,7 @@ async def _invoke_direct(payload):
7270

7371
invoke_fn = _invoke_direct
7472
elif callable(target):
75-
# If target is a normal callable, call it directly; if it returns an awaitable, await it
73+
7674
async def _invoke_wrapper(payload):
7775
result = target(payload)
7876
if asyncio.iscoroutine(result):
@@ -84,44 +82,18 @@ async def _invoke_wrapper(payload):
8482
raise TypeError("Unsupported invoke target for LangGraphRolloutProcessor")
8583

8684
result_obj = await invoke_fn({"messages": lm_messages})
87-
# Accept both dicts and objects with .get/.messages
8885
if isinstance(result_obj, dict):
89-
result_messages: List[BaseMessage] = result_obj.get("messages", [])
86+
result_messages: List[LCBaseMessage] = result_obj.get("messages", [])
9087
else:
9188
result_messages = getattr(result_obj, "messages", [])
9289

93-
# TODO: i didn't see a langgraph example so couldn't fully test this. should uncomment and test when we have example ready.
94-
# total_input_tokens = 0
95-
# total_output_tokens = 0
96-
# total_tokens = 0
97-
98-
# for msg in result_messages:
99-
# if isinstance(msg, BaseMessage):
100-
# usage = getattr(msg, 'response_metadata', {})
101-
# else:
102-
# usage = msg.get("response_metadata", {})
103-
104-
# if usage:
105-
# total_input_tokens += usage.get("prompt_tokens", 0)
106-
# total_output_tokens += usage.get("completion_tokens", 0)
107-
# total_tokens += usage.get("total_tokens", 0)
108-
109-
# row.execution_metadata.usage = CompletionUsage(
110-
# prompt_tokens=total_input_tokens,
111-
# completion_tokens=total_output_tokens,
112-
# total_tokens=total_tokens,
113-
# )
114-
115-
def _serialize_message(msg: BaseMessage) -> Message:
116-
# Prefer SDK-level serializer
90+
def _serialize_message(msg: LCBaseMessage) -> Message:
11791
try:
11892
from eval_protocol.adapters.langchain import serialize_lc_message_to_ep as _ser
119-
120-
return _ser(msg)
121-
except Exception:
122-
# Minimal fallback: best-effort string content only
93+
except ImportError:
12394
content = getattr(msg, "content", "")
12495
return Message(role=getattr(msg, "type", "assistant"), content=str(content))
96+
return _ser(cast(Any, msg))
12597

12698
row.messages = [_serialize_message(m) for m in result_messages]
12799

eval_protocol/pytest/handle_persist_flow.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name:
4242
if len(dataset_name) > 63:
4343
dataset_name = dataset_name[:63]
4444

45+
# Fireworks requires: last character of id must not be '-'
46+
dataset_name = dataset_name.rstrip("-")
47+
48+
# Ensure non-empty after stripping; fallback to safe_test_func_name
49+
if not dataset_name:
50+
dataset_name = safe_test_func_name[:63].rstrip("-") or "dataset"
51+
4552
exp_file = exp_dir / f"{experiment_id}.jsonl"
4653
with open(exp_file, "w", encoding="utf-8") as f:
4754
for row in exp_rows:

0 commit comments

Comments
 (0)