diff --git a/schemas/postgres/migrations/0027_playground_session_rendered_messages.sql b/schemas/postgres/migrations/0027_playground_session_rendered_messages.sql new file mode 100644 index 0000000..aa9f4ac --- /dev/null +++ b/schemas/postgres/migrations/0027_playground_session_rendered_messages.sql @@ -0,0 +1,33 @@ +-- 0027_playground_session_rendered_messages.sql +-- Add a structured form of the rendered prompt to playground_session. +-- The existing rendered_prompt text stays as a human-readable +-- newline-joined view (used by the trace UI today); the new +-- rendered_messages jsonb is what replay / re-dispatch will read so +-- the message structure round-trips exactly. + +begin; + +alter table playground_session + add column rendered_messages jsonb; + +-- Backfill: existing sessions wrapped as a single human message so the +-- column is never null going forward. This preserves the meaning of +-- old sessions whose rendered_prompt was a single concatenated string. +update playground_session + set rendered_messages = jsonb_build_array( + jsonb_build_object('role', 'human', 'content', rendered_prompt) + ) + where rendered_messages is null; + +alter table playground_session + alter column rendered_messages set not null; + +alter table playground_session + add constraint playground_session_rendered_messages_nonempty + check (jsonb_typeof(rendered_messages) = 'array' + and jsonb_array_length(rendered_messages) > 0); + +insert into schema_migrations (version) values ('0027_playground_session_rendered_messages') +on conflict (version) do nothing; + +commit; diff --git a/services/api/tests/unit/test_playground_create_validation.py b/services/api/tests/unit/test_playground_create_validation.py new file mode 100644 index 0000000..4d0b892 --- /dev/null +++ b/services/api/tests/unit/test_playground_create_validation.py @@ -0,0 +1,134 @@ +"""PlaygroundCreate xor: exactly one of prompt_version_id / +raw_template / raw_messages is required per request. Zero or more than +one is a 422 (pydantic ValidationError).""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest +from pydantic import ValidationError +from tracebility_api.routers.playground import Message, PlaygroundCreate + +_MODEL = "anthropic/claude-sonnet-4-6" + + +def test_accepts_raw_messages(): + body = PlaygroundCreate( + project_id=uuid4(), + raw_messages=[Message(role="human", content="hi {{ x }}")], + variables={"x": "y"}, + model=_MODEL, + ) + assert body.raw_messages is not None + assert body.raw_template is None + assert body.prompt_version_id is None + + +def test_accepts_raw_template(): + body = PlaygroundCreate( + project_id=uuid4(), + raw_template="hi {{ x }}", + model=_MODEL, + ) + assert body.raw_template == "hi {{ x }}" + assert body.raw_messages is None + + +def test_accepts_prompt_version_id(): + body = PlaygroundCreate( + project_id=uuid4(), + prompt_version_id=uuid4(), + model=_MODEL, + ) + assert body.prompt_version_id is not None + assert body.raw_template is None + assert body.raw_messages is None + + +def test_rejects_zero_template_sources(): + with pytest.raises(ValidationError) as exc: + PlaygroundCreate(project_id=uuid4(), model=_MODEL) + assert "required" in str(exc.value).lower() + + +def test_rejects_template_and_messages_together(): + with pytest.raises(ValidationError): + PlaygroundCreate( + project_id=uuid4(), + raw_template="hi", + raw_messages=[Message(role="human", content="hi")], + model=_MODEL, + ) + + +def test_rejects_prompt_id_and_raw_template_together(): + with pytest.raises(ValidationError): + PlaygroundCreate( + project_id=uuid4(), + prompt_version_id=uuid4(), + raw_template="hi", + model=_MODEL, + ) + + +def test_rejects_prompt_id_and_raw_messages_together(): + with pytest.raises(ValidationError): + PlaygroundCreate( + project_id=uuid4(), + prompt_version_id=uuid4(), + raw_messages=[Message(role="human", content="hi")], + model=_MODEL, + ) + + +def test_rejects_all_three_together(): + with pytest.raises(ValidationError) as exc: + PlaygroundCreate( + project_id=uuid4(), + prompt_version_id=uuid4(), + raw_template="hi", + raw_messages=[Message(role="human", content="hi")], + model=_MODEL, + ) + assert "mutually exclusive" in str(exc.value).lower() + + +def test_rejects_empty_raw_messages_list(): + """An empty messages list is not a valid template source - at least + one message is required. (The check constraint on the prompt_version + table enforces this on the storage side; the request-side validator + closes the gap so we don't even attempt a render with zero messages.) + """ + with pytest.raises(ValidationError): + PlaygroundCreate( + project_id=uuid4(), + raw_messages=[], + model=_MODEL, + ) + + +def test_rejects_all_three_explicit_none(): + """Same as the zero-source case, but with explicit Nones in the body + (e.g. a JSON client that sends nulls instead of omitting fields). + Pydantic should treat None and omitted identically.""" + with pytest.raises(ValidationError): + PlaygroundCreate( + project_id=uuid4(), + prompt_version_id=None, + raw_template=None, + raw_messages=None, + model=_MODEL, + ) + + +def test_rejects_raw_messages_with_invalid_role(): + """Pydantic enforces role in {system, human} via the Message Literal; + sending an out-of-range role rejects the request before our validator + even runs.""" + with pytest.raises(ValidationError): + PlaygroundCreate( + project_id=uuid4(), + raw_messages=[{"role": "assistant", "content": "hi"}], + model=_MODEL, + ) diff --git a/services/api/tests/unit/test_playground_render_messages.py b/services/api/tests/unit/test_playground_render_messages.py new file mode 100644 index 0000000..d920a4c --- /dev/null +++ b/services/api/tests/unit/test_playground_render_messages.py @@ -0,0 +1,91 @@ +"""Per-message {{ var }} substitution: each message's content is rendered +against the same variable dict; roles are preserved verbatim. + +Spec decision 9: missing variables render as empty string. +""" + +from __future__ import annotations + +from tracebility_api.routers.playground import ( + Message, + _render_messages, +) + + +def test_renders_variables_per_message(): + msgs = [ + Message(role="system", content="You are a {{ tone }} assistant."), + Message(role="human", content="Summarize: {{ doc }}"), + ] + out = _render_messages(msgs, {"tone": "terse", "doc": "lorem ipsum"}) + assert out == [ + Message(role="system", content="You are a terse assistant."), + Message(role="human", content="Summarize: lorem ipsum"), + ] + + +def test_missing_variable_renders_empty(): + """Per spec decision 9: a placeholder whose key is absent from the + variables dict renders as the empty string.""" + msgs = [Message(role="human", content="Echo: {{ x }}")] + out = _render_messages(msgs, {}) + assert out == [Message(role="human", content="Echo: ")] + + +def test_no_variables_passes_through(): + msgs = [ + Message(role="system", content="static prompt"), + Message(role="human", content="hi"), + ] + assert _render_messages(msgs, {"unused": "value"}) == msgs + + +def test_returns_new_list_does_not_mutate_input(): + msgs = [Message(role="human", content="{{ x }}")] + out = _render_messages(msgs, {"x": "y"}) + assert out is not msgs + assert msgs[0].content == "{{ x }}" # original untouched + + +def test_non_string_value_serializes_via_json(): + """Non-string variable values serialize via json.dumps so dicts and + lists round-trip as readable JSON.""" + msgs = [Message(role="human", content="ctx={{ ctx }}")] + out = _render_messages(msgs, {"ctx": {"a": 1}}) + assert out == [Message(role="human", content='ctx={"a": 1}')] + + +def test_repeated_variable_in_one_content(): + """Both occurrences are substituted; re.sub default replaces all.""" + msgs = [Message(role="human", content="{{ x }} and {{ x }}")] + assert _render_messages(msgs, {"x": "hi"}) == [Message(role="human", content="hi and hi")] + + +def test_whitespace_around_placeholder(): + """The regex tolerates `\\s*` on either side of the var name; both + {{x}} and {{ x }} resolve identically.""" + msgs = [Message(role="human", content="a={{x}} b={{ x }}")] + out = _render_messages(msgs, {"x": "1"}) + assert out == [Message(role="human", content="a=1 b=1")] + + +def test_same_var_across_multiple_messages(): + """A single variables dict is applied to every message in order.""" + msgs = [ + Message(role="system", content="tone: {{ tone }}"), + Message(role="human", content="again, tone: {{ tone }}"), + ] + out = _render_messages(msgs, {"tone": "terse"}) + assert [m.content for m in out] == [ + "tone: terse", + "again, tone: terse", + ] + + +def test_returns_fresh_message_objects(): + """Pydantic equality is value-based; assert object identity too so a + future shortcut that returns the input message unchanged on a no-op + render would still trip the no-mutation contract.""" + msgs = [Message(role="human", content="static")] + out = _render_messages(msgs, {}) + assert out[0] is not msgs[0] diff --git a/services/api/tests/unit/test_playground_resolve_messages.py b/services/api/tests/unit/test_playground_resolve_messages.py new file mode 100644 index 0000000..63f2f54 --- /dev/null +++ b/services/api/tests/unit/test_playground_resolve_messages.py @@ -0,0 +1,135 @@ +"""_resolve_messages picks the right source per the xor validator and +returns the canonical message list shape. The asyncpg pool is mocked +because we're testing routing logic, not SQL — the SQL path itself is +covered by the prompt_version integration test.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest +from fastapi import HTTPException +from tracebility_api.routers.playground import ( + Message, + PlaygroundCreate, + _resolve_messages, +) + +_MODEL = "anthropic/claude-sonnet-4-6" + + +@pytest.mark.asyncio +async def test_raw_messages_used_verbatim(): + body = PlaygroundCreate( + project_id=uuid4(), + raw_messages=[ + Message(role="system", content="be terse"), + Message(role="human", content="echo {{ x }}"), + ], + variables={}, + model=_MODEL, + ) + pool = AsyncMock() + out, version_row = await _resolve_messages(pool, body) + + assert version_row is None + assert out == body.raw_messages + pool.fetchrow.assert_not_called() + + +@pytest.mark.asyncio +async def test_raw_template_wrapped_as_single_human_message(): + body = PlaygroundCreate( + project_id=uuid4(), + raw_template="echo {{ x }}", + model=_MODEL, + ) + pool = AsyncMock() + out, version_row = await _resolve_messages(pool, body) + + assert version_row is None + assert out == [Message(role="human", content="echo {{ x }}")] + pool.fetchrow.assert_not_called() + + +@pytest.mark.asyncio +async def test_prompt_version_id_reads_template_messages(): + """When prompt_version_id is set, _resolve_messages reads the + template_messages jsonb column and validates each entry.""" + version_id = uuid4() + body = PlaygroundCreate( + project_id=uuid4(), + prompt_version_id=version_id, + model=_MODEL, + ) + pool = AsyncMock() + pool.fetchrow.return_value = { + "id": version_id, + "prompt_id": uuid4(), + "template": "ignored legacy field", + "template_messages": [ + {"role": "system", "content": "be terse"}, + {"role": "human", "content": "echo {{ x }}"}, + ], + } + + out, version_row = await _resolve_messages(pool, body) + + assert version_row is not None + assert out == [ + Message(role="system", content="be terse"), + Message(role="human", content="echo {{ x }}"), + ] + pool.fetchrow.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_prompt_version_id_handles_jsonb_string_form(): + """Some asyncpg/codec configs hand back jsonb as a string. The + helper decodes defensively.""" + version_id = uuid4() + body = PlaygroundCreate( + project_id=uuid4(), + prompt_version_id=version_id, + model=_MODEL, + ) + pool = AsyncMock() + pool.fetchrow.return_value = { + "id": version_id, + "prompt_id": uuid4(), + "template": "x", + "template_messages": '[{"role": "human", "content": "x"}]', + } + + out, _ = await _resolve_messages(pool, body) + assert out == [Message(role="human", content="x")] + + +@pytest.mark.asyncio +async def test_prompt_version_id_missing_returns_404(): + body = PlaygroundCreate( + project_id=uuid4(), + prompt_version_id=uuid4(), + model=_MODEL, + ) + pool = AsyncMock() + pool.fetchrow.return_value = None + + with pytest.raises(HTTPException) as exc: + await _resolve_messages(pool, body) + assert exc.value.status_code == 404 + + +def test_rendered_prompt_join_format(): + """create_session computes rendered_prompt = '\\n\\n'.join(m.content + for m in rendered_messages). Pin the join semantics so a future + refactor of that line doesn't silently change the trace UI's + display string. We can't easily run create_session in isolation, + but the join is a one-liner that's safe to assert directly.""" + msgs = [ + Message(role="system", content="be terse"), + Message(role="human", content="echo hello"), + ] + rendered_prompt = "\n\n".join(m.content for m in msgs) + assert rendered_prompt == "be terse\n\necho hello" diff --git a/services/api/tests/unit/test_playground_role_mapping.py b/services/api/tests/unit/test_playground_role_mapping.py new file mode 100644 index 0000000..38e7847 --- /dev/null +++ b/services/api/tests/unit/test_playground_role_mapping.py @@ -0,0 +1,51 @@ +"""Prompt-side `human` <-> dispatch-side `user`. System passes through. + +Two Message types live in the codebase: the prompt-side pydantic model +(LangSmith vocabulary: system / human) and the dispatch-side dataclass +(provider vocabulary: system / user / assistant / tool). _to_dispatch_messages +bridges them. AI / tool roles are deferred (spec decision 2). +""" + +from __future__ import annotations + +from tracebility_api.llm import Message as DispatchMessage +from tracebility_api.routers.playground import ( + Message, + _to_dispatch_messages, +) + + +def test_human_maps_to_user(): + out = _to_dispatch_messages([Message(role="human", content="hi")]) + assert out == [DispatchMessage(role="user", content="hi")] + + +def test_system_passes_through(): + out = _to_dispatch_messages([Message(role="system", content="be terse")]) + assert out == [DispatchMessage(role="system", content="be terse")] + + +def test_full_conversation_order_preserved(): + out = _to_dispatch_messages( + [ + Message(role="system", content="be terse"), + Message(role="human", content="hi"), + ] + ) + assert [m.role for m in out] == ["system", "user"] + assert [m.content for m in out] == ["be terse", "hi"] + + +def test_empty_list_returns_empty_list(): + """Edge case: pydantic doesn't enforce non-empty at the prompt-side + Message level; the API's xor validator does. The mapper itself is + safe on empty input.""" + assert _to_dispatch_messages([]) == [] + + +def test_constructs_dispatch_message_instances(): + """The mapper returns DispatchMessage objects, not the prompt-side + Message it received. Locks in the return-type contract.""" + src = [Message(role="human", content="hi")] + out = _to_dispatch_messages(src) + assert isinstance(out[0], DispatchMessage) diff --git a/services/api/tests/unit/test_prompt_version_create_validation.py b/services/api/tests/unit/test_prompt_version_create_validation.py new file mode 100644 index 0000000..b6c7312 --- /dev/null +++ b/services/api/tests/unit/test_prompt_version_create_validation.py @@ -0,0 +1,139 @@ +"""PromptVersionCreate xor validator + to_messages() helper. + +The model validator ensures exactly one of template, template_messages +is set. to_messages() resolves both shapes to the canonical Message +list for the storage path.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError +from tracebility_api.routers.prompts import Message, PromptVersionCreate + + +def test_accepts_template_messages(): + body = PromptVersionCreate( + template_messages=[Message(role="human", content="hi")], + ) + assert body.template_messages is not None + assert body.template is None + + +def test_accepts_legacy_template(): + body = PromptVersionCreate(template="hi") + assert body.template == "hi" + assert body.template_messages is None + + +def test_rejects_neither(): + with pytest.raises(ValidationError) as exc: + PromptVersionCreate() + assert "required" in str(exc.value).lower() + + +def test_rejects_both(): + with pytest.raises(ValidationError) as exc: + PromptVersionCreate( + template="hi", + template_messages=[Message(role="human", content="hi")], + ) + assert "mutually exclusive" in str(exc.value).lower() + + +def test_rejects_empty_template_messages(): + """min_length=1 on the field; pydantic rejects [] before the + model validator runs.""" + with pytest.raises(ValidationError): + PromptVersionCreate(template_messages=[]) + + +def test_to_messages_returns_template_messages_as_is(): + msgs = [ + Message(role="system", content="be terse"), + Message(role="human", content="hi"), + ] + body = PromptVersionCreate(template_messages=msgs) + assert body.to_messages() == msgs + + +def test_to_messages_wraps_legacy_template_as_single_human(): + body = PromptVersionCreate(template="hi {{ x }}") + assert body.to_messages() == [Message(role="human", content="hi {{ x }}")] + + +def test_to_messages_returns_fresh_list_for_template_messages(): + """Ensure to_messages() doesn't strip pydantic Message identity.""" + msgs = [Message(role="human", content="hi")] + body = PromptVersionCreate(template_messages=msgs) + out = body.to_messages() + assert all(isinstance(m, Message) for m in out) + + +def test_to_messages_round_trips_via_model_dump(): + """The no-op short-circuit compares [m.model_dump() for m in messages] + against the deserialized template_messages from the latest version + row. Pin the round-trip: the same messages on both sides must compare + equal as list-of-dicts. + + This is the riskiest comparison in the handler (jsonb-vs-list, dict + key order, str-decoding) and a regression here would silently create + duplicate versions instead of short-circuiting.""" + msgs = [ + Message(role="system", content="be terse"), + Message(role="human", content="echo {{ x }}"), + ] + body = PromptVersionCreate(template_messages=msgs) + # Outgoing form (what the handler builds before the INSERT compare). + outgoing = [m.model_dump() for m in body.to_messages()] + # Incoming form simulating what asyncpg returns for the jsonb column + # after `select template_messages from prompt_version` on a row that + # was previously inserted via this same model_dump path. + incoming_from_db = [ + {"role": "system", "content": "be terse"}, + {"role": "human", "content": "echo {{ x }}"}, + ] + assert outgoing == incoming_from_db + + +def test_legacy_template_round_trips_via_model_dump(): + """The legacy single-string body wraps to the same shape as a + structured raw_messages with a single human turn — confirming a + user can re-save a legacy prompt as the structured form without + accidentally creating v2.""" + legacy_body = PromptVersionCreate(template="echo {{ x }}") + structured_body = PromptVersionCreate( + template_messages=[Message(role="human", content="echo {{ x }}")] + ) + legacy_messages = [m.model_dump() for m in legacy_body.to_messages()] + structured_messages = [m.model_dump() for m in structured_body.to_messages()] + assert legacy_messages == structured_messages + + +def test_derive_legacy_template_single_human(): + """One bare human message -> its content is the legacy template.""" + from tracebility_api.routers.prompts import _derive_legacy_template + + out = _derive_legacy_template([Message(role="human", content="hi")]) + assert out == "hi" + + +def test_derive_legacy_template_multi_message_returns_empty(): + """Multi-message versions can't be honestly represented as a single + string -> empty rather than misleading.""" + from tracebility_api.routers.prompts import _derive_legacy_template + + out = _derive_legacy_template( + [ + Message(role="system", content="be terse"), + Message(role="human", content="hi"), + ] + ) + assert out == "" + + +def test_derive_legacy_template_single_system_returns_empty(): + """Single system-only message: not a 'normal' prompt; return empty.""" + from tracebility_api.routers.prompts import _derive_legacy_template + + out = _derive_legacy_template([Message(role="system", content="be terse")]) + assert out == "" diff --git a/services/api/tracebility_api/routers/playground.py b/services/api/tracebility_api/routers/playground.py index de3dbf9..dd11c0b 100644 --- a/services/api/tracebility_api/routers/playground.py +++ b/services/api/tracebility_api/routers/playground.py @@ -5,7 +5,9 @@ Concretely, the request handler: 1. Persists a `playground_session` row in postgres (status=running). -2. Renders the template (Jinja-style ``{{ var }}`` substitution). +2. Renders each message's content (Jinja-style ``{{ var }}`` substitution) + and persists both the structured rendered_messages and a newline- + joined rendered_prompt for the trace UI's display path. 3. Calls the chosen LLM provider over HTTP. 4. Writes a `run` + `span` to ClickHouse with `sdk='playground'` so the result is visible at `/runs/{id}` like any other trace. @@ -33,18 +35,21 @@ import re import time import uuid +from collections.abc import Mapping from datetime import UTC, datetime -from typing import Any +from typing import Any, Literal from uuid import UUID import asyncpg import structlog from fastapi import APIRouter, Depends, HTTPException, Query, Request, status -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from .. import audit from ..auth import Principal, assert_workspace_role, require_user from ..clickhouse_client import ClickHouseQuery +from ..llm import Message as DispatchMessage +from .prompts import Message log = structlog.get_logger("tracebility.api.playground") @@ -63,12 +68,37 @@ class PlaygroundCreate(BaseModel): project_id: UUID prompt_version_id: UUID | None = None + # Legacy single-string template. Wraps to a single human message + # at render time. Drops in the cleanup PR after Plan C lands. raw_template: str | None = None + # Structured form added in Plan B. Mutually exclusive with + # raw_template and prompt_version_id (xor enforced below). + # min_length=1 rejects an empty list at field-validation time with a + # structured `type=too_short` error scoped to `loc=('raw_messages',)`, + # before the model validator runs. + raw_messages: list[Message] | None = Field(default=None, min_length=1) variables: dict[str, Any] = Field(default_factory=dict) model: str = Field(min_length=1, max_length=128) temperature: float | None = Field(default=None, ge=0.0, le=2.0) max_tokens: int | None = Field(default=None, ge=1, le=8192) + @model_validator(mode="after") + def _exactly_one_template_source(self) -> PlaygroundCreate: + sources = [ + self.prompt_version_id is not None, + self.raw_template is not None, + self.raw_messages is not None, + ] + n = sum(sources) + if n == 0: + raise ValueError("one of prompt_version_id, raw_template, raw_messages is required") + if n > 1: + raise ValueError( + "prompt_version_id, raw_template, raw_messages are mutually " + "exclusive; provide exactly one" + ) + return self + class PlaygroundSessionOut(BaseModel): id: UUID @@ -145,12 +175,6 @@ async def create_session( body: PlaygroundCreate, principal: Principal = Depends(require_user), ) -> PlaygroundSessionOut: - if body.prompt_version_id is None and not body.raw_template: - raise HTTPException( - status.HTTP_400_BAD_REQUEST, - "either prompt_version_id or raw_template is required", - ) - pool: asyncpg.Pool = request.app.state.pg workspace_id = await _assert_project_role( pool, principal, body.project_id, ("owner", "admin", "member") @@ -158,22 +182,24 @@ async def create_session( provider = _resolve_provider(body.model) - # Resolve the template body. If a prompt_version was provided, the - # postgres row is authoritative; the raw_template field on the - # request is ignored to avoid silent divergence. - template_body, version_row = await _resolve_template(pool, body) + # Resolve the message list to render. raw_messages > prompt_version_id + # > raw_template per the xor validator on PlaygroundCreate. + template_messages, version_row = await _resolve_messages(pool, body) - rendered = _render_template(template_body, body.variables) + rendered_messages = _render_messages(template_messages, body.variables) + # Newline-joined readable view for the trace UI's existing display + # path. The structured rendered_messages is what replay reads. + rendered_prompt = "\n\n".join(m.content for m in rendered_messages) started = time.monotonic() session_row = await pool.fetchrow( """ insert into playground_session ( project_id, prompt_version_id, raw_template, rendered_prompt, - variables, provider, model, temperature, max_tokens, - status, created_by + rendered_messages, variables, provider, model, temperature, + max_tokens, status, created_by ) - values ($1, $2, $3, $4, $5::jsonb, $6, $7, $8, $9, 'running', $10) + values ($1, $2, $3, $4, $5::jsonb, $6::jsonb, $7, $8, $9, $10, 'running', $11) returning id, project_id, prompt_version_id, raw_template, rendered_prompt, variables, provider, model, temperature, max_tokens, status, output_text, prompt_tokens, @@ -182,8 +208,13 @@ async def create_session( """, body.project_id, body.prompt_version_id, - None if version_row is not None else body.raw_template, - rendered, + # raw_template column on the session is only persisted for the + # legacy single-string path. raw_messages and prompt_version_id + # paths leave it null so the trace UI doesn't show a misleading + # legacy snippet. + None if (version_row is not None or body.raw_messages is not None) else body.raw_template, + rendered_prompt, + _json.dumps([m.model_dump() for m in rendered_messages]), _json.dumps(body.variables), provider, body.model, @@ -205,6 +236,7 @@ async def create_session( "provider": provider, "prompt_version_id": str(body.prompt_version_id) if body.prompt_version_id else None, "raw_template": body.raw_template is not None, + "raw_messages": body.raw_messages is not None, }, request=request, workspace_id=workspace_id, @@ -215,9 +247,9 @@ async def create_session( # recorded in dispatch_cost by the gateway; the session row only # holds the surface state. if provider == "stub": - result_dict = await _dispatch_stub(body.model, rendered) + result_dict = await _dispatch_stub(body.model, rendered_prompt) else: - from ..llm import DispatchError, Message + from ..llm import DispatchError from ..llm import dispatch as gateway_dispatch try: @@ -227,7 +259,7 @@ async def create_session( surface="playground", surface_ref_id=session_id, model=f"{provider}/{body.model.removeprefix(provider + '/')}", - messages=[Message(role="user", content=rendered)], + messages=_to_dispatch_messages(rendered_messages), temperature=body.temperature, max_tokens=body.max_tokens or _DEFAULT_MAX_TOKENS, ) @@ -270,7 +302,7 @@ async def create_session( run_id=run_id, model=body.model, temperature=body.temperature, - prompt=rendered, + prompt=rendered_prompt, output=result["text"], prompt_tokens=result["prompt_tokens"], completion_tokens=result["completion_tokens"], @@ -391,37 +423,113 @@ def _resolve_provider(model: str) -> str: ) -async def _resolve_template( +async def _resolve_messages( pool: asyncpg.Pool, body: PlaygroundCreate, -) -> tuple[str, asyncpg.Record | None]: +) -> tuple[list[Message], asyncpg.Record | None]: + """Return the message list to render + the version row (if any). + + Resolution order matches PlaygroundCreate's xor validator: + 1. raw_messages - explicit; use as-is. + 2. prompt_version_id - read template_messages from prompt_version. + 3. raw_template - wrap as [{role: human, content: }]. + + Exactly one of these is set per the model validator; the validator + runs before this function so the assertions below are guards, not + logic. + """ + if body.raw_messages is not None: + return body.raw_messages, None if body.prompt_version_id is not None: version_row = await pool.fetchrow( - """select id, prompt_id, template from prompt_version - where id = $1""", + """select id, prompt_id, template, template_messages + from prompt_version where id = $1""", body.prompt_version_id, ) if version_row is None: raise HTTPException(status.HTTP_404_NOT_FOUND, "prompt version not found") - return version_row["template"], version_row + msgs_raw = version_row["template_messages"] + if isinstance(msgs_raw, str): + msgs_raw = _json.loads(msgs_raw) + messages = [Message.model_validate(m) for m in msgs_raw] + return messages, version_row + # raw_template path assert body.raw_template is not None # validated above - return body.raw_template, None + return [Message(role="human", content=body.raw_template)], None + + +def _coerce_var_value(value: Any) -> str: + """Stringify a variable value the same way for both render paths. + + Strings pass through. Other types serialize via json.dumps so dicts + and lists round-trip as readable JSON. Falls back to str() on + json-incompatible objects (datetimes etc.) so we never crash mid-render. + """ + if isinstance(value, str): + return value + try: + return _json.dumps(value) + except (TypeError, ValueError): + return str(value) + +def _render_messages( + messages: list[Message], + variables: dict[str, Any], +) -> list[Message]: + """Render `{{ var }}` substitutions in each message's content. + + Missing variables render as empty string per spec decision 9 — the + user can iterate without the renderer fighting them. Non-string + values serialize via json.dumps (via _coerce_var_value) so a dict or + list passed as a variable round-trips as readable JSON. + + Returns a fresh list of new Message objects — never mutates the + input list or its contents. + """ -def _render_template(template: str, variables: dict[str, Any]) -> str: def _repl(match: re.Match[str]) -> str: key = match.group(1) if key not in variables: - return match.group(0) - value = variables[key] - if isinstance(value, str): - return value - try: - return _json.dumps(value) - except (TypeError, ValueError): - return str(value) + return "" # spec decision 9: missing var -> empty string + return _coerce_var_value(variables[key]) + + return [Message(role=m.role, content=_VAR_RE.sub(_repl, m.content)) for m in messages] + - return _VAR_RE.sub(_repl, template) +# Single source of truth for the prompt -> dispatch role translation. +# When AI / tool roles land (spec decision 2 deferral), extend this dict +# and the prompt-side Message Literal in routers/prompts.py together. +_PROMPT_TO_DISPATCH_ROLE: Mapping[ + Literal["system", "human"], + Literal["system", "user", "assistant", "tool"], +] = { + "system": "system", + "human": "user", +} + + +def _to_dispatch_messages(messages: list[Message]) -> list[DispatchMessage]: + """Convert prompt-side roles to dispatcher roles. + + Prompt side: `system` | `human` (LangSmith vocabulary). + Dispatch side: `system` | `user` | `assistant` | `tool` (provider + vocabulary, what LiteLLM expects). + + The only translation today is `human` -> `user`; system passes + through. AI / tool roles are deferred (spec decision 2). When they + land, extend `_PROMPT_TO_DISPATCH_ROLE` AND the prompt-side + `Message.role` Literal in `routers/prompts.py` in the same change — + the dispatch-side dataclass already accepts assistant / tool, so + only the prompt side and the bridging table need updating. + """ + return [ + DispatchMessage( + role=_PROMPT_TO_DISPATCH_ROLE[m.role], + content=m.content, + ) + for m in messages + ] async def _dispatch_stub(model: str, prompt: str) -> dict[str, Any]: diff --git a/services/api/tracebility_api/routers/prompts.py b/services/api/tracebility_api/routers/prompts.py index 7585878..6fe4d2b 100644 --- a/services/api/tracebility_api/routers/prompts.py +++ b/services/api/tracebility_api/routers/prompts.py @@ -24,8 +24,8 @@ import asyncpg import structlog -from fastapi import APIRouter, Depends, HTTPException, Query, Request, status -from pydantic import BaseModel, Field +from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status +from pydantic import BaseModel, Field, model_validator from .. import audit from ..auth import Principal, assert_workspace_role, require_user @@ -97,12 +97,44 @@ class PromptVersionList(BaseModel): class PromptVersionCreate(BaseModel): - template: str = Field(min_length=1) + """Body for POST /v1/prompts/{prompt_id}/versions. + + Exactly one of `template_messages` (preferred, structured) and + `template` (legacy single-string) must be provided. The legacy + field stays for one release of back-compat; it wraps to a single + human message internally before write. + """ + + template_messages: list[Message] | None = Field(default=None, min_length=1) + template: str | None = Field(default=None, min_length=1) input_schema: dict[str, Any] | None = None model_params: dict[str, Any] | None = None aliases: list[str] = Field(default_factory=list) commit_message: str | None = Field(default=None, max_length=2000) + @model_validator(mode="after") + def _exactly_one_template_source(self) -> PromptVersionCreate: + n = sum( + [ + self.template_messages is not None, + self.template is not None, + ] + ) + if n == 0: + raise ValueError("one of template, template_messages is required") + if n > 1: + raise ValueError( + "template and template_messages are mutually exclusive; provide exactly one" + ) + return self + + def to_messages(self) -> list[Message]: + """Resolve to the canonical list-of-messages form for storage.""" + if self.template_messages is not None: + return self.template_messages + assert self.template is not None # validated above + return [Message(role="human", content=self.template)] + class AliasUpdate(BaseModel): alias: str = Field(min_length=1, max_length=64, pattern=r"^[a-z0-9][a-z0-9_-]*$") @@ -319,6 +351,7 @@ async def list_versions( ) async def create_version( request: Request, + response: Response, prompt_id: UUID, body: PromptVersionCreate, principal: Principal = Depends(require_user), @@ -336,14 +369,41 @@ async def create_version( input_schema_json = _json.dumps(body.input_schema) if body.input_schema is not None else None model_params_json = _json.dumps(body.model_params) if body.model_params is not None else None + messages = body.to_messages() + messages_json_list = [m.model_dump() for m in messages] + async with pool.acquire() as conn, conn.transaction(): - next_version = await conn.fetchval( + # Latest version for this prompt, if any. Used both to compute the + # next version number and to detect a no-op duplicate. + latest = await conn.fetchrow( """ - select coalesce(max(version), 0) + 1 - from prompt_version where prompt_id = $1 + select id, prompt_id, version, template, template_messages, + input_schema, model_params, aliases, commit_message, + created_at + from prompt_version + where prompt_id = $1 + order by version desc + limit 1 """, prompt_id, ) + + if latest is not None: + # Decode jsonb defensively (some asyncpg codec configs return str). + latest_msgs_raw = latest["template_messages"] + if isinstance(latest_msgs_raw, str): + latest_msgs_raw = _json.loads(latest_msgs_raw) + # No-op short-circuit: identical to the most recent version. + # End-to-end coverage of this branch lives in the deferred + # integration test (Plan B Task 6 follow-up); the comparison + # contract is pinned by tests/unit/test_prompt_version_create_validation.py + # so a regression in to_messages()/model_dump fails fast. + if latest_msgs_raw == messages_json_list: + response.status_code = status.HTTP_200_OK + return _hydrate_version(latest) + + next_version = (latest["version"] + 1) if latest is not None else 1 + if aliases: await conn.execute( """ @@ -356,20 +416,24 @@ async def create_version( prompt_id, aliases, ) + + legacy_template = _derive_legacy_template(messages) row = await conn.fetchrow( """ insert into prompt_version ( - prompt_id, version, template, input_schema, model_params, - aliases, commit_message, created_by + prompt_id, version, template, template_messages, + input_schema, model_params, aliases, commit_message, + created_by ) - values ($1, $2, $3, $4::jsonb, $5::jsonb, $6, $7, $8) + values ($1, $2, $3, $4::jsonb, $5::jsonb, $6::jsonb, $7, $8, $9) returning id, prompt_id, version, template, template_messages, input_schema, model_params, aliases, commit_message, created_at """, prompt_id, next_version, - body.template, + legacy_template, + _json.dumps(messages_json_list), input_schema_json, model_params_json, aliases, @@ -392,6 +456,7 @@ async def create_version( "prompt_id": str(prompt_id), "version": version.version, "aliases": aliases, + "template_messages": body.template_messages is not None, }, request=request, workspace_id=workspace_id, @@ -504,6 +569,23 @@ async def assign_alias( return version +def _derive_legacy_template(messages: list[Message]) -> str: + """Compute the legacy `template` text from the canonical messages. + + Returns the human content only when the version is exactly one bare + human message; otherwise empty string. Lying about the legacy field + (e.g., picking the human content from a multi-message version) would + quietly mislead old clients during the deprecation window. + + This rule is the ONE source of truth — the create write path and the + read hydration path both call this so the response shape can never + diverge between create and read. + """ + if len(messages) == 1 and messages[0].role == "human": + return messages[0].content + return "" + + def _normalize_aliases(raw: list[str]) -> list[str]: seen: set[str] = set() out: list[str] = [] @@ -541,9 +623,7 @@ def _hydrate_version(row: asyncpg.Record) -> PromptVersionOut: # data-corruption bug. If/when ai/tool roles land, extend Message, # don't loosen this validation. messages = [Message.model_validate(m) for m in msgs_raw] - legacy_template = ( - messages[0].content if len(messages) == 1 and messages[0].role == "human" else "" - ) + legacy_template = _derive_legacy_template(messages) return PromptVersionOut( id=data["id"], prompt_id=data["prompt_id"],