diff --git a/doc/api.rst b/doc/api.rst index 1b9bdc775..7c134c103 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -331,6 +331,7 @@ API Reference AttackOutcome AttackResult DecomposedSeedGroup + JsonResponseConfig Message MessagePiece PromptDataType diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index ee1d249b4..0c4868e81 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -25,6 +25,7 @@ ) from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.identifiers import Identifier +from pyrit.models.json_response_config import JsonResponseConfig from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError from pyrit.models.message import ( Message, @@ -69,6 +70,7 @@ "group_message_pieces_into_conversations", "Identifier", "ImagePathDataTypeSerializer", + "JsonResponseConfig", "Message", "MessagePiece", "PromptDataType", diff --git a/pyrit/models/json_response_config.py b/pyrit/models/json_response_config.py new file mode 100644 index 000000000..53f55d6c9 --- /dev/null +++ b/pyrit/models/json_response_config.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class JsonResponseConfig: + """ + Configuration for JSON responses (with OpenAI). + """ + + enabled: bool = False + schema: Optional[Dict[str, Any]] = None + schema_name: str = "CustomSchema" + strict: bool = True + + @classmethod + def from_metadata(cls, *, metadata: Optional[Dict[str, Any]]) -> "JsonResponseConfig": + if not metadata: + return cls(enabled=False) + + response_format = metadata.get("response_format") + if response_format != "json": + return cls(enabled=False) + + schema_val = metadata.get("json_schema") + if schema_val: + if isinstance(schema_val, str): + try: + schema = json.loads(schema_val) if schema_val else None + except json.JSONDecodeError: + raise ValueError(f"Invalid JSON schema provided: {schema_val}") + else: + schema = schema_val + + return cls( + enabled=True, + schema=schema, + schema_name=metadata.get("schema_name", "CustomSchema"), + strict=metadata.get("strict", True), + ) + + return cls(enabled=True) diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index 2d32c35de..655ee3d1e 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -4,7 +4,7 @@ import abc from typing import Optional -from pyrit.models import MessagePiece +from pyrit.models import JsonResponseConfig, MessagePiece from pyrit.prompt_target import PromptTarget @@ -86,16 +86,32 @@ def is_response_format_json(self, message_piece: MessagePiece) -> bool: include a "response_format" key. Returns: - bool: True if the response format is JSON and supported, False otherwise. + bool: True if the response format is JSON, False otherwise. Raises: ValueError: If "json" response format is requested but unsupported. """ - if message_piece.prompt_metadata: - response_format = message_piece.prompt_metadata.get("response_format") - if response_format == "json": - if not self.is_json_response_supported(): - target_name = self.get_identifier()["__type__"] - raise ValueError(f"This target {target_name} does not support JSON response format.") - return True - return False + config = self.get_json_response_config(message_piece=message_piece) + return config.enabled + + def get_json_response_config(self, *, message_piece: MessagePiece) -> JsonResponseConfig: + """ + Get the JSON response configuration from the message piece metadata. + + Args: + message_piece: A MessagePiece object with a `prompt_metadata` dictionary that may + include JSON response configuration. + + Returns: + JsonResponseConfig: The JSON response configuration. + + Raises: + ValueError: If JSON response format is requested but unsupported. + """ + config = JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) + + if config.enabled and not self.is_json_response_supported(): + target_name = self.get_identifier()["__type__"] + raise ValueError(f"This target {target_name} does not support JSON response format.") + + return config diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index f68e9fd54..98c154e73 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Any, MutableSequence, Optional +from typing import Any, Dict, MutableSequence, Optional from pyrit.common import convert_local_image_to_data_url from pyrit.exceptions import ( @@ -13,6 +13,7 @@ from pyrit.models import ( ChatMessage, ChatMessageListDictContent, + JsonResponseConfig, Message, MessagePiece, construct_response_from_request, @@ -182,8 +183,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: self._validate_request(message=message) message_piece: MessagePiece = message.message_pieces[0] - - is_json_response = self.is_response_format_json(message_piece) + json_config = self.get_json_response_config(message_piece=message_piece) # Get conversation from memory and append the current message conversation = self._memory.get_conversation(conversation_id=message_piece.conversation_id) @@ -191,7 +191,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: logger.info(f"Sending the following prompt to the prompt target: {message}") - body = await self._construct_request_body(conversation=conversation, is_json_response=is_json_response) + body = await self._construct_request_body(conversation=conversation, json_config=json_config) # Use unified error handling - automatically detects ChatCompletion and validates response = await self._handle_openai_request( @@ -389,8 +389,11 @@ async def _build_chat_messages_for_multi_modal_async(self, conversation: Mutable chat_messages.append(chat_message.model_dump(exclude_none=True)) return chat_messages - async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict: + async def _construct_request_body( + self, *, conversation: MutableSequence[Message], json_config: JsonResponseConfig + ) -> dict: messages = await self._build_chat_messages_async(conversation) + response_format = self._build_response_format(json_config) body_parameters = { "model": self._model_name, @@ -404,7 +407,7 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], "seed": self._seed, "n": self._n, "messages": messages, - "response_format": {"type": "json_object"} if is_json_response else None, + "response_format": response_format, } if self._extra_body_parameters: @@ -432,3 +435,19 @@ def _validate_request(self, *, message: Message) -> None: for prompt_data_type in converted_prompt_data_types: if prompt_data_type not in ["text", "image_path"]: raise ValueError(f"This target only supports text and image_path. Received: {prompt_data_type}.") + + def _build_response_format(self, json_config: JsonResponseConfig) -> Optional[Dict[str, Any]]: + if not json_config.enabled: + return None + + if json_config.schema: + return { + "type": "json_schema", + "json_schema": { + "name": json_config.schema_name, + "schema": json_config.schema, + "strict": json_config.strict, + }, + } + + return {"type": "json_object"} diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index eb4ce472b..5a3fc7894 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -21,6 +21,7 @@ pyrit_target_retry, ) from pyrit.models import ( + JsonResponseConfig, Message, MessagePiece, PromptDataType, @@ -315,7 +316,9 @@ async def _build_input_for_multi_modal_async(self, conversation: MutableSequence return input_items - async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict: + async def _construct_request_body( + self, *, conversation: MutableSequence[Message], json_config: JsonResponseConfig + ) -> dict: """ Construct the request body to send to the Responses API. @@ -324,13 +327,15 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], Args: conversation: The full conversation history. - is_json_response: Whether the response should be formatted as JSON. + json_config: Specification for JSON formatting. Returns: dict: The request body to send to the Responses API. """ input_items = await self._build_input_for_multi_modal_async(conversation) + text_format = self._build_text_format(json_config=json_config) + body_parameters = { "model": self._model_name, "max_output_tokens": self._max_output_tokens, @@ -339,7 +344,7 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], "stream": False, "input": input_items, # Correct JSON response format per Responses API - "response_format": {"type": "json_object"} if is_json_response else None, + "text": text_format, } if self._extra_body_parameters: @@ -348,6 +353,23 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} + def _build_text_format(self, json_config: JsonResponseConfig) -> Optional[Dict[str, Any]]: + if not json_config.enabled: + return None + + if json_config.schema: + return { + "format": { + "type": "json_schema", + "name": json_config.schema_name, + "schema": json_config.schema, + "strict": json_config.strict, + } + } + + logger.info("Using json_object format without schema - consider providing a schema for better results") + return {"format": {"type": "json_object"}} + def _check_content_filter(self, response: Any) -> bool: """ Check if a Response API response has a content filter error. @@ -445,7 +467,10 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: self._validate_request(message=message) message_piece: MessagePiece = message.message_pieces[0] - is_json_response = self.is_response_format_json(message_piece) + json_config = JsonResponseConfig(enabled=False) + if message.message_pieces: + last_piece = message.message_pieces[-1] + json_config = self.get_json_response_config(message_piece=last_piece) # Get full conversation history from memory and append the current message conversation: MutableSequence[Message] = self._memory.get_conversation( @@ -462,7 +487,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: while True: logger.info(f"Sending conversation with {len(conversation)} messages to the prompt target") - body = await self._construct_request_body(conversation=conversation, is_json_response=is_json_response) + body = await self._construct_request_body(conversation=conversation, json_config=json_config) # Use unified error handling - automatically detects Response and validates result = await self._handle_openai_request( diff --git a/tests/integration/targets/test_openai_responses_gpt5.py b/tests/integration/targets/test_openai_responses_gpt5.py index 08f0b52a7..2e241beed 100644 --- a/tests/integration/targets/test_openai_responses_gpt5.py +++ b/tests/integration/targets/test_openai_responses_gpt5.py @@ -2,24 +2,32 @@ # Licensed under the MIT license. +import json import os import uuid +import jsonschema import pytest +from pyrit.auth import get_azure_openai_auth from pyrit.models import MessagePiece from pyrit.prompt_target import OpenAIResponseTarget -@pytest.mark.asyncio -async def test_openai_responses_gpt5(sqlite_instance): - args = { - "endpoint": os.getenv("AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT"), +@pytest.fixture() +def gpt5_args(): + endpoint_value = os.environ["AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT"] + return { + "endpoint": endpoint_value, "model_name": os.getenv("AZURE_OPENAI_GPT5_MODEL"), - "api_key": os.getenv("AZURE_OPENAI_GPT5_KEY"), + # "api_key": os.getenv("AZURE_OPENAI_GPT5_KEY"), + "api_key": get_azure_openai_auth(endpoint_value), } - target = OpenAIResponseTarget(**args) + +@pytest.mark.asyncio +async def test_openai_responses_gpt5(sqlite_instance, gpt5_args): + target = OpenAIResponseTarget(**gpt5_args) conv_id = str(uuid.uuid4()) @@ -47,3 +55,91 @@ async def test_openai_responses_gpt5(sqlite_instance): assert result[0].message_pieces[1].role == "assistant" # Hope that the model manages to give the correct answer somewhere (GPT-5 really should) assert "Paris" in result[0].message_pieces[1].converted_value + + +@pytest.mark.asyncio +async def test_openai_responses_gpt5_json_schema(sqlite_instance, gpt5_args): + target = OpenAIResponseTarget(**gpt5_args) + + conv_id = str(uuid.uuid4()) + + developer_piece = MessagePiece( + role="developer", + original_value="You are an expert in the lore of cats.", + original_value_data_type="text", + conversation_id=conv_id, + attack_identifier={"id": str(uuid.uuid4())}, + ) + sqlite_instance.add_message_to_memory(request=developer_piece.to_message()) + + cat_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "minLength": 12}, + "age": {"type": "integer", "minimum": 0, "maximum": 20}, + "colour": { + "type": "array", + "items": {"type": "integer", "minimum": 0, "maximum": 255}, + "minItems": 3, + "maxItems": 3, + }, + }, + "required": ["name", "age", "colour"], + "additionalProperties": False, + } + + prompt = "Create a JSON object that describes a mystical cat " + prompt += "with the following properties: name, age, colour." + + user_piece = MessagePiece( + role="user", + original_value=prompt, + original_value_data_type="text", + conversation_id=conv_id, + prompt_metadata={"response_format": "json", "json_schema": json.dumps(cat_schema)}, + ) + + response = await target.send_prompt_async(message=user_piece.to_message()) + + assert len(response) == 1 + assert len(response[0].message_pieces) == 2 + response_piece = response[0].message_pieces[1] + assert response_piece.role == "assistant" + response_json = json.loads(response_piece.converted_value) + jsonschema.validate(instance=response_json, schema=cat_schema) + + +@pytest.mark.asyncio +async def test_openai_responses_gpt5_json_object(sqlite_instance, gpt5_args): + target = OpenAIResponseTarget(**gpt5_args) + + conv_id = str(uuid.uuid4()) + + developer_piece = MessagePiece( + role="developer", + original_value="You are an expert in the lore of cats.", + original_value_data_type="text", + conversation_id=conv_id, + attack_identifier={"id": str(uuid.uuid4())}, + ) + + sqlite_instance.add_message_to_memory(request=developer_piece.to_message()) + + prompt = "Create a JSON object that describes a mystical cat " + prompt += "with the following properties: name, age, colour." + + user_piece = MessagePiece( + role="user", + original_value=prompt, + original_value_data_type="text", + conversation_id=conv_id, + prompt_metadata={"response_format": "json"}, + ) + response = await target.send_prompt_async(message=user_piece.to_message()) + + assert len(response) == 1 + assert len(response[0].message_pieces) == 2 + response_piece = response[0].message_pieces[1] + assert response_piece.role == "assistant" + _ = json.loads(response_piece.converted_value) + # Can't assert more, since the failure could be due to a bad generation by the model diff --git a/tests/unit/models/test_json_response_config.py b/tests/unit/models/test_json_response_config.py new file mode 100644 index 000000000..f715907ab --- /dev/null +++ b/tests/unit/models/test_json_response_config.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json + +import pytest + +from pyrit.models import JsonResponseConfig + + +def test_with_none(): + config = JsonResponseConfig.from_metadata(metadata=None) + assert config.enabled is False + assert config.schema is None + assert config.schema_name == "CustomSchema" + assert config.strict is True + + +def test_with_json_object(): + metadata = { + "response_format": "json", + } + config = JsonResponseConfig.from_metadata(metadata=metadata) + assert config.enabled is True + assert config.schema is None + assert config.schema_name == "CustomSchema" + assert config.strict is True + + +def test_with_json_string_schema(): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + metadata = { + "response_format": "json", + "json_schema": json.dumps(schema), + "schema_name": "TestSchema", + "strict": False, + } + config = JsonResponseConfig.from_metadata(metadata=metadata) + assert config.enabled is True + assert config.schema == schema + assert config.schema_name == "TestSchema" + assert config.strict is False + + +def test_with_json_schema_object(): + schema = {"type": "object", "properties": {"age": {"type": "integer"}}} + metadata = { + "response_format": "json", + "json_schema": schema, + } + config = JsonResponseConfig.from_metadata(metadata=metadata) + assert config.enabled is True + assert config.schema == schema + assert config.schema_name == "CustomSchema" + assert config.strict is True + + +def test_with_invalid_json_schema_string(): + metadata = { + "response_format": "json", + "json_schema": "{invalid_json: true}", + } + with pytest.raises(ValueError) as e: + JsonResponseConfig.from_metadata(metadata=metadata) + assert "Invalid JSON schema provided" in str(e.value) + + +def test_other_response_format(): + metadata = { + "response_format": "something_really_improbably_to_have_here", + } + config = JsonResponseConfig.from_metadata(metadata=metadata) + assert config.enabled is False + assert config.schema is None + assert config.schema_name == "CustomSchema" + assert config.strict is True diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 7db148143..76f26c849 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -23,7 +23,7 @@ RateLimitException, ) from pyrit.memory.memory_interface import MemoryInterface -from pyrit.models import Message, MessagePiece +from pyrit.models import JsonResponseConfig, Message, MessagePiece from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget @@ -130,7 +130,6 @@ def test_init_is_json_supported_can_be_set_to_true(patch_central_database): @pytest.mark.asyncio() async def test_build_chat_messages_for_multi_modal(target: OpenAIChatTarget): - image_request = get_image_message_piece() entries = [ Message( @@ -183,23 +182,31 @@ async def test_construct_request_body_includes_extra_body_params( request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert body["key"] == "value" @pytest.mark.asyncio -@pytest.mark.parametrize("is_json", [True, False]) -async def test_construct_request_body_includes_json( - is_json, target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece -): +async def test_construct_request_body_json_object(target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece): + request = Message(message_pieces=[dummy_text_message_piece]) + jrc = JsonResponseConfig.from_metadata(metadata={"response_format": "json"}) + body = await target._construct_request_body(conversation=[request], json_config=jrc) + assert body["response_format"] == {"type": "json_object"} + + +@pytest.mark.asyncio +async def test_construct_request_body_json_schema(target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece): + schema_obj = {"type": "object", "properties": {"name": {"type": "string"}}} request = Message(message_pieces=[dummy_text_message_piece]) + jrc = JsonResponseConfig.from_metadata(metadata={"response_format": "json", "json_schema": schema_obj}) - body = await target._construct_request_body(conversation=[request], is_json_response=is_json) - if is_json: - assert body["response_format"] == {"type": "json_object"} - else: - assert "response_format" not in body + body = await target._construct_request_body(conversation=[request], json_config=jrc) + assert body["response_format"] == { + "type": "json_schema", + "json_schema": {"name": "CustomSchema", "schema": schema_obj, "strict": True}, + } @pytest.mark.asyncio @@ -208,13 +215,15 @@ async def test_construct_request_body_removes_empty_values( ): request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert "max_completion_tokens" not in body assert "max_tokens" not in body assert "temperature" not in body assert "top_p" not in body assert "frequency_penalty" not in body assert "presence_penalty" not in body + assert "response_format" not in body @pytest.mark.asyncio @@ -222,8 +231,9 @@ async def test_construct_request_body_serializes_text_message( target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece ): request = Message(message_pieces=[dummy_text_message_piece]) + jrc = JsonResponseConfig.from_metadata(metadata=None) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert ( body["messages"][0]["content"] == "dummy text" ), "Text messages are serialized in a simple way that's more broadly supported" @@ -236,8 +246,9 @@ async def test_construct_request_body_serializes_complex_message( image_piece = get_image_message_piece() image_piece.conversation_id = dummy_text_message_piece.conversation_id # Match conversation IDs request = Message(message_pieces=[dummy_text_message_piece, image_piece]) + jrc = JsonResponseConfig.from_metadata(metadata=None) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + body = await target._construct_request_body(conversation=[request], json_config=jrc) messages = body["messages"][0]["content"] assert len(messages) == 2, "Complex messages are serialized as a list" assert messages[0]["type"] == "text", "Text messages are serialized properly when multi-modal" @@ -441,7 +452,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di @pytest.mark.asyncio async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIChatTarget): - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")]) # Create proper mock request and response for RateLimitError @@ -503,7 +513,6 @@ async def test_send_prompt_async_content_filter_200(target: OpenAIChatTarget): def test_validate_request_unsupported_data_types(target: OpenAIChatTarget): - image_piece = get_image_message_piece() image_piece.converted_value_data_type = "new_unknown_type" # type: ignore message = Message( @@ -548,7 +557,6 @@ def test_inheritance_from_prompt_chat_target_base(): def test_is_response_format_json_supported(target: OpenAIChatTarget): - message_piece = MessagePiece( role="user", original_value="original prompt text", @@ -559,10 +567,28 @@ def test_is_response_format_json_supported(target: OpenAIChatTarget): ) result = target.is_response_format_json(message_piece) - + assert isinstance(result, bool) assert result is True +def test_is_response_format_json_schema_supported(target: OpenAIChatTarget): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + message_piece = MessagePiece( + role="user", + original_value="original prompt text", + converted_value="Hello, how are you?", + conversation_id="conversation_1", + sequence=0, + prompt_metadata={ + "response_format": "json", + "json_schema": json.dumps(schema), + }, + ) + + result = target.is_response_format_json(message_piece) + assert result + + def test_is_response_format_json_no_metadata(target: OpenAIChatTarget): message_piece = MessagePiece( role="user", diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index 4f905a718..f5ad4bb03 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -22,7 +22,7 @@ RateLimitException, ) from pyrit.memory.memory_interface import MemoryInterface -from pyrit.models import Message, MessagePiece +from pyrit.models import JsonResponseConfig, Message, MessagePiece from pyrit.prompt_target import OpenAIResponseTarget, PromptChatTarget @@ -138,7 +138,6 @@ def test_init_with_no_additional_request_headers_var_raises(): @pytest.mark.asyncio() async def test_build_input_for_multi_modal(target: OpenAIResponseTarget): - image_request = get_image_message_piece() conversation_id = image_request.conversation_id entries = [ @@ -217,23 +216,37 @@ async def test_construct_request_body_includes_extra_body_params( request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert body["key"] == "value" @pytest.mark.asyncio -@pytest.mark.parametrize("is_json", [True, False]) -async def test_construct_request_body_includes_json( - is_json, target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece -): +async def test_construct_request_body_json_object(target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece): + json_response_config = JsonResponseConfig(enabled=True) + request = Message(message_pieces=[dummy_text_message_piece]) + + body = await target._construct_request_body(conversation=[request], json_config=json_response_config) + assert body["text"] == {"format": {"type": "json_object"}} + +@pytest.mark.asyncio +async def test_construct_request_body_json_schema(target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece): + schema_object = {"type": "object", "properties": {"name": {"type": "string"}}} + json_response_config = JsonResponseConfig.from_metadata( + metadata={"response_format": "json", "json_schema": schema_object} + ) request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=is_json) - if is_json: - assert body["response_format"] == {"type": "json_object"} - else: - assert "response_format" not in body + body = await target._construct_request_body(conversation=[request], json_config=json_response_config) + assert body["text"] == { + "format": { + "type": "json_schema", + "schema": schema_object, + "name": "CustomSchema", + "strict": True, + } + } @pytest.mark.asyncio @@ -242,13 +255,15 @@ async def test_construct_request_body_removes_empty_values( ): request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + json_response_config = JsonResponseConfig(enabled=False) + body = await target._construct_request_body(conversation=[request], json_config=json_response_config) assert "max_completion_tokens" not in body assert "max_tokens" not in body assert "temperature" not in body assert "top_p" not in body assert "frequency_penalty" not in body assert "presence_penalty" not in body + assert "text" not in body @pytest.mark.asyncio @@ -257,7 +272,8 @@ async def test_construct_request_body_serializes_text_message( ): request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert body["input"][0]["content"][0]["text"] == "dummy text" @@ -265,13 +281,13 @@ async def test_construct_request_body_serializes_text_message( async def test_construct_request_body_serializes_complex_message( target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece ): - image_piece = get_image_message_piece() dummy_text_message_piece.conversation_id = image_piece.conversation_id request = Message(message_pieces=[dummy_text_message_piece, image_piece]) + jrc = JsonResponseConfig.from_metadata(metadata=None) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + body = await target._construct_request_body(conversation=[request], json_config=jrc) messages = body["input"][0]["content"] assert len(messages) == 2 assert messages[0]["type"] == "input_text" @@ -479,7 +495,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di @pytest.mark.asyncio async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIResponseTarget): - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")]) # Mock SDK to raise RateLimitError @@ -556,7 +571,6 @@ async def test_send_prompt_async_content_filter(target: OpenAIResponseTarget): def test_validate_request_unsupported_data_types(target: OpenAIResponseTarget): - image_piece = get_image_message_piece() image_piece.converted_value_data_type = "new_unknown_type" # type: ignore message = Message( @@ -589,7 +603,6 @@ def test_inheritance_from_prompt_chat_target(target: OpenAIResponseTarget): def test_is_response_format_json_supported(target: OpenAIResponseTarget): - message_piece = MessagePiece( role="user", original_value="original prompt text", @@ -601,9 +614,28 @@ def test_is_response_format_json_supported(target: OpenAIResponseTarget): result = target.is_response_format_json(message_piece) + assert isinstance(result, bool) assert result is True +def test_is_response_format_json_schema_supported(target: OpenAIResponseTarget): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + message_piece = MessagePiece( + role="user", + original_value="original prompt text", + converted_value="Hello, how are you?", + conversation_id="conversation_1", + sequence=0, + prompt_metadata={ + "response_format": "json", + "json_schema": json.dumps(schema), + }, + ) + + result = target.is_response_format_json(message_piece) + assert result + + def test_is_response_format_json_no_metadata(target: OpenAIResponseTarget): message_piece = MessagePiece( role="user", @@ -683,7 +715,8 @@ async def test_construct_request_body_filters_none( target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece ): req = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body([req], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[req], json_config=jrc) assert "max_output_tokens" not in body or body["max_output_tokens"] is None assert "temperature" not in body or body["temperature"] is None assert "top_p" not in body or body["top_p"] is None