@@ -15,10 +15,34 @@ limitations under the License.
1515
1616#include " mposition.h"
1717
18+ #include < absl/strings/match.h>
19+
1820#include " framework/model/model_args.h"
1921#include " framework/request/sequence.h"
22+
2023namespace xllm {
2124
25+ namespace {
26+ std::vector<std::tuple<std::string, int , int >> groupByTokenType (
27+ const std::vector<std::string>& token_types) {
28+ std::vector<std::tuple<std::string, int , int >> groups;
29+ if (token_types.empty ()) return groups;
30+
31+ std::string current_key = token_types[0 ];
32+ int start = 0 ;
33+
34+ for (int i = 1 ; i < token_types.size (); ++i) {
35+ if (token_types[i] != current_key) {
36+ groups.emplace_back (current_key, start, i);
37+ current_key = token_types[i];
38+ start = i;
39+ }
40+ }
41+ groups.emplace_back (current_key, start, static_cast <int >(token_types.size ()));
42+ return groups;
43+ }
44+ } // namespace
45+
2246torch::Tensor MPositionHelper::get_positions () {
2347 // if (seq_.is_chunked_prefill_stage()) {
2448 if (seq_.kv_state ().kv_cache_tokens_num () < seq_.num_prompt_tokens ()) {
@@ -35,16 +59,128 @@ torch::Tensor MPositionHelper::get_positions() {
3559 torch::Tensor second_per_grid_ts;
3660 if (auto res = mm_data.get <torch::Tensor>(" second_per_grid_ts" ))
3761 second_per_grid_ts = res.value ();
38- auto res =
39- get_positions_p (image_grid_thw, video_grid_thw, second_per_grid_ts);
62+ std::tuple<torch::Tensor, int > res;
63+ if (!absl::StartsWith (args_.model_type (), " glm4v" )) {
64+ res = get_positions_p (image_grid_thw, video_grid_thw, second_per_grid_ts);
65+ } else {
66+ res = get_positions_glm (image_grid_thw, video_grid_thw);
67+ }
4068 seq_.set_mrope_position_delta (std::get<1 >(res));
41-
4269 return std::get<0 >(res);
4370 } else {
4471 return get_positions_d ();
4572 }
4673}
4774
75+ std::tuple<torch::Tensor, int > MPositionHelper::get_positions_glm (
76+ torch::Tensor image_grid_thw,
77+ torch::Tensor video_grid_thw) {
78+ auto input_tokens = seq_.tokens ();
79+ auto spatial_merge_size = args_.mm_spatial_merge_size ();
80+ auto image_token_id = args_.image_token_id ();
81+ auto video_token_id = args_.video_token_id ();
82+ auto video_start_token_id = args_.video_start_token_id ();
83+ auto video_end_token_id = args_.video_end_token_id ();
84+
85+ auto dtype = torch::kInt32 ;
86+
87+ std::vector<std::string> input_token_type;
88+ bool in_video = false ;
89+ int num_tokens = input_tokens.size ();
90+
91+ for (int index = 0 ; index < num_tokens; ++index) {
92+ auto token = input_tokens[index];
93+ if (token == video_start_token_id) {
94+ in_video = true ;
95+ } else if (token == video_end_token_id) {
96+ in_video = false ;
97+ }
98+
99+ if (token == image_token_id && !in_video) {
100+ input_token_type.push_back (" image" );
101+ } else if (token == image_token_id && in_video) {
102+ input_token_type.push_back (" video" );
103+ } else {
104+ input_token_type.push_back (" text" );
105+ }
106+ }
107+ auto input_type_group = groupByTokenType (input_token_type);
108+ int image_index = 0 ;
109+ int video_index = 0 ;
110+ int video_group_index = 0 ;
111+
112+ std::vector<torch::Tensor> llm_pos_ids_list;
113+ int video_frame_num = 1 ;
114+ for (const auto & group : input_type_group) {
115+ const auto & modality_type = std::get<0 >(group);
116+ int start_idx = std::get<1 >(group);
117+ int end_idx = std::get<2 >(group);
118+ int st_idx = 0 ;
119+ if (!llm_pos_ids_list.empty ()) {
120+ st_idx = llm_pos_ids_list.back ().max ().item <int >() + 1 ;
121+ }
122+
123+ if (modality_type == " image" ) {
124+ auto grid = image_grid_thw[image_index];
125+ int t = grid[0 ].item <int >();
126+ int h = grid[1 ].item <int >() / spatial_merge_size;
127+ int w = grid[2 ].item <int >() / spatial_merge_size;
128+
129+ auto t_arange =
130+ torch::arange (t, dtype).view ({-1 , 1 }).expand ({-1 , h * w}).flatten ();
131+ auto h_arange =
132+ torch::arange (h, dtype).view ({1 , -1 , 1 }).expand ({t, -1 , w}).flatten ();
133+ auto w_arange =
134+ torch::arange (w, dtype).view ({1 , 1 , -1 }).expand ({t, h, -1 }).flatten ();
135+
136+ auto pos = torch::stack ({t_arange, h_arange, w_arange}) + st_idx;
137+ llm_pos_ids_list.push_back (pos);
138+ video_frame_num = 1 ;
139+ image_index++;
140+ } else if (modality_type == " video" ) {
141+ int t = video_frame_num;
142+ int h = video_grid_thw[video_index][1 ].item <int >() / spatial_merge_size;
143+ int w = video_grid_thw[video_index][2 ].item <int >() / spatial_merge_size;
144+
145+ for (int t_idx = 0 ; t_idx < t; ++t_idx) {
146+ auto t_tensor = torch::full ({1 , h * w}, t_idx, dtype).flatten ();
147+ auto h_tensor = torch::arange (h, dtype)
148+ .view ({1 , -1 , 1 })
149+ .expand ({1 , -1 , w})
150+ .flatten ();
151+ auto w_tensor = torch::arange (w, dtype)
152+ .view ({1 , 1 , -1 })
153+ .expand ({1 , h, -1 })
154+ .flatten ();
155+
156+ auto pos = torch::stack ({t_tensor, h_tensor, w_tensor}) + st_idx;
157+ llm_pos_ids_list.push_back (pos);
158+ }
159+
160+ video_group_index++;
161+ if (video_group_index >= video_grid_thw[video_index][0 ].item <int >()) {
162+ video_index++;
163+ video_group_index = 0 ;
164+ }
165+ video_frame_num++;
166+ } else { // text
167+ int text_len = end_idx - start_idx;
168+ auto arange =
169+ torch::arange (text_len, dtype).view ({1 , -1 }).expand ({3 , -1 }) + st_idx;
170+ llm_pos_ids_list.push_back (arange);
171+ video_frame_num = 1 ;
172+ }
173+ }
174+
175+ torch::Tensor llm_positions =
176+ torch::cat (llm_pos_ids_list, /* dim=*/ 1 ).reshape ({3 , -1 });
177+ llm_positions = llm_positions;
178+ int mrope_position_delta =
179+ (llm_positions.max ().item <int >() + 1 - input_tokens.size ());
180+
181+ return std::make_pair (llm_positions, mrope_position_delta);
182+ }
183+
48184std::tuple<torch::Tensor, int > MPositionHelper::get_positions_p (
49185 torch::Tensor image_grid_thw,
50186 torch::Tensor video_grid_thw,
0 commit comments