Skip to content

Commit 9790faf

Browse files
sayanshaw24Sayan Shaw
andauthored
add qwen3 chat template support (#1001)
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
1 parent 43387b1 commit 9790faf

4 files changed

Lines changed: 757849 additions & 65 deletions

File tree

shared/api/minja.hpp

Lines changed: 72 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,9 +1666,9 @@ namespace minja
16661666
{
16671667
public:
16681668
DEFINE_EXPRESION_ID();
1669-
std::shared_ptr<Expression> start, end;
1670-
SliceExpr(const Location &location, std::shared_ptr<Expression> &&s, std::shared_ptr<Expression> &&e)
1671-
: Expression(location), start(std::move(s)), end(std::move(e)) {}
1669+
std::shared_ptr<Expression> start, end, step;
1670+
SliceExpr(const Location &location, std::shared_ptr<Expression> &&s, std::shared_ptr<Expression> &&e, std::shared_ptr<Expression> && st = nullptr)
1671+
: Expression(location), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
16721672
Value do_evaluate(const std::shared_ptr<Context> &) const override
16731673
{
16741674
throw std::runtime_error("SliceExpr not implemented");
@@ -1692,25 +1692,37 @@ namespace minja
16921692
auto target_value = base->evaluate(context);
16931693
if (auto slice = expr_cast<SliceExpr *>(index.get()))
16941694
{
1695-
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
1696-
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t)target_value.size();
1695+
auto len = target_value.size();
1696+
auto wrap = [len](int64_t i) -> int64_t {
1697+
if (i < 0) {
1698+
return i + len;
1699+
}
1700+
return i;
1701+
};
1702+
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
1703+
if (!step) {
1704+
throw std::runtime_error("slice step cannot be zero");
1705+
}
1706+
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
1707+
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
16971708
if (target_value.is_string())
16981709
{
16991710
std::string s = target_value.get<std::string>();
1700-
if (start < 0)
1701-
start = s.size() + start;
1702-
if (end < 0)
1703-
end = s.size() + end;
1704-
return s.substr(start, end - start);
1711+
1712+
std::string result;
1713+
if (start < end && step == 1) {
1714+
result = s.substr(start, end - start);
1715+
} else {
1716+
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
1717+
result += s[i];
1718+
}
1719+
}
1720+
return result;
17051721
}
17061722
else if (target_value.is_array())
17071723
{
1708-
if (start < 0)
1709-
start = target_value.size() + start;
1710-
if (end < 0)
1711-
end = target_value.size() + end;
17121724
auto result = Value::array();
1713-
for (auto i = start; i < end; ++i)
1725+
for (int64_t i = start; step > 0 ? i < end : i > end; i += step)
17141726
{
17151727
result.push_back(target_value.at(i));
17161728
}
@@ -2161,6 +2173,12 @@ namespace minja
21612173
auto suffix = vargs.args[0].get<std::string>();
21622174
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
21632175
}
2176+
else if (method->get_name() == "startswith")
2177+
{
2178+
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
2179+
auto prefix = vargs.args[0].get<std::string>();
2180+
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
2181+
}
21642182
else if (method->get_name() == "title")
21652183
{
21662184
vargs.expectArgs("title method", {0, 0}, {0, 0});
@@ -2916,67 +2934,56 @@ namespace minja
29162934

29172935
auto value = parseValue();
29182936

2919-
while (it != end && consumeSpaces() && peekSymbols({"[", "."}))
2920-
{
2921-
if (!consumeToken("[").empty())
2922-
{
2937+
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
2938+
if (!consumeToken("[").empty()) {
29232939
std::shared_ptr<Expression> index;
2924-
if (!consumeToken(":").empty())
2925-
{
2926-
auto slice_end = parseExpression();
2927-
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
2940+
auto slice_loc = get_location();
2941+
std::shared_ptr<Expression> start, end, step;
2942+
bool has_first_colon = false, has_second_colon = false;
2943+
2944+
if (!peekSymbols({ ":" })) {
2945+
start = parseExpression();
29282946
}
2929-
else
2930-
{
2931-
auto slice_start = parseExpression();
2932-
if (!consumeToken(":").empty())
2933-
{
2934-
consumeSpaces();
2935-
if (peekSymbols({"]"}))
2936-
{
2937-
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
2938-
}
2939-
else
2940-
{
2941-
auto slice_end = parseExpression();
2942-
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
2943-
}
2947+
2948+
if (!consumeToken(":").empty()) {
2949+
has_first_colon = true;
2950+
if (!peekSymbols({ ":", "]" })) {
2951+
end = parseExpression();
29442952
}
2945-
else
2946-
{
2947-
index = std::move(slice_start);
2953+
if (!consumeToken(":").empty()) {
2954+
has_second_colon = true;
2955+
if (!peekSymbols({ "]" })) {
2956+
step = parseExpression();
2957+
}
29482958
}
29492959
}
2950-
if (!index)
2951-
throw std::runtime_error("Empty index in subscript");
2952-
if (consumeToken("]").empty())
2953-
throw std::runtime_error("Expected closing bracket in subscript");
2960+
2961+
if ((has_first_colon || has_second_colon)) {
2962+
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
2963+
} else {
2964+
index = std::move(start);
2965+
}
2966+
if (!index) throw std::runtime_error("Empty index in subscript");
2967+
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
29542968

29552969
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
2956-
}
2957-
else if (!consumeToken(".").empty())
2958-
{
2959-
auto identifier = parseIdentifier();
2960-
if (!identifier)
2961-
throw std::runtime_error("Expected identifier in subscript");
2962-
2963-
consumeSpaces();
2964-
if (peekSymbols({"("}))
2965-
{
2966-
auto callParams = parseCallArgs();
2967-
value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
2968-
}
2969-
else
2970-
{
2971-
auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
2972-
value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
2973-
}
2970+
} else if (!consumeToken(".").empty()) {
2971+
auto identifier = parseIdentifier();
2972+
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
2973+
2974+
consumeSpaces();
2975+
if (peekSymbols({ "(" })) {
2976+
auto callParams = parseCallArgs();
2977+
value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
2978+
} else {
2979+
auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
2980+
value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
2981+
}
29742982
}
29752983
consumeSpaces();
29762984
}
29772985

2978-
if (peekSymbols({"("}))
2979-
{
2986+
if (peekSymbols({ "(" })) {
29802987
auto location = get_location();
29812988
auto callParams = parseCallArgs();
29822989
value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));

0 commit comments

Comments
 (0)