diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index d31fb90..3e73638 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -191,17 +191,27 @@ class chat_template { }}, }; }; + auto make_tool_call_response = [](const std::string & tool_call_id, const std::string & tool_name, const std::string & content) { + return json { + {"role", "tool"}, + {"name", tool_name}, + {"content", content}, + {"tool_call_id", tool_call_id}, + }; + }; const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}}; // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want. out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})), + make_tool_call_response("call_1___", "ipython", "Hello, World!"), }), {}, false); auto tool_call_renders_str_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})), + make_tool_call_response("call_1___", "ipython", "Hello, World!"), }), {}, false); auto tool_call_renders_obj_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); @@ -215,18 +225,14 @@ class chat_template { auto out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({tc1, tc2})), + dummy_user_msg, }), {}, false); caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2"); out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({tc1})), - { - {"role", "tool"}, - {"name", "test_tool1"}, - {"content", "Some response!"}, - {"tool_call_id", "call_911_"}, - } + make_tool_call_response("call_911_", "test_tool1", "Some response!"), }), {}, false); caps_.supports_tool_responses = contains(out, "Some response!"); caps_.supports_tool_call_id = contains(out, "call_911_"); diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index acaf969..1a500b6 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -185,17 +185,26 @@ def make_tool_call(tool_name, arguments): "name": tool_name, } } + def make_tool_call_response(tool_call_id, tool_name, content): + return { + "role": "tool", + "name": tool_name, + "content": content, + "tool_call_id": tool_call_id, + } dummy_args_obj = {"argument_needle": "print('Hello, World!')"} out = self.try_raw_render([ dummy_user_msg, make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), + make_tool_call_response("call_1___", "ipython", "Hello, world!"), ]) tool_call_renders_str_arguments = "" in out or '"argument_needle":' in out or "'argument_needle':" in out out = self.try_raw_render([ dummy_user_msg, make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), + make_tool_call_response("call_1___", "ipython", "Hello, world!"), ]) tool_call_renders_obj_arguments = "" in out or '"argument_needle":' in out or "'argument_needle':" in out @@ -209,18 +218,14 @@ def make_tool_call(tool_name, arguments): out = self.try_raw_render([ dummy_user_msg, make_tool_calls_msg([tc1, tc2]), + dummy_user_msg, ]) caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out out = self.try_raw_render([ dummy_user_msg, make_tool_calls_msg([tc1]), - { - "role": "tool", - "name": "test_tool1", - "content": "Some response!", - "tool_call_id": "call_911_", - } + make_tool_call_response("call_911_", "test_tool1", "Some response!"), ]) caps.supports_tool_responses = "Some response!" in out caps.supports_tool_call_id = "call_911_" in out diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0388a74..4743428 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -223,6 +223,7 @@ set(MODEL_IDS nvidia/Eagle2-1B nvidia/Eagle2-9B nvidia/Llama-3.1-Nemotron-70B-Instruct-HF + nvidia/NVIDIA-Nemotron-Nano-9B-v2 OnlyCheeini/greesychat-turbo onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX open-thoughts/OpenThinker-7B diff --git a/tests/test-capabilities.cpp b/tests/test-capabilities.cpp index 458f9b9..45d01be 100644 --- a/tests/test-capabilities.cpp +++ b/tests/test-capabilities.cpp @@ -222,6 +222,18 @@ TEST(CapabilitiesTest, NousResearchHermes2ProLlama3_8BToolUse) { EXPECT_FALSE(caps.requires_typed_content); } +TEST(CapabilitiesTest, NvidiaNemotronNano_9BToolUse) { + auto caps = get_caps("tests/nvidia-NVIDIA-Nemotron-Nano-9B-v2.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + TEST(CapabilitiesTest, CommandRPlusDefault) { auto caps = get_caps("tests/CohereForAI-c4ai-command-r-plus-default.jinja"); EXPECT_TRUE(caps.supports_system_role);