Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,29 @@ def convert_trace_dict_to_evaluation_row(
):
break # Break early if we've found all the metadata we need

# Extract router replay payloads when present
payloads = trace.get("payloads")
if isinstance(payloads, dict):
router_replay = payloads.get("router_replay")
if isinstance(router_replay, dict) and router_replay.get("data"):
try:
from .r3_deserializer import decompress_and_parse_r3

matrices, r3_meta = decompress_and_parse_r3(router_replay["data"])
if execution_metadata.extra is None:
execution_metadata.extra = {}
execution_metadata.extra["routing_matrices"] = matrices
execution_metadata.extra["routing_metadata"] = r3_meta
except Exception as e:
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)

return EvaluationRow(
messages=messages,
tools=tools,
input_metadata=InputMetadata(
row_id=row_id,
session_data={
"langfuse_trace_id": trace.get("id"), # Store the trace ID here
"langfuse_trace_id": trace.get("id"),
},
),
execution_metadata=execution_metadata,
Expand Down Expand Up @@ -426,6 +442,7 @@ def get_evaluation_rows(
max_retries: int = 3,
span_name: Optional[str] = None,
converter: Optional[TraceDictConverter] = None,
include_payloads: bool = False,
) -> List[EvaluationRow]:
"""Pull traces from Langfuse via proxy and convert to EvaluationRow format.

Expand All @@ -449,6 +466,8 @@ def get_evaluation_rows(
max_retries: Max retry attempts used by proxy (default: 3)
converter: Optional custom converter implementing TraceDictConverter protocol.
If provided, this will be used instead of the default conversion logic.
include_payloads: If True, request payload data (e.g., router replay)
from the gateway and decompress it into the returned EvaluationRows.

Returns:
List[EvaluationRow]: Converted evaluation rows
Expand Down Expand Up @@ -479,6 +498,7 @@ def get_evaluation_rows(
"to_timestamp": to_timestamp.isoformat() if to_timestamp else None,
"sleep_between_gets": sleep_between_gets,
"max_retries": max_retries,
"include_payloads": include_payloads if include_payloads else None,
}

# Remove None values
Expand Down
187 changes: 187 additions & 0 deletions eval_protocol/adapters/r3_deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""R3/v1 binary deserializer for router-replay payloads.

Implements the inverse of the packed binary format produced by the tracing
gateway's ``r3_serializer.serialize_r3``. See that module for the full
header specification.

The main entry point is :func:`decompress_and_parse_r3`, which accepts the
base64-encoded compressed blob returned by the gateway's
``/v1/traces/pointwise?include_payloads=true`` endpoint and produces
per-token routing matrices in the same ``List[Optional[str]]`` format used
by the direct inference path (``DeploymentSampler.sample_with_tokens()``).
"""

from __future__ import annotations

import base64
import struct
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple

import zstandard as zstd

MAGIC = b"R3V1"
HEADER_FORMAT = "<4sBBBBIIIIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes
BITS_PER_BYTE = 8


class _SelectorMode(IntEnum):
ALL = 0
SUFFIX = 1
BITMAP = 2


class _RoutingDtype(IntEnum):
UINT8 = 1
UINT16 = 2


_SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode}
_ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype}


def _parse_header(raw: bytes) -> Dict[str, Any]:
if len(raw) < HEADER_SIZE:
raise ValueError(
f"Payload too short for r3/v1 header: {len(raw)} < {HEADER_SIZE}"
)

(
magic,
version,
selector_mode,
routing_dtype,
flags,
total_token_count,
replayed_token_count,
replay_start_token,
selector_byte_length,
matrix_byte_length,
) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE])

if magic != MAGIC:
raise ValueError(f"Bad R3 magic: {magic!r}")
if version != 1:
raise ValueError(f"Unsupported R3 header version: {version}")

return {
"selector_mode": selector_mode,
"routing_dtype": routing_dtype,
"flags": flags,
"total_token_count": total_token_count,
"replayed_token_count": replayed_token_count,
"replay_start_token": replay_start_token,
"selector_byte_length": selector_byte_length,
"matrix_byte_length": matrix_byte_length,
}


def _read_bitmap_positions(
selector_bytes: bytes, total_token_count: int
) -> List[int]:
"""Return sorted token indices where the bitmap bit is set."""
positions: List[int] = []
for i in range(total_token_count):
byte_idx = i // BITS_PER_BYTE
bit_idx = i % BITS_PER_BYTE
if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1:
positions.append(i)
return positions


def decompress_and_parse_r3(
data_b64: str,
) -> Tuple[List[Optional[str]], Dict[str, Any]]:
"""Decompress and unpack an R3/v1 payload into per-token routing matrices.

Args:
data_b64: Base64-encoded zstd-compressed R3 binary blob, as returned
by the tracing gateway in ``payloads.router_replay.data``.

Returns:
A tuple of ``(routing_matrices, metadata)`` where:

- ``routing_matrices`` is a ``List[Optional[str]]`` of length
``total_token_count``. Each present position contains a
base64-encoded routing matrix (matching the format returned by
the direct inference path); absent positions are ``None``.
- ``metadata`` is a dict with keys ``routing_dtype``,
``selector_mode``, ``total_token_count``, ``replayed_token_count``,
``replay_start_token``.
"""
compressed = base64.b64decode(data_b64)

# ZstdCompressor.compress() embeds the uncompressed size in the frame
# header by default, so the library can auto-allocate the output buffer.
decompressor = zstd.ZstdDecompressor()
raw = decompressor.decompress(compressed)

header = _parse_header(raw)

selector_mode = header["selector_mode"]
routing_dtype = header["routing_dtype"]
total_token_count = header["total_token_count"]
replayed_token_count = header["replayed_token_count"]
replay_start_token = header["replay_start_token"]
selector_byte_length = header["selector_byte_length"]
matrix_byte_length = header["matrix_byte_length"]

metadata: Dict[str, Any] = {
"routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)),
"selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)),
"total_token_count": total_token_count,
"replayed_token_count": replayed_token_count,
"replay_start_token": replay_start_token,
}

if replayed_token_count == 0:
return [None] * total_token_count, metadata

# Per-token matrix byte size is implicit in the payload: all replayed
# tokens share the same matrix length, so we can recover it from the
# matrix section total length divided by the replayed-token count.
if matrix_byte_length % replayed_token_count != 0:
raise ValueError(
f"matrix_byte_length ({matrix_byte_length}) is not a multiple of "
f"replayed_token_count ({replayed_token_count}); cannot split "
"into per-token matrices"
)
matrix_elem_size = matrix_byte_length // replayed_token_count

body = raw[HEADER_SIZE:]
expected_body_length = selector_byte_length + matrix_byte_length
if len(body) < expected_body_length:
raise ValueError(
f"Payload body too short for selector and matrix sections: "
f"{len(body)} < {expected_body_length}"
)

selector_bytes = body[:selector_byte_length]
matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length]

if selector_mode == _SelectorMode.ALL:
replayed_positions = list(range(total_token_count))
elif selector_mode == _SelectorMode.SUFFIX:
replayed_positions = list(
range(replay_start_token, replay_start_token + replayed_token_count)
)
Comment thread
cursor[bot] marked this conversation as resolved.
elif selector_mode == _SelectorMode.BITMAP:
replayed_positions = _read_bitmap_positions(selector_bytes, total_token_count)
else:
raise ValueError(f"Unknown selector_mode: {selector_mode}")

if len(replayed_positions) != replayed_token_count:
raise ValueError(
f"Selector produced {len(replayed_positions)} replayed positions, "
f"but header replayed_token_count is {replayed_token_count}"
)

# Split matrix bytes into per-token chunks and base64-encode each one
matrices: List[Optional[str]] = [None] * total_token_count
for idx, pos in enumerate(replayed_positions):
start = idx * matrix_elem_size
end = start + matrix_elem_size
matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii")

return matrices, metadata
7 changes: 6 additions & 1 deletion eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def __init__(
model_base_url: str = "https://tracing.fireworks.ai",
poll_interval: float = 1.0,
timeout_seconds: float = 120.0,
include_payloads: bool = False,
):
# Prefer constructor-provided configuration. These can be overridden via
# config.kwargs at call time for backward compatibility.
self._remote_base_url = remote_base_url
self._model_base_url = model_base_url
self._include_payloads = include_payloads
if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"):
self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL")
_ep_model_base_url = os.getenv("EP_MODEL_BASE_URL")
Expand Down Expand Up @@ -194,7 +196,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time

def _update_with_trace() -> None:
return update_row_with_remote_trace(row, default_fireworks_output_data_loader, model_base_url)
return update_row_with_remote_trace(
row, default_fireworks_output_data_loader, model_base_url,
include_payloads=self._include_payloads,
)

await asyncio.to_thread(_update_with_trace) # Update row with remote trace in-place
return row
Expand Down
19 changes: 15 additions & 4 deletions eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ def fetch_traces() -> List[EvaluationRow]:
# Use EP_REMOTE_API_KEY for fetching remote traces, falling back to FIREWORKS_API_KEY
api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY")
adapter = FireworksTracingAdapter(base_url=base_url, api_key=api_key)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
return adapter.get_evaluation_rows(
tags=[f"rollout_id:{config.rollout_id}"],
max_retries=5,
include_payloads=config.include_payloads,
)

return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)

Expand Down Expand Up @@ -129,7 +133,7 @@ def build_init_request(

# Build final model base URL with tracing metadata
final_model_base_url = model_base_url
if model_base_url and ("tracing.fireworks.ai" in model_base_url or model_base_url.startswith("http://localhost")):
if model_base_url and ("tracing.fireworks.ai" in model_base_url or model_base_url.startswith("http://localhost") or "litellm-gateway" in model_base_url):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the check for tracing.fireworks.ai or litellm-gateway. Which one is it. Are there cases where its one and not the other, and vice versa?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for dev testing since that is litellm-gateway

Comment thread
SunnySoldier357 marked this conversation as resolved.
final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)

# Extract API key from environment or completion_params
Expand All @@ -148,13 +152,20 @@ def build_init_request(


def update_row_with_remote_trace(
row: EvaluationRow, output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader], model_base_url: str
row: EvaluationRow,
output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader],
model_base_url: str,
include_payloads: bool = False,
) -> None:
"""Update row with remote trace data using output_data_loader (shared logic)."""
if not row.execution_metadata.rollout_id:
return None

loader_config = DataLoaderConfig(rollout_id=row.execution_metadata.rollout_id, model_base_url=model_base_url)
loader_config = DataLoaderConfig(
rollout_id=row.execution_metadata.rollout_id,
model_base_url=model_base_url,
include_payloads=include_payloads,
)
data_loader = output_data_loader(loader_config)
results = data_loader.load()
output_rows: List[EvaluationRow] = [r for result in results for r in result.rows]
Expand Down
1 change: 1 addition & 0 deletions eval_protocol/types/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class DataLoaderConfig(BaseModel):

rollout_id: str
model_base_url: Optional[str] = None
include_payloads: bool = False


class InitRequest(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"deepdiff>=6.0.0",
"websockets>=15.0.1",
"fastapi>=0.116.1",
"zstandard>=0.19.0",
]

[project.urls]
Expand Down
Loading
Loading