Skip to content

Commit c05aa69

Browse files
authored
common : add nemotron 3 parsing (#18077)
* common : expose json-schema functionality to extract type info * common : fix peg parser negation during needs_more_input * common : add some defensive measures in constructed peg parser * common : add nemotron nano 3 support * common : add nemotron nano 3 tests * remove debug line
1 parent 279cef2 commit c05aa69

File tree

8 files changed

+741
-6
lines changed

8 files changed

+741
-6
lines changed

common/chat-peg-parser.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
using json = nlohmann::json;
66

7-
static std::string_view trim_trailing_space(std::string_view sv) {
7+
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
8+
int count = 0;
89
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
10+
if (max != -1 && count <= max) {
11+
break;
12+
}
913
sv.remove_suffix(1);
14+
count++;
1015
}
1116
return sv;
1217
}
@@ -93,14 +98,15 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
9398

9499
if (is_arg_string && current_tool) {
95100
// Serialize to JSON, but exclude the end quote
96-
std::string dumped = json(node.text).dump();
101+
std::string dumped = json(trim_trailing_space(node.text)).dump();
97102
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
98103
needs_closing_quote = true;
99104
}
100105

101106
if (is_arg_close && current_tool) {
102107
if (needs_closing_quote) {
103108
current_tool->arguments += "\"";
109+
needs_closing_quote = false;
104110
}
105111
}
106112

@@ -109,6 +115,10 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
109115
}
110116

111117
if (is_tool_close && current_tool) {
118+
if (needs_closing_quote) {
119+
current_tool->arguments += "\"";
120+
needs_closing_quote = false;
121+
}
112122
current_tool->arguments += "}";
113123
}
114124
}

common/chat.cpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,25 @@ static void foreach_function(const json & tools, const std::function<void(const
711711
}
712712
}
713713

714+
static void foreach_parameter(const json & function, const std::function<void(const std::string &, const json &, bool)> & fn) {
715+
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
716+
return;
717+
}
718+
const auto & params = function.at("parameters");
719+
if (!params.contains("properties") || !params.at("properties").is_object()) {
720+
return;
721+
}
722+
const auto & props = params.at("properties");
723+
std::set<std::string> required;
724+
if (params.contains("required") && params.at("required").is_array()) {
725+
params.at("required").get_to(required);
726+
}
727+
for (const auto & [name, prop] : props.items()) {
728+
bool is_required = (required.find(name) != required.end());
729+
fn(name, prop, is_required);
730+
}
731+
}
732+
714733
static std::string apply(
715734
const common_chat_template & tmpl,
716735
const struct templates_params & inputs,
@@ -1409,6 +1428,123 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
14091428
return data;
14101429
}
14111430

1431+
static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
1432+
common_chat_params data;
1433+
1434+
data.prompt = apply(tmpl, inputs);
1435+
data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
1436+
1437+
// Handle thinking tags appropriately based on inputs.enable_thinking
1438+
if (string_ends_with(data.prompt, "<think>\n")) {
1439+
if (!inputs.enable_thinking) {
1440+
data.prompt += "</think>";
1441+
} else {
1442+
data.thinking_forced_open = true;
1443+
}
1444+
}
1445+
1446+
data.preserved_tokens = {
1447+
"<think>",
1448+
"</think>",
1449+
"<tool_call>",
1450+
"</tool_call>",
1451+
};
1452+
1453+
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
1454+
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
1455+
auto include_grammar = true;
1456+
1457+
auto parser = build_chat_peg_constructed_parser([&](auto & p) {
1458+
auto reasoning = p.eps();
1459+
if (inputs.enable_thinking && extract_reasoning) {
1460+
auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
1461+
if (data.thinking_forced_open) {
1462+
reasoning = reasoning_content;
1463+
}
1464+
}
1465+
1466+
// Response format parser
1467+
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
1468+
return reasoning << p.content(p.schema(p.json(), "response-format", inputs.json_schema));
1469+
}
1470+
1471+
// Tool call parser
1472+
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
1473+
auto tool_choice = p.choice();
1474+
foreach_function(inputs.tools, [&](const json & tool) {
1475+
const auto & function = tool.at("function");
1476+
std::string name = function.at("name");
1477+
auto parameters = function.at("parameters");
1478+
1479+
auto schema_info = common_schema_info();
1480+
schema_info.resolve_refs(parameters);
1481+
1482+
auto tool_open = "<function=" + p.tool_name(p.literal(name)) + ">\n";
1483+
auto tool_close = p.literal("</function>\n");
1484+
auto args = p.sequence();
1485+
auto arg_string = p.rule("xml-arg-string", p.until_one_of({
1486+
"\n</parameter>",
1487+
"\n<parameter=",
1488+
"\n</function>"
1489+
}));
1490+
1491+
foreach_parameter(function, [&](const auto & param_name, const json & param_schema, bool is_required) {
1492+
auto rule_name = "tool-" + name + "-arg-" + param_name;
1493+
1494+
auto arg_open = "<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">\n";
1495+
auto arg_close = p.literal("</parameter>\n");
1496+
auto arg_value = p.eps();
1497+
1498+
if (schema_info.resolves_to_string(param_schema)) {
1499+
arg_value = p.tool_arg_string_value(arg_string) + "\n";
1500+
} else {
1501+
arg_value = p.tool_arg_json_value(p.schema(p.json(), rule_name + "-schema", param_schema));
1502+
}
1503+
1504+
// Model may or my not close with </parameter>
1505+
auto arg_rule = p.rule(rule_name, p.tool_arg_open(arg_open) + arg_value + p.optional(p.tool_arg_close(arg_close)));
1506+
args += p.repeat(arg_rule, /* min = */ is_required ? 1 : 0, /* max = */ 1);
1507+
});
1508+
1509+
tool_choice |= p.rule("tool-" + name, p.tool_open(tool_open) + args + p.tool_close(tool_close));
1510+
});
1511+
1512+
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
1513+
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
1514+
auto tool_call = p.rule("tool-call", "<tool_call>\n" + tool_choice + "</tool_call>" + p.space());
1515+
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
1516+
1517+
return reasoning << p.content(p.until("<tool_call>")) << tool_calls;
1518+
}
1519+
1520+
// Content only parser
1521+
include_grammar = false;
1522+
return reasoning << p.content(p.rest());
1523+
});
1524+
1525+
data.parser = parser.save();
1526+
1527+
if (include_grammar) {
1528+
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
1529+
1530+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1531+
foreach_function(inputs.tools, [&](const json & tool) {
1532+
const auto & function = tool.at("function");
1533+
auto schema = function.at("parameters");
1534+
builder.resolve_refs(schema);
1535+
});
1536+
parser.build_grammar(builder, data.grammar_lazy);
1537+
});
1538+
1539+
data.grammar_triggers = {
1540+
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"}
1541+
};
1542+
}
1543+
1544+
return data;
1545+
}
1546+
1547+
14121548
static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
14131549
common_chat_params data;
14141550

@@ -2534,6 +2670,10 @@ static common_chat_params common_chat_templates_apply_jinja(
25342670
src.find("<function=") != std::string::npos &&
25352671
src.find("<parameters>") != std::string::npos &&
25362672
src.find("<parameter=") != std::string::npos) {
2673+
// Nemotron 3 Nano 30B A3B
2674+
if (src.find("<think>") != std::string::npos) {
2675+
return common_chat_params_init_nemotron_v3(tmpl, params);
2676+
}
25372677
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
25382678
}
25392679

common/json-schema-to-grammar.cpp

Lines changed: 132 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,9 @@ static std::string format_literal(const std::string & literal) {
305305

306306
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
307307

308-
class SchemaConverter {
308+
class common_schema_converter {
309309
private:
310+
friend class common_schema_info;
310311
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
311312
std::function<json(const std::string &)> _fetch_json;
312313
bool _dotall;
@@ -729,7 +730,7 @@ class SchemaConverter {
729730
}
730731

731732
public:
732-
SchemaConverter(
733+
common_schema_converter(
733734
const std::function<json(const std::string &)> & fetch_json,
734735
bool dotall)
735736
: _fetch_json(fetch_json), _dotall(dotall)
@@ -990,6 +991,134 @@ class SchemaConverter {
990991
}
991992
};
992993

994+
// common_schema_info implementation (pimpl)
995+
996+
common_schema_info::common_schema_info()
997+
: impl_(std::make_unique<common_schema_converter>(
998+
[](const std::string &) { return json(); },
999+
false)) {}
1000+
1001+
common_schema_info::~common_schema_info() = default;
1002+
1003+
common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
1004+
common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;
1005+
1006+
void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
1007+
impl_->resolve_refs(schema, "");
1008+
}
1009+
1010+
// Determines if a JSON schema can resolve to a string type through any path.
1011+
// Some models emit raw string values rather than JSON-encoded strings for string parameters.
1012+
// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
1013+
// true, allowing callers to handle the value as a raw string for simplicity.
1014+
bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
1015+
std::unordered_set<std::string> visited_refs;
1016+
1017+
std::function<bool(const json &)> check = [&](const json & s) -> bool {
1018+
if (!s.is_object()) {
1019+
return false;
1020+
}
1021+
1022+
// Handle $ref
1023+
if (s.contains("$ref")) {
1024+
const std::string & ref = s["$ref"];
1025+
if (visited_refs.find(ref) != visited_refs.end()) {
1026+
// Circular reference, assume not a string to be safe
1027+
return false;
1028+
}
1029+
visited_refs.insert(ref);
1030+
auto it = impl_->_refs.find(ref);
1031+
if (it != impl_->_refs.end()) {
1032+
return check(it->second);
1033+
}
1034+
return false;
1035+
}
1036+
1037+
// Check type field
1038+
if (s.contains("type")) {
1039+
const json & schema_type = s["type"];
1040+
if (schema_type.is_string()) {
1041+
if (schema_type == "string") {
1042+
return true;
1043+
}
1044+
} else if (schema_type.is_array()) {
1045+
// Type can be an array like ["string", "null"]
1046+
for (const auto & t : schema_type) {
1047+
if (t == "string") {
1048+
return true;
1049+
}
1050+
}
1051+
}
1052+
}
1053+
1054+
// Check oneOf/anyOf - if any alternative can be a string
1055+
if (s.contains("oneOf")) {
1056+
for (const auto & alt : s["oneOf"]) {
1057+
if (check(alt)) {
1058+
return true;
1059+
}
1060+
}
1061+
}
1062+
if (s.contains("anyOf")) {
1063+
for (const auto & alt : s["anyOf"]) {
1064+
if (check(alt)) {
1065+
return true;
1066+
}
1067+
}
1068+
}
1069+
1070+
// Check allOf - all components must be compatible with string type
1071+
if (s.contains("allOf")) {
1072+
bool all_string = true;
1073+
for (const auto & component : s["allOf"]) {
1074+
if (!check(component)) {
1075+
all_string = false;
1076+
break;
1077+
}
1078+
}
1079+
if (all_string) {
1080+
return true;
1081+
}
1082+
}
1083+
1084+
// Check const - if the constant value is a string
1085+
if (s.contains("const")) {
1086+
if (s["const"].is_string()) {
1087+
return true;
1088+
}
1089+
}
1090+
1091+
// Check enum - if any enum value is a string
1092+
if (s.contains("enum")) {
1093+
for (const auto & val : s["enum"]) {
1094+
if (val.is_string()) {
1095+
return true;
1096+
}
1097+
}
1098+
}
1099+
1100+
// String-specific keywords imply string type
1101+
if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
1102+
return true;
1103+
}
1104+
1105+
// Check format - many formats imply string
1106+
if (s.contains("format")) {
1107+
const std::string & fmt = s["format"];
1108+
if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
1109+
fmt == "uri" || fmt == "email" || fmt == "hostname" ||
1110+
fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
1111+
fmt.find("uuid") == 0) {
1112+
return true;
1113+
}
1114+
}
1115+
1116+
return false;
1117+
};
1118+
1119+
return check(schema);
1120+
}
1121+
9931122
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
9941123
#ifdef LLAMA_USE_LLGUIDANCE
9951124
if (!force_gbnf) {
@@ -1006,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
10061135
}
10071136

10081137
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
1009-
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
1138+
common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
10101139
common_grammar_builder builder {
10111140
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
10121141
return converter._add_rule(name, rule);

0 commit comments

Comments
 (0)