Skip to content

Commit b8c4168

Browse files
wly-115JimHsiung
authored andcommitted
feat: support mpositions for glm4v.
1 parent 6a13161 commit b8c4168

File tree

2 files changed

+143
-3
lines changed

2 files changed

+143
-3
lines changed

xllm/core/framework/batch/mposition.cpp

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,34 @@ limitations under the License.
1515

1616
#include "mposition.h"
1717

18+
#include <absl/strings/match.h>
19+
1820
#include "framework/model/model_args.h"
1921
#include "framework/request/sequence.h"
22+
2023
namespace xllm {
2124

25+
namespace {
26+
std::vector<std::tuple<std::string, int, int>> groupByTokenType(
27+
const std::vector<std::string>& token_types) {
28+
std::vector<std::tuple<std::string, int, int>> groups;
29+
if (token_types.empty()) return groups;
30+
31+
std::string current_key = token_types[0];
32+
int start = 0;
33+
34+
for (int i = 1; i < token_types.size(); ++i) {
35+
if (token_types[i] != current_key) {
36+
groups.emplace_back(current_key, start, i);
37+
current_key = token_types[i];
38+
start = i;
39+
}
40+
}
41+
groups.emplace_back(current_key, start, static_cast<int>(token_types.size()));
42+
return groups;
43+
}
44+
} // namespace
45+
2246
torch::Tensor MPositionHelper::get_positions() {
2347
// if (seq_.is_chunked_prefill_stage()) {
2448
if (seq_.kv_state().kv_cache_tokens_num() < seq_.num_prompt_tokens()) {
@@ -35,16 +59,128 @@ torch::Tensor MPositionHelper::get_positions() {
3559
torch::Tensor second_per_grid_ts;
3660
if (auto res = mm_data.get<torch::Tensor>("second_per_grid_ts"))
3761
second_per_grid_ts = res.value();
38-
auto res =
39-
get_positions_p(image_grid_thw, video_grid_thw, second_per_grid_ts);
62+
std::tuple<torch::Tensor, int> res;
63+
if (!absl::StartsWith(args_.model_type(), "glm4v")) {
64+
res = get_positions_p(image_grid_thw, video_grid_thw, second_per_grid_ts);
65+
} else {
66+
res = get_positions_glm(image_grid_thw, video_grid_thw);
67+
}
4068
seq_.set_mrope_position_delta(std::get<1>(res));
41-
4269
return std::get<0>(res);
4370
} else {
4471
return get_positions_d();
4572
}
4673
}
4774

75+
std::tuple<torch::Tensor, int> MPositionHelper::get_positions_glm(
76+
torch::Tensor image_grid_thw,
77+
torch::Tensor video_grid_thw) {
78+
auto input_tokens = seq_.tokens();
79+
auto spatial_merge_size = args_.mm_spatial_merge_size();
80+
auto image_token_id = args_.image_token_id();
81+
auto video_token_id = args_.video_token_id();
82+
auto video_start_token_id = args_.video_start_token_id();
83+
auto video_end_token_id = args_.video_end_token_id();
84+
85+
auto dtype = torch::kInt32;
86+
87+
std::vector<std::string> input_token_type;
88+
bool in_video = false;
89+
int num_tokens = input_tokens.size();
90+
91+
for (int index = 0; index < num_tokens; ++index) {
92+
auto token = input_tokens[index];
93+
if (token == video_start_token_id) {
94+
in_video = true;
95+
} else if (token == video_end_token_id) {
96+
in_video = false;
97+
}
98+
99+
if (token == image_token_id && !in_video) {
100+
input_token_type.push_back("image");
101+
} else if (token == image_token_id && in_video) {
102+
input_token_type.push_back("video");
103+
} else {
104+
input_token_type.push_back("text");
105+
}
106+
}
107+
auto input_type_group = groupByTokenType(input_token_type);
108+
int image_index = 0;
109+
int video_index = 0;
110+
int video_group_index = 0;
111+
112+
std::vector<torch::Tensor> llm_pos_ids_list;
113+
int video_frame_num = 1;
114+
for (const auto& group : input_type_group) {
115+
const auto& modality_type = std::get<0>(group);
116+
int start_idx = std::get<1>(group);
117+
int end_idx = std::get<2>(group);
118+
int st_idx = 0;
119+
if (!llm_pos_ids_list.empty()) {
120+
st_idx = llm_pos_ids_list.back().max().item<int>() + 1;
121+
}
122+
123+
if (modality_type == "image") {
124+
auto grid = image_grid_thw[image_index];
125+
int t = grid[0].item<int>();
126+
int h = grid[1].item<int>() / spatial_merge_size;
127+
int w = grid[2].item<int>() / spatial_merge_size;
128+
129+
auto t_arange =
130+
torch::arange(t, dtype).view({-1, 1}).expand({-1, h * w}).flatten();
131+
auto h_arange =
132+
torch::arange(h, dtype).view({1, -1, 1}).expand({t, -1, w}).flatten();
133+
auto w_arange =
134+
torch::arange(w, dtype).view({1, 1, -1}).expand({t, h, -1}).flatten();
135+
136+
auto pos = torch::stack({t_arange, h_arange, w_arange}) + st_idx;
137+
llm_pos_ids_list.push_back(pos);
138+
video_frame_num = 1;
139+
image_index++;
140+
} else if (modality_type == "video") {
141+
int t = video_frame_num;
142+
int h = video_grid_thw[video_index][1].item<int>() / spatial_merge_size;
143+
int w = video_grid_thw[video_index][2].item<int>() / spatial_merge_size;
144+
145+
for (int t_idx = 0; t_idx < t; ++t_idx) {
146+
auto t_tensor = torch::full({1, h * w}, t_idx, dtype).flatten();
147+
auto h_tensor = torch::arange(h, dtype)
148+
.view({1, -1, 1})
149+
.expand({1, -1, w})
150+
.flatten();
151+
auto w_tensor = torch::arange(w, dtype)
152+
.view({1, 1, -1})
153+
.expand({1, h, -1})
154+
.flatten();
155+
156+
auto pos = torch::stack({t_tensor, h_tensor, w_tensor}) + st_idx;
157+
llm_pos_ids_list.push_back(pos);
158+
}
159+
160+
video_group_index++;
161+
if (video_group_index >= video_grid_thw[video_index][0].item<int>()) {
162+
video_index++;
163+
video_group_index = 0;
164+
}
165+
video_frame_num++;
166+
} else { // text
167+
int text_len = end_idx - start_idx;
168+
auto arange =
169+
torch::arange(text_len, dtype).view({1, -1}).expand({3, -1}) + st_idx;
170+
llm_pos_ids_list.push_back(arange);
171+
video_frame_num = 1;
172+
}
173+
}
174+
175+
torch::Tensor llm_positions =
176+
torch::cat(llm_pos_ids_list, /*dim=*/1).reshape({3, -1});
177+
llm_positions = llm_positions;
178+
int mrope_position_delta =
179+
(llm_positions.max().item<int>() + 1 - input_tokens.size());
180+
181+
return std::make_pair(llm_positions, mrope_position_delta);
182+
}
183+
48184
std::tuple<torch::Tensor, int> MPositionHelper::get_positions_p(
49185
torch::Tensor image_grid_thw,
50186
torch::Tensor video_grid_thw,

xllm/core/framework/batch/mposition.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ class MPositionHelper {
3737
torch::Tensor image_grid_thw,
3838
torch::Tensor video_grid_thw,
3939
torch::Tensor second_per_grid_ts);
40+
std::tuple<torch::Tensor, int> get_positions_glm(
41+
torch::Tensor image_grid_thw,
42+
torch::Tensor video_grid_thw);
43+
4044
torch::Tensor get_positions_d();
4145

4246
private:

0 commit comments

Comments
 (0)