Skip to content

Commit afc7ffd

Browse files
pwilkinCISCaldehir
authored andcommitted
common/parser: handle reasoning budget (ggml-org#20297)
* v1 * Finished! * Handlie cli * Reasoning sampler * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Less explosive terminology :) * Add utf-8 case and tests * common : migrate reasoning budget sampler to common * cont : clean up * cont : expose state and allow passing as initial state * cont : remove unused imports * cont : update state machine doc string --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Alde Rojas <hello@alde.dev>
1 parent e35a39e commit afc7ffd

18 files changed

Lines changed: 670 additions & 10 deletions

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
8181
preset.cpp
8282
preset.h
8383
regex-partial.cpp
84+
reasoning-budget.cpp
85+
reasoning-budget.h
8486
regex-partial.h
8587
sampling.cpp
8688
sampling.h

common/arg.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
29132913
[](common_params & params, const std::string & value) {
29142914
auto parsed = json::parse(value);
29152915
for (const auto & item : parsed.items()) {
2916+
if (item.key() == "enable_thinking") {
2917+
LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. "
2918+
"Use --reasoning on / --reasoning off instead.\n");
2919+
}
29162920
params.default_template_kwargs[item.key()] = item.value().dump();
29172921
}
29182922
}
@@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
30483052
params.reasoning_format = common_reasoning_format_from_name(value);
30493053
}
30503054
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK"));
3055+
add_opt(common_arg(
3056+
{"-rea", "--reasoning"}, "[on|off|auto]",
3057+
"Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))",
3058+
[](common_params & params, const std::string & value) {
3059+
if (is_truthy(value)) {
3060+
params.enable_reasoning = 1;
3061+
params.default_template_kwargs["enable_thinking"] = "true";
3062+
} else if (is_falsey(value)) {
3063+
params.enable_reasoning = 0;
3064+
params.default_template_kwargs["enable_thinking"] = "false";
3065+
} else if (is_autoy(value)) {
3066+
params.enable_reasoning = -1;
3067+
} else {
3068+
throw std::invalid_argument(
3069+
string_format("error: unknown value for --reasoning: '%s'\n", value.c_str()));
3070+
}
3071+
}
3072+
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING"));
30513073
add_opt(common_arg(
30523074
{"--reasoning-budget"}, "N",
3053-
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
3075+
"token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
30543076
[](common_params & params, int value) {
3055-
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
3077+
if (value < -1) { throw std::invalid_argument("invalid value"); }
30563078
params.reasoning_budget = value;
30573079
}
30583080
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
3081+
add_opt(common_arg(
3082+
{"--reasoning-budget-message"}, "MESSAGE",
3083+
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
3084+
[](common_params & params, const std::string & value) {
3085+
params.reasoning_budget_message = value;
3086+
}
3087+
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
30593088
add_opt(common_arg(
30603089
{"--chat-template"}, "JINJA_TEMPLATE",
30613090
string_format(

common/chat-auto-parser-generator.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
135135
if (thinking_forced_open || thinking_forced_closed) {
136136
// Thinking is forced open OR forced closed with enable_thinking=true
137137
// In both cases, expect only the closing tag (opening was in template)
138-
return p.reasoning(p.until(end)) + end;
138+
// However, since we might have incorrectly detected the open/close pattern,
139+
// we admit an optional starting marker
140+
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
139141
}
140142
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
141143
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)

common/chat.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
857857
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
858858
auto include_grammar = true;
859859

860-
data.supports_thinking = true;
860+
data.supports_thinking = true;
861+
data.thinking_start_tag = "[THINK]";
862+
data.thinking_end_tag = "[/THINK]";
861863
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
862864
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
863865
data.preserved_tokens = {
@@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
11651167
const autoparser::templates_params & inputs) {
11661168
common_chat_params data;
11671169

1168-
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
1169-
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
1170-
data.supports_thinking = true;
1170+
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
1171+
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
1172+
data.supports_thinking = true;
1173+
data.thinking_start_tag = "<think>";
1174+
data.thinking_end_tag = "</think>";
11711175
data.preserved_tokens = {
11721176
"<|tool_calls_section_begin|>",
11731177
"<|tool_calls_section_end|>",
@@ -1527,6 +1531,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
15271531
autoparser.analyze_template(tmpl);
15281532
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
15291533
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
1534+
if (auto_params.supports_thinking) {
1535+
auto_params.thinking_start_tag = autoparser.reasoning.start;
1536+
auto_params.thinking_end_tag = autoparser.reasoning.end;
1537+
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
1538+
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
1539+
// but forces <think> open when thinking is enabled)
1540+
auto_params.thinking_forced_open =
1541+
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
1542+
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
1543+
}
15301544
return auto_params;
15311545
} catch (const std::exception & e) {
15321546
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());

common/chat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ struct common_chat_params {
213213
bool grammar_lazy = false;
214214
bool thinking_forced_open = false;
215215
bool supports_thinking = false;
216+
std::string thinking_start_tag; // e.g., "<think>"
217+
std::string thinking_end_tag; // e.g., "</think>"
216218
std::vector<common_grammar_trigger> grammar_triggers;
217219
std::vector<std::string> preserved_tokens;
218220
std::vector<std::string> additional_stops;

common/common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,14 @@ struct common_params_sampling {
235235
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
236236
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
237237

238+
// reasoning budget sampler parameters
239+
// these are populated by the server/CLI based on chat template params
240+
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
241+
bool reasoning_budget_activate_immediately = false;
242+
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
243+
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
244+
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
245+
238246
bool backend_sampling = false;
239247

240248
bool has_logit_bias() const {
@@ -536,7 +544,9 @@ struct common_params {
536544
bool use_jinja = true; // NOLINT
537545
bool enable_chat_template = true;
538546
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
547+
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
539548
int reasoning_budget = -1;
549+
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
540550
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
541551
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
542552

common/reasoning-budget.cpp

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#include "reasoning-budget.h"
2+
#include "common.h"
3+
#include "unicode.h"
4+
5+
#include "log.h"
6+
7+
#include <cmath>
8+
#include <cstdint>
9+
#include <string>
10+
#include <vector>
11+
12+
struct token_matcher {
13+
std::vector<llama_token> tokens;
14+
size_t pos = 0;
15+
16+
bool advance(llama_token token) {
17+
if (tokens.empty()) {
18+
return false;
19+
}
20+
21+
if (token == tokens[pos]) {
22+
pos++;
23+
if (pos >= tokens.size()) {
24+
pos = 0;
25+
return true;
26+
}
27+
} else {
28+
pos = 0;
29+
if (token == tokens[0]) {
30+
pos = 1;
31+
}
32+
}
33+
return false;
34+
}
35+
36+
void reset() { pos = 0; }
37+
};
38+
39+
struct common_reasoning_budget_ctx {
40+
const llama_vocab * vocab;
41+
42+
token_matcher start_matcher;
43+
token_matcher end_matcher;
44+
std::vector<llama_token> forced_tokens;
45+
46+
int32_t budget; // maximum tokens in reasoning block
47+
int32_t remaining; // tokens remaining in budget
48+
49+
common_reasoning_budget_state state;
50+
51+
// for forcing
52+
size_t force_pos; // next position in forced_tokens to force
53+
};
54+
55+
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
56+
return "reasoning-budget";
57+
}
58+
59+
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
60+
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
61+
62+
switch (ctx->state) {
63+
case REASONING_BUDGET_IDLE:
64+
{
65+
if (ctx->start_matcher.advance(token)) {
66+
ctx->state = REASONING_BUDGET_COUNTING;
67+
ctx->remaining = ctx->budget;
68+
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
69+
70+
if (ctx->remaining <= 0) {
71+
ctx->state = REASONING_BUDGET_FORCING;
72+
ctx->force_pos = 0;
73+
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
74+
}
75+
}
76+
break;
77+
}
78+
case REASONING_BUDGET_COUNTING:
79+
case REASONING_BUDGET_WAITING_UTF8:
80+
{
81+
if (ctx->end_matcher.advance(token)) {
82+
ctx->state = REASONING_BUDGET_DONE;
83+
LOG_INF("reasoning-budget: deactivated (natural end)\n");
84+
break;
85+
}
86+
87+
bool utf8_complete = true;
88+
if (ctx->vocab != nullptr) {
89+
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
90+
utf8_complete = common_utf8_is_complete(piece);
91+
}
92+
93+
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
94+
if (utf8_complete) {
95+
ctx->state = REASONING_BUDGET_FORCING;
96+
ctx->force_pos = 0;
97+
ctx->end_matcher.reset();
98+
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
99+
}
100+
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
101+
ctx->remaining--;
102+
if (ctx->remaining <= 0) {
103+
if (utf8_complete) {
104+
ctx->state = REASONING_BUDGET_FORCING;
105+
ctx->force_pos = 0;
106+
ctx->end_matcher.reset();
107+
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
108+
} else {
109+
ctx->state = REASONING_BUDGET_WAITING_UTF8;
110+
ctx->end_matcher.reset();
111+
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
112+
}
113+
}
114+
}
115+
break;
116+
}
117+
case REASONING_BUDGET_FORCING:
118+
// force_pos is advanced in apply(), not here.
119+
// This ensures the first forced token isn't skipped when the sampler
120+
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
121+
break;
122+
case REASONING_BUDGET_DONE:
123+
break;
124+
}
125+
}
126+
127+
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
128+
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
129+
130+
if (ctx->state != REASONING_BUDGET_FORCING) {
131+
// passthrough — don't modify logits
132+
return;
133+
}
134+
135+
if (ctx->force_pos >= ctx->forced_tokens.size()) {
136+
return;
137+
}
138+
139+
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
140+
141+
// set all logits to -inf except the forced token
142+
for (size_t i = 0; i < cur_p->size; i++) {
143+
if (cur_p->data[i].id != forced) {
144+
cur_p->data[i].logit = -INFINITY;
145+
}
146+
}
147+
148+
// advance to next forced token (done here rather than in accept so that
149+
// the first forced token isn't skipped when starting in FORCING state)
150+
ctx->force_pos++;
151+
if (ctx->force_pos >= ctx->forced_tokens.size()) {
152+
ctx->state = REASONING_BUDGET_DONE;
153+
LOG_INF("reasoning-budget: forced sequence complete, done\n");
154+
}
155+
}
156+
157+
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
158+
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
159+
ctx->state = REASONING_BUDGET_IDLE;
160+
ctx->remaining = ctx->budget;
161+
ctx->start_matcher.reset();
162+
ctx->end_matcher.reset();
163+
ctx->force_pos = 0;
164+
}
165+
166+
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
167+
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
168+
return common_reasoning_budget_init(
169+
ctx->vocab,
170+
ctx->start_matcher.tokens,
171+
ctx->end_matcher.tokens,
172+
ctx->forced_tokens,
173+
ctx->budget,
174+
ctx->state);
175+
}
176+
177+
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
178+
delete (common_reasoning_budget_ctx *) smpl->ctx;
179+
}
180+
181+
static struct llama_sampler_i common_reasoning_budget_i = {
182+
/* .name = */ common_reasoning_budget_name,
183+
/* .accept = */ common_reasoning_budget_accept,
184+
/* .apply = */ common_reasoning_budget_apply,
185+
/* .reset = */ common_reasoning_budget_reset,
186+
/* .clone = */ common_reasoning_budget_clone,
187+
/* .free = */ common_reasoning_budget_free,
188+
/* .backend_init = */ nullptr,
189+
/* .backend_accept = */ nullptr,
190+
/* .backend_apply = */ nullptr,
191+
/* .backend_set_input = */ nullptr,
192+
};
193+
194+
struct llama_sampler * common_reasoning_budget_init(
195+
const struct llama_vocab * vocab,
196+
const std::vector<llama_token> & start_tokens,
197+
const std::vector<llama_token> & end_tokens,
198+
const std::vector<llama_token> & forced_tokens,
199+
int32_t budget,
200+
common_reasoning_budget_state initial_state) {
201+
// promote COUNTING with budget <= 0 to FORCING
202+
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
203+
initial_state = REASONING_BUDGET_FORCING;
204+
}
205+
206+
return llama_sampler_init(
207+
/* .iface = */ &common_reasoning_budget_i,
208+
/* .ctx = */ new common_reasoning_budget_ctx {
209+
/* .vocab = */ vocab,
210+
/* .start_matcher = */ { start_tokens, 0 },
211+
/* .end_matcher = */ { end_tokens, 0 },
212+
/* .forced_tokens = */ forced_tokens,
213+
/* .budget = */ budget,
214+
/* .remaining = */ budget,
215+
/* .state = */ initial_state,
216+
/* .force_pos = */ 0,
217+
}
218+
);
219+
}

0 commit comments

Comments
 (0)