diff --git a/cozeloop/entities/prompt.py b/cozeloop/entities/prompt.py index b3c1d68..f708523 100644 --- a/cozeloop/entities/prompt.py +++ b/cozeloop/entities/prompt.py @@ -9,6 +9,7 @@ class TemplateType(str, Enum): NORMAL = "normal" + JINJA2 = "jinja2" class Role(str, Enum): @@ -26,6 +27,15 @@ class ToolType(str, Enum): class VariableType(str, Enum): STRING = "string" PLACEHOLDER = "placeholder" + BOOLEAN = "boolean" + INTEGER = "integer" + FLOAT = "float" + OBJECT = "object" + ARRAY_STRING = "array" + ARRAY_BOOLEAN = "array" + ARRAY_INTEGER = "array" + ARRAY_FLOAT = "array" + ARRAY_OBJECT = "array" class ToolChoiceType(str, Enum): diff --git a/cozeloop/internal/prompt/converter.py b/cozeloop/internal/prompt/converter.py index ff1f065..92e8738 100644 --- a/cozeloop/internal/prompt/converter.py +++ b/cozeloop/internal/prompt/converter.py @@ -59,7 +59,16 @@ def _convert_message(msg: OpenAPIMessage) -> EntityMessage: def _convert_variable_type(openapi_type: OpenAPIVariableType) -> EntityVariableType: type_mapping = { OpenAPIVariableType.STRING: EntityVariableType.STRING, - OpenAPIVariableType.PLACEHOLDER: EntityVariableType.PLACEHOLDER + OpenAPIVariableType.PLACEHOLDER: EntityVariableType.PLACEHOLDER, + OpenAPIVariableType.BOOLEAN: EntityVariableType.BOOLEAN, + OpenAPIVariableType.INTEGER: EntityVariableType.INTEGER, + OpenAPIVariableType.FLOAT: EntityVariableType.FLOAT, + OpenAPIVariableType.OBJECT: EntityVariableType.OBJECT, + OpenAPIVariableType.ARRAY_STRING: EntityVariableType.ARRAY_STRING, + OpenAPIVariableType.ARRAY_INTEGER: EntityVariableType.ARRAY_INTEGER, + OpenAPIVariableType.ARRAY_FLOAT: EntityVariableType.ARRAY_FLOAT, + OpenAPIVariableType.ARRAY_BOOLEAN: EntityVariableType.ARRAY_BOOLEAN, + OpenAPIVariableType.ARRAY_OBJECT: EntityVariableType.ARRAY_OBJECT } return type_mapping.get(openapi_type, EntityVariableType.STRING) # Default to STRING type @@ -122,7 +131,8 @@ def _convert_llm_config(config: OpenAPIModelConfig) -> EntityModelConfig: def _convert_template_type(openapi_template_type: OpenAPITemplateType) -> EntityTemplateType: template_mapping = { - OpenAPITemplateType.NORMAL: EntityTemplateType.NORMAL + OpenAPITemplateType.NORMAL: EntityTemplateType.NORMAL, + OpenAPITemplateType.JINJA2: EntityTemplateType.JINJA2 } return template_mapping.get(openapi_template_type, EntityTemplateType.NORMAL) # Default to NORMAL type diff --git a/cozeloop/internal/prompt/openapi.py b/cozeloop/internal/prompt/openapi.py index c9bf4ca..7335d4b 100644 --- a/cozeloop/internal/prompt/openapi.py +++ b/cozeloop/internal/prompt/openapi.py @@ -14,6 +14,7 @@ class TemplateType(str, Enum): NORMAL = "normal" + JINJA2 = "jinja2" class Role(str, Enum): @@ -31,6 +32,15 @@ class ToolType(str, Enum): class VariableType(str, Enum): STRING = "string" PLACEHOLDER = "placeholder" + BOOLEAN = "boolean" + INTEGER = "integer" + FLOAT = "float" + OBJECT = "object" + ARRAY_STRING = "array" + ARRAY_BOOLEAN = "array" + ARRAY_INTEGER = "array" + ARRAY_FLOAT = "array" + ARRAY_OBJECT = "array" class ToolChoiceType(str, Enum): diff --git a/cozeloop/internal/prompt/prompt.py b/cozeloop/internal/prompt/prompt.py index 84a777d..fd04232 100644 --- a/cozeloop/internal/prompt/prompt.py +++ b/cozeloop/internal/prompt/prompt.py @@ -6,6 +6,7 @@ from jinja2 import Environment, BaseLoader, Undefined from jinja2.utils import missing, object_type_repr +from jinja2.sandbox import SandboxedEnvironment from cozeloop.spec.tracespec import PROMPT_KEY, INPUT, PROMPT_VERSION, V_SCENE_PROMPT_TEMPLATE, V_SCENE_PROMPT_HUB from cozeloop.entities.prompt import (Prompt, Message, VariableDef, VariableType, TemplateType, Role, @@ -153,6 +154,27 @@ def _validate_variable_values_type(self, variable_defs: List[VariableDef], varia elif var_def.type == VariableType.PLACEHOLDER: if not (isinstance(val, Message) or (isinstance(val, List) and all(isinstance(item, Message) for item in val))): raise ValueError(f"type of variable '{var_def.key}' should be Message like object") + elif var_def.type == VariableType.BOOLEAN: + if not isinstance(val, bool): + raise ValueError(f"type of variable '{var_def.key}' should be bool") + elif var_def.type == VariableType.INTEGER: + if not isinstance(val, int): + raise ValueError(f"type of variable '{var_def.key}' should be int") + elif var_def.type == VariableType.FLOAT: + if not isinstance(val, float): + raise ValueError(f"type of variable '{var_def.key}' should be float") + elif var_def.type == VariableType.ARRAY_STRING: + if not isinstance(val, list) or not all(isinstance(item, str) for item in val): + raise ValueError(f"type of variable '{var_def.key}' should be array") + elif var_def.type == VariableType.ARRAY_BOOLEAN: + if not isinstance(val, list) or not all(isinstance(item, bool) for item in val): + raise ValueError(f"type of variable '{var_def.key}' should be array") + elif var_def.type == VariableType.ARRAY_INTEGER: + if not isinstance(val, list) or not all(isinstance(item, int) for item in val): + raise ValueError(f"type of variable '{var_def.key}' should be array") + elif var_def.type == VariableType.ARRAY_FLOAT: + if not isinstance(val, list) or not all(isinstance(item, float) for item in val): + raise ValueError(f"type of variable '{var_def.key}' should be array") def _format_normal_messages( self, @@ -217,7 +239,7 @@ def _render_text_content( ) -> str: if template_type == TemplateType.NORMAL: # Create custom Environment using DebugUndefined to preserve original form of undefined variables - env = Environment( + env = SandboxedEnvironment( loader=BaseLoader(), undefined=CustomUndefined, variable_start_string='{{', @@ -230,10 +252,20 @@ def _render_text_content( render_vars = {k: variables.get(k, '') for k in variable_def_map.keys()} # Render template return template.render(**render_vars) + elif template_type == TemplateType.JINJA2: + return self._render_jinja2_template(template_str, variable_def_map, variables) else: raise ValueError(f"text render unsupported template type: {template_type}") + def _render_jinja2_template(self, template_str: str, variable_def_map: Dict[str, VariableDef], + variables: Dict[str, Any]) -> str: + """渲染 Jinja2 模板""" + env = SandboxedEnvironment() + template = env.from_string(template_str) + render_vars = {k: variables[k] for k in variable_def_map.keys() if variables is not None and k in variables} + return template.render(**render_vars) + class CustomUndefined(Undefined): __slots__ = () diff --git a/examples/prompt/advance/prompt_hub_with_jinja.py b/examples/prompt/advance/prompt_hub_with_jinja.py new file mode 100644 index 0000000..9914145 --- /dev/null +++ b/examples/prompt/advance/prompt_hub_with_jinja.py @@ -0,0 +1,155 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import json +import time +from typing import List + +import cozeloop +from cozeloop import Message +from cozeloop.entities.prompt import Role +from cozeloop.spec.tracespec import CALL_OPTIONS, ModelCallOption, ModelMessage, ModelInput + + +def convert_model_input(messages: List[Message]) -> ModelInput: + model_messages = [] + for message in messages: + model_messages.append(ModelMessage( + role=str(message.role), + content=message.content if message.content is not None else "" + )) + + return ModelInput( + messages=model_messages + ) + + +class LLMRunner: + def __init__(self, client): + self.client = client + + def llm_call(self, input_data): + """ + Simulate an LLM call and set relevant span tags. + """ + span = self.client.start_span("llmCall", "model") + try: + # Assuming llm is processing + # output = ChatOpenAI().invoke(input=input_data) + + # mock resp + time.sleep(1) + output = "I'm a robot. I don't have a specific name. You can give me one." + input_token = 232 + output_token = 1211 + + # set tag key: `input` + span.set_input(convert_model_input(input_data)) + # set tag key: `output` + span.set_output(output) + # set tag key: `model_provider`, e.g., openai, etc. + span.set_model_provider("openai") + # set tag key: `start_time_first_resp` + # Timestamp of the first packet return from LLM, unit: microseconds. + # When `start_time_first_resp` is set, a tag named `latency_first_resp` calculated + # based on the span's StartTime will be added, meaning the latency for the first packet. + span.set_start_time_first_resp(int(time.time() * 1000000)) + # set tag key: `input_tokens`. The amount of input tokens. + # when the `input_tokens` value is set, it will automatically sum with the `output_tokens` to calculate the `tokens` tag. + span.set_input_tokens(input_token) + # set tag key: `output_tokens`. The amount of output tokens. + # when the `output_tokens` value is set, it will automatically sum with the `input_tokens` to calculate the `tokens` tag. + span.set_output_tokens(output_token) + # set tag key: `model_name`, e.g., gpt-4-1106-preview, etc. + span.set_model_name("gpt-4-1106-preview") + span.set_tags({CALL_OPTIONS: ModelCallOption( + temperature=0.5, + top_p=0.5, + top_k=10, + presence_penalty=0.5, + frequency_penalty=0.5, + max_tokens=1024, + )}) + + return None + except Exception as e: + raise e + finally: + span.finish() + +# If you want to use the jinja templates in prompts, you can refer to the following. +if __name__ == '__main__': + # 1.Create a prompt on the platform + # You can create a Prompt on the platform's Prompt development page (set Prompt Key to 'prompt_hub_demo'), + # add the following messages to the template, and submit a version. + # System: You are a helpful bot, the conversation topic is {{var1}}. + # Placeholder: placeholder1 + # User: My question is {{var2}} + # Placeholder: placeholder2 + + # Set the following environment variables first. + # COZELOOP_WORKSPACE_ID=your workspace id + # COZELOOP_API_TOKEN=your token + # 2.New loop client + client = cozeloop.new_client( + # Set whether to report a trace span when get or format prompt. + # Default value is false. + prompt_trace=True) + + # 3. new root span + rootSpan = client.start_span("root_span", "main_span") + + # 4. Get the prompt + # If no specific version is specified, the latest version of the corresponding prompt will be obtained + prompt = client.get_prompt(prompt_key="prompt_hub_demo", version="0.0.1") + if prompt is not None: + # Get messages of the prompt + if prompt.prompt_template is not None: + messages = prompt.prompt_template.messages + print( + f"prompt messages: {json.dumps([message.model_dump(exclude_none=True) for message in messages], ensure_ascii=False)}") + # Get llm config of the prompt + if prompt.llm_config is not None: + llm_config = prompt.llm_config + print(f"prompt llm_config: {llm_config.model_dump_json(exclude_none=True)}") + + # 5.Format messages of the prompt + formatted_messages = client.prompt_format(prompt, { + "var_string": "hi", + "var_int": 5, + "var_bool": True, + "var_float": 1.0, + "var_object": { + "name": "John", + "age": 30, + "hobbies": ["reading", "coding"], + "address": { + "city": "bejing", + "street": "123 Main", + }, + }, + "var_array_string": ["hello", "nihao"], + "var_array_boolean": [True, False, True], + "var_array_int": [1, 2, 3, 4], + "var_array_float": [1.0, 2.0], + "var_array_object": [{"key": "123"}, {"value": 100}], + # Placeholder variable type should be Message/List[Message] + "placeholder1": [Message(role=Role.USER, content="Hello!"), + Message(role=Role.ASSISTANT, content="Hello!")] + # Other variables in the prompt template that are not provided with corresponding values will be + # considered as empty values. + }) + print( + f"formatted_messages: {json.dumps([message.model_dump(exclude_none=True) for message in formatted_messages], ensure_ascii=False)}") + + # 6.LLM call + llm_runner = LLMRunner(client) + llm_runner.llm_call(formatted_messages) + + rootSpan.finish() + # 4. (optional) flush or close + # -- force flush, report all traces in the queue + # Warning! In general, this method is not needed to be call, as spans will be automatically reported in batches. + # Note that flush will block and wait for the report to complete, and it may cause frequent reporting, + # affecting performance. + client.flush() diff --git a/tests/internal/prompt/test_prompt.py b/tests/internal/prompt/test_prompt.py index 3b8d203..b309c9f 100644 --- a/tests/internal/prompt/test_prompt.py +++ b/tests/internal/prompt/test_prompt.py @@ -634,3 +634,493 @@ def test_format_normal_messages_null_message(prompt_provider): assert result[0].content == "Hello" assert result[1].role == Role.USER assert result[1].content == "World" +def test_validate_variable_values_type_boolean_valid(prompt_provider): + """测试有效的 boolean 类型变量""" + var_defs = [VariableDef(key="enabled", desc="Enable feature", type=VariableType.BOOLEAN)] + variables = {"enabled": True} + + # 应该不抛出异常 + prompt_provider._validate_variable_values_type(var_defs, variables) + + +def test_validate_variable_values_type_boolean_invalid(prompt_provider): + """测试无效的 boolean 类型变量""" + var_defs = [VariableDef(key="enabled", desc="Enable feature", type=VariableType.BOOLEAN)] + variables = {"enabled": "true"} # 字符串而不是布尔值 + + with pytest.raises(ValueError) as excinfo: + prompt_provider._validate_variable_values_type(var_defs, variables) + + assert "type of variable 'enabled' should be bool" in str(excinfo.value) + + +def test_validate_variable_values_type_integer_valid(prompt_provider): + """测试有效的 integer 类型变量""" + var_defs = [VariableDef(key="count", desc="Item count", type=VariableType.INTEGER)] + variables = {"count": 42} + + # 应该不抛出异常 + prompt_provider._validate_variable_values_type(var_defs, variables) + + +def test_validate_variable_values_type_integer_invalid(prompt_provider): + """测试无效的 integer 类型变量""" + var_defs = [VariableDef(key="count", desc="Item count", type=VariableType.INTEGER)] + variables = {"count": "42"} # 字符串而不是整数 + + with pytest.raises(ValueError) as excinfo: + prompt_provider._validate_variable_values_type(var_defs, variables) + + assert "type of variable 'count' should be int" in str(excinfo.value) + + +def test_validate_variable_values_type_float_valid(prompt_provider): + """测试有效的 float 类型变量""" + var_defs = [VariableDef(key="temperature", desc="Temperature value", type=VariableType.FLOAT)] + variables = {"temperature": 3.14} + + # 应该不抛出异常 + prompt_provider._validate_variable_values_type(var_defs, variables) + + +def test_validate_variable_values_type_float_invalid(prompt_provider): + """测试无效的 float 类型变量""" + var_defs = [VariableDef(key="temperature", desc="Temperature value", type=VariableType.FLOAT)] + variables = {"temperature": "3.14"} # 字符串而不是浮点数 + + with pytest.raises(ValueError) as excinfo: + prompt_provider._validate_variable_values_type(var_defs, variables) + + assert "type of variable 'temperature' should be float" in str(excinfo.value) + + +def test_validate_variable_values_type_array_string_valid(prompt_provider): + """测试有效的 array 类型变量""" + var_defs = [VariableDef(key="tags", desc="Tag list", type=VariableType.ARRAY_STRING)] + variables = {"tags": ["tag1", "tag2", "tag3"]} + + # 应该不抛出异常 + prompt_provider._validate_variable_values_type(var_defs, variables) + + +def test_validate_variable_values_type_array_string_invalid_not_list(prompt_provider): + """测试无效的 array 类型变量 - 不是列表""" + var_defs = [VariableDef(key="tags", desc="Tag list", type=VariableType.ARRAY_STRING)] + variables = {"tags": "tag1,tag2,tag3"} # 字符串而不是列表 + + with pytest.raises(ValueError) as excinfo: + prompt_provider._validate_variable_values_type(var_defs, variables) + + assert "type of variable 'tags' should be array" in str(excinfo.value) + + +def test_validate_variable_values_type_array_string_invalid_wrong_element_type(prompt_provider): + """测试无效的 array 类型变量 - 元素类型错误""" + var_defs = [VariableDef(key="tags", desc="Tag list", type=VariableType.ARRAY_STRING)] + variables = {"tags": ["tag1", 123, "tag3"]} # 包含非字符串元素 + + with pytest.raises(ValueError) as excinfo: + prompt_provider._validate_variable_values_type(var_defs, variables) + + assert "type of variable 'tags' should be array" in str(excinfo.value) + + +def test_validate_variable_values_type_array_boolean_valid(prompt_provider): + """测试有效的 array 类型变量""" + var_defs = [VariableDef(key="flags", desc="Boolean flags", type=VariableType.ARRAY_BOOLEAN)] + variables = {"flags": [True, False, True]} + + # 应该不抛出异常 + prompt_provider._validate_variable_values_type(var_defs, variables) + + +def test_validate_variable_values_type_array_boolean_invalid(prompt_provider): + """测试无效的 array 类型变量""" + var_defs = [VariableDef(key="flags", desc="Boolean flags", type=VariableType.ARRAY_BOOLEAN)] + variables = {"flags": [True, "false", True]} # 包含字符串而不是布尔值 + + with pytest.raises(ValueError) as excinfo: + prompt_provider._validate_variable_values_type(var_defs, variables) + + assert "type of variable 'flags' should be array" in str(excinfo.value) + + +def test_validate_variable_values_type_array_integer_valid(prompt_provider): + """测试有效的 array 类型变量""" + var_defs = [VariableDef(key="numbers", desc="Number list", type=VariableType.ARRAY_INTEGER)] + variables = {"numbers": [1, 2, 3, 4, 5]} + + # 应该不抛出异常 + prompt_provider._validate_variable_values_type(var_defs, variables) + + +def test_validate_variable_values_type_array_integer_invalid(prompt_provider): + """测试无效的 array 类型变量""" + var_defs = [VariableDef(key="numbers", desc="Number list", type=VariableType.ARRAY_INTEGER)] + variables = {"numbers": [1, "2", 3]} # 包含字符串而不是整数 + + with pytest.raises(ValueError) as excinfo: + prompt_provider._validate_variable_values_type(var_defs, variables) + + assert "type of variable 'numbers' should be array" in str(excinfo.value) + + +def test_validate_variable_values_type_array_float_valid(prompt_provider): + """测试有效的 array 类型变量""" + var_defs = [VariableDef(key="scores", desc="Score list", type=VariableType.ARRAY_FLOAT)] + variables = {"scores": [1.5, 2.7, 3.14]} + + # 应该不抛出异常 + prompt_provider._validate_variable_values_type(var_defs, variables) + + +def test_validate_variable_values_type_array_float_invalid(prompt_provider): + """测试无效的 array 类型变量""" + var_defs = [VariableDef(key="scores", desc="Score list", type=VariableType.ARRAY_FLOAT)] + variables = {"scores": [1.5, "2.7", 3.14]} # 包含字符串而不是浮点数 + + with pytest.raises(ValueError) as excinfo: + prompt_provider._validate_variable_values_type(var_defs, variables) + + assert "type of variable 'scores' should be array" in str(excinfo.value) + + +def test_validate_variable_values_type_mixed_valid(prompt_provider): + """测试多种类型变量的混合验证""" + var_defs = [ + VariableDef(key="name", desc="User name", type=VariableType.STRING), + VariableDef(key="enabled", desc="Enable feature", type=VariableType.BOOLEAN), + VariableDef(key="count", desc="Item count", type=VariableType.INTEGER), + VariableDef(key="temperature", desc="Temperature", type=VariableType.FLOAT), + VariableDef(key="tags", desc="Tag list", type=VariableType.ARRAY_STRING), + ] + variables = { + "name": "Alice", + "enabled": True, + "count": 42, + "temperature": 3.14, + "tags": ["tag1", "tag2"] + } + + # 应该不抛出异常 + prompt_provider._validate_variable_values_type(var_defs, variables) + + +# ============================================================================= +# Jinja2模板渲染测试 +# ============================================================================= + +def test_render_jinja2_template_basic(prompt_provider): + """测试基本的 Jinja2 模板渲染""" + template_str = "Hello {{ name }}!" + variable_def_map = {"name": VariableDef(key="name", desc="User name", type=VariableType.STRING)} + variables = {"name": "Alice"} + + result = prompt_provider._render_jinja2_template(template_str, variable_def_map, variables) + assert result == "Hello Alice!" + + +def test_render_jinja2_template_variable_substitution(prompt_provider): + """测试 Jinja2 模板变量替换""" + template_str = "Welcome {{ name }}, you have {{ count }} messages." + variable_def_map = { + "name": VariableDef(key="name", desc="User name", type=VariableType.STRING), + "count": VariableDef(key="count", desc="Message count", type=VariableType.INTEGER) + } + variables = {"name": "Bob", "count": 5} + + result = prompt_provider._render_jinja2_template(template_str, variable_def_map, variables) + assert result == "Welcome Bob, you have 5 messages." + + +def test_render_jinja2_template_with_loops(prompt_provider): + """测试 Jinja2 模板循环语句""" + template_str = "Items: {% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}" + variable_def_map = {"items": VariableDef(key="items", desc="Item list", type=VariableType.ARRAY_STRING)} + variables = {"items": ["apple", "banana", "orange"]} + + result = prompt_provider._render_jinja2_template(template_str, variable_def_map, variables) + assert result == "Items: apple, banana, orange" + + +def test_render_jinja2_template_with_conditionals(prompt_provider): + """测试 Jinja2 模板条件语句""" + template_str = "{% if enabled %}Feature is enabled{% else %}Feature is disabled{% endif %}" + variable_def_map = {"enabled": VariableDef(key="enabled", desc="Feature enabled", type=VariableType.BOOLEAN)} + variables = {"enabled": True} + + result = prompt_provider._render_jinja2_template(template_str, variable_def_map, variables) + assert result == "Feature is enabled" + + +def test_render_jinja2_template_with_conditionals_false(prompt_provider): + """测试 Jinja2 模板条件语句 - false 分支""" + template_str = "{% if enabled %}Feature is enabled{% else %}Feature is disabled{% endif %}" + variable_def_map = {"enabled": VariableDef(key="enabled", desc="Feature enabled", type=VariableType.BOOLEAN)} + variables = {"enabled": False} + + result = prompt_provider._render_jinja2_template(template_str, variable_def_map, variables) + assert result == "Feature is disabled" + + +def test_render_jinja2_template_undefined_variable(prompt_provider): + """测试 Jinja2 模板未定义变量处理""" + template_str = "Hello {{ name }}!" + variable_def_map = {"name": VariableDef(key="name", desc="User name", type=VariableType.STRING)} + variables = {} # 没有提供 name 变量 + + result = prompt_provider._render_jinja2_template(template_str, variable_def_map, variables) + assert result == "Hello !" # 未定义变量被替换为空字符串 + + +def test_render_jinja2_template_complex(prompt_provider): + """测试复杂的 Jinja2 模板""" + template_str = """ +{%- if user -%} +Hello {{ user }}! +{%- if items -%} +Your items: +{%- for item in items %} +- {{ item }} +{%- endfor -%} +{%- else -%} +You have no items. +{%- endif -%} +{%- else -%} +Hello stranger! +{%- endif -%} +""".strip() + + variable_def_map = { + "user": VariableDef(key="user", desc="User name", type=VariableType.STRING), + "items": VariableDef(key="items", desc="Item list", type=VariableType.ARRAY_STRING) + } + variables = {"user": "Alice", "items": ["book", "pen"]} + + result = prompt_provider._render_jinja2_template(template_str, variable_def_map, variables) + expected = "Hello Alice!Your items:\n- book\n- pen" + assert result == expected + + +def test_render_jinja2_template_sandbox_security(prompt_provider): + """测试 Jinja2 模板沙箱环境安全性""" + # 尝试访问不安全的内置函数 + template_str = "{{ ''.__class__.__mro__[1].__subclasses__() }}" + variable_def_map = {} + variables = {} + + # 沙箱环境应该阻止这种访问 + with pytest.raises(Exception): # SandboxedEnvironment 会抛出异常 + prompt_provider._render_jinja2_template(template_str, variable_def_map, variables) + + +def test_render_text_content_jinja2(prompt_provider): + """测试通过 _render_text_content 方法使用 Jinja2 模板""" + template_str = "Hello {{ name }}! You have {{ count }} items." + variable_def_map = { + "name": VariableDef(key="name", desc="User name", type=VariableType.STRING), + "count": VariableDef(key="count", desc="Item count", type=VariableType.INTEGER) + } + variables = {"name": "Charlie", "count": 3} + + result = prompt_provider._render_text_content( + TemplateType.JINJA2, + template_str, + variable_def_map, + variables + ) + + assert result == "Hello Charlie! You have 3 items." + + +# ============================================================================= +# 集成测试 +# ============================================================================= + +def test_prompt_format_jinja2_integration(prompt_provider): + """测试使用 Jinja2 模板的完整 prompt 格式化""" + # 创建 Jinja2 模板消息 + system_message = Message(role=Role.SYSTEM, content="You are a helpful assistant for {{ domain }}.") + user_message = Message(role=Role.USER, content="{% if urgent %}URGENT: {% endif %}{{ question }}") + + var_defs = [ + VariableDef(key="domain", desc="Domain", type=VariableType.STRING), + VariableDef(key="urgent", desc="Is urgent", type=VariableType.BOOLEAN), + VariableDef(key="question", desc="User question", type=VariableType.STRING) + ] + + prompt_template = MagicMock() + prompt_template.template_type = TemplateType.JINJA2 + prompt_template.messages = [system_message, user_message] + prompt_template.variable_defs = var_defs + + prompt = MagicMock(spec=Prompt) + prompt.prompt_template = prompt_template + + variables = { + "domain": "programming", + "urgent": True, + "question": "How to fix this bug?" + } + + # 调用方法 + result = prompt_provider.prompt_format(prompt, variables) + + # 验证结果 + assert len(result) == 2 + assert result[0].role == Role.SYSTEM + assert result[0].content == "You are a helpful assistant for programming." + assert result[1].role == Role.USER + assert result[1].content == "URGENT: How to fix this bug?" + + +def test_prompt_format_jinja2_with_arrays(prompt_provider): + """测试使用数组变量的 Jinja2 模板格式化""" + message = Message( + role=Role.USER, + content="Please process these items: {% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}" + ) + + var_defs = [ + VariableDef(key="items", desc="Item list", type=VariableType.ARRAY_STRING) + ] + + prompt_template = MagicMock() + prompt_template.template_type = TemplateType.JINJA2 + prompt_template.messages = [message] + prompt_template.variable_defs = var_defs + + prompt = MagicMock(spec=Prompt) + prompt.prompt_template = prompt_template + + variables = {"items": ["task1", "task2", "task3"]} + + # 调用方法 + result = prompt_provider.prompt_format(prompt, variables) + + # 验证结果 + assert len(result) == 1 + assert result[0].role == Role.USER + assert result[0].content == "Please process these items: task1, task2, task3" + + +def test_prompt_format_mixed_template_types(prompt_provider): + """测试混合使用 NORMAL 和 JINJA2 模板的场景(通过不同的 prompt)""" + # 测试 NORMAL 模板 + normal_message = Message(role=Role.USER, content="Hello {{name}}!") + normal_var_defs = [VariableDef(key="name", desc="User name", type=VariableType.STRING)] + + normal_template = MagicMock() + normal_template.template_type = TemplateType.NORMAL + normal_template.messages = [normal_message] + normal_template.variable_defs = normal_var_defs + + normal_prompt = MagicMock(spec=Prompt) + normal_prompt.prompt_template = normal_template + + normal_variables = {"name": "Alice"} + normal_result = prompt_provider.prompt_format(normal_prompt, normal_variables) + + # 测试 JINJA2 模板 + jinja2_message = Message(role=Role.USER, content="Hello {{ name }}!") + jinja2_var_defs = [VariableDef(key="name", desc="User name", type=VariableType.STRING)] + + jinja2_template = MagicMock() + jinja2_template.template_type = TemplateType.JINJA2 + jinja2_template.messages = [jinja2_message] + jinja2_template.variable_defs = jinja2_var_defs + + jinja2_prompt = MagicMock(spec=Prompt) + jinja2_prompt.prompt_template = jinja2_template + + jinja2_variables = {"name": "Bob"} + jinja2_result = prompt_provider.prompt_format(jinja2_prompt, jinja2_variables) + + # 验证两种模板都能正常工作 + assert len(normal_result) == 1 + assert normal_result[0].content == "Hello Alice!" + + assert len(jinja2_result) == 1 + assert jinja2_result[0].content == "Hello Bob!" + + +def test_prompt_format_jinja2_with_placeholder(prompt_provider): + """测试 Jinja2 模板与 placeholder 消息的组合""" + system_message = Message(role=Role.SYSTEM, content="You are helping with {{ task_type }}.") + placeholder_message = Message(role=Role.PLACEHOLDER, content="history") + user_message = Message(role=Role.USER, content="{% if urgent %}URGENT: {% endif %}{{ question }}") + + var_defs = [ + VariableDef(key="task_type", desc="Task type", type=VariableType.STRING), + VariableDef(key="urgent", desc="Is urgent", type=VariableType.BOOLEAN), + VariableDef(key="question", desc="User question", type=VariableType.STRING), + VariableDef(key="history", desc="Chat history", type=VariableType.PLACEHOLDER) + ] + + prompt_template = MagicMock() + prompt_template.template_type = TemplateType.JINJA2 + prompt_template.messages = [system_message, placeholder_message, user_message] + prompt_template.variable_defs = var_defs + + prompt = MagicMock(spec=Prompt) + prompt.prompt_template = prompt_template + + # 创建历史消息 + history_messages = [ + Message(role=Role.USER, content="Previous question"), + Message(role=Role.ASSISTANT, content="Previous answer") + ] + + variables = { + "task_type": "debugging", + "urgent": False, + "question": "What's the issue?", + "history": history_messages + } + + # 调用方法 + result = prompt_provider.prompt_format(prompt, variables) + + # 验证结果 + assert len(result) == 4 + assert result[0].role == Role.SYSTEM + assert result[0].content == "You are helping with debugging." + assert result[1].role == Role.USER + assert result[1].content == "Previous question" + assert result[2].role == Role.ASSISTANT + assert result[2].content == "Previous answer" + assert result[3].role == Role.USER + assert result[3].content == "What's the issue?" + + +def test_prompt_format_jinja2_edge_cases(prompt_provider): + """测试 Jinja2 模板的边界情况""" + # 空模板 + empty_message = Message(role=Role.USER, content="") + + # 只有空格的模板 + whitespace_message = Message(role=Role.USER, content=" ") + + # 没有变量的模板 + no_vars_message = Message(role=Role.USER, content="Static text") + + var_defs = [] + + prompt_template = MagicMock() + prompt_template.template_type = TemplateType.JINJA2 + prompt_template.messages = [empty_message, whitespace_message, no_vars_message] + prompt_template.variable_defs = var_defs + + prompt = MagicMock(spec=Prompt) + prompt.prompt_template = prompt_template + + variables = {} + + # 调用方法 + result = prompt_provider.prompt_format(prompt, variables) + + # 验证结果 + assert len(result) == 3 + assert result[0].content == "" + assert result[1].content == " " + assert result[2].content == "Static text" \ No newline at end of file