Skip to content

Commit 8a2110c

Browse files
bugfix: fix the issue of missing MMData input during engine -> worker transfer via brpc format. (#501)
1 parent 6e62d05 commit 8a2110c

File tree

5 files changed

+495
-166
lines changed

5 files changed

+495
-166
lines changed

xllm/core/framework/request/dit_request_params.cpp

Lines changed: 7 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include "butil/base64.h"
2020
#include "core/common/instance_name.h"
2121
#include "core/common/macros.h"
22+
#include "core/util/utils.h"
2223
#include "core/util/uuid.h"
2324
#include "mm_codec.h"
2425
#include "request.h"
@@ -34,165 +35,6 @@ std::string generate_image_generation_request_id() {
3435

3536
} // namespace
3637

37-
static torch::ScalarType datatype_proto_to_torch(
38-
const std::string& proto_datatype) {
39-
static const std::unordered_map<std::string, torch::ScalarType> kDatatypeMap =
40-
{{"BOOL", torch::kBool},
41-
{"INT32", torch::kInt},
42-
{"INT64", torch::kLong},
43-
{"UINT32", torch::kInt32},
44-
{"UINT64", torch::kInt64},
45-
{"FP32", torch::kFloat},
46-
{"FP64", torch::kDouble},
47-
{"BYTES", torch::kByte}};
48-
49-
auto iter = kDatatypeMap.find(proto_datatype);
50-
if (iter == kDatatypeMap.end()) {
51-
LOG(FATAL)
52-
<< "Unsupported proto datatype: " << proto_datatype
53-
<< " (supported types: BOOL/INT32/INT64/UINT32/UINT64/FP32/FP64/BYTES)";
54-
}
55-
return iter->second;
56-
}
57-
58-
template <typename T>
59-
static const void* get_data_from_contents(const proto::TensorContents& contents,
60-
const std::string& datatype) {
61-
if constexpr (std::is_same_v<T, bool>) {
62-
if (contents.bool_contents().empty()) {
63-
LOG(ERROR) << "TensorContents.bool_contents is empty (datatype="
64-
<< datatype << ")";
65-
return nullptr;
66-
}
67-
return contents.bool_contents().data();
68-
} else if constexpr (std::is_same_v<T, int32_t>) {
69-
if (contents.int_contents().empty()) {
70-
LOG(ERROR) << "TensorContents.int_contents is empty (datatype="
71-
<< datatype << ")";
72-
return nullptr;
73-
}
74-
return contents.int_contents().data();
75-
} else if constexpr (std::is_same_v<T, int64_t>) {
76-
if (contents.int64_contents().empty()) {
77-
LOG(ERROR) << "TensorContents.int64_contents is empty (datatype="
78-
<< datatype << ")";
79-
return nullptr;
80-
}
81-
return contents.int64_contents().data();
82-
} else if constexpr (std::is_same_v<T, uint32_t>) {
83-
if (contents.uint_contents().empty()) {
84-
LOG(ERROR) << "TensorContents.uint_contents is empty (datatype="
85-
<< datatype << ")";
86-
return nullptr;
87-
}
88-
return contents.uint_contents().data();
89-
} else if constexpr (std::is_same_v<T, uint64_t>) {
90-
if (contents.uint64_contents().empty()) {
91-
LOG(ERROR) << "TensorContents.uint64_contents is empty (datatype="
92-
<< datatype << ")";
93-
return nullptr;
94-
}
95-
return contents.uint64_contents().data();
96-
} else if constexpr (std::is_same_v<T, float>) {
97-
if (contents.fp32_contents().empty()) {
98-
LOG(ERROR) << "TensorContents.fp32_contents is empty (datatype="
99-
<< datatype << ")";
100-
return nullptr;
101-
}
102-
return contents.fp32_contents().data();
103-
} else if constexpr (std::is_same_v<T, double>) {
104-
if (contents.fp64_contents().empty()) {
105-
LOG(ERROR) << "TensorContents.fp64_contents is empty (datatype="
106-
<< datatype << ")";
107-
return nullptr;
108-
}
109-
return contents.fp64_contents().data();
110-
} else {
111-
LOG(FATAL) << "Unsupported data type for TensorContents: "
112-
<< typeid(T).name();
113-
return nullptr;
114-
}
115-
}
116-
117-
torch::Tensor proto_to_torch(const proto::Tensor& proto_tensor) {
118-
if (proto_tensor.datatype().empty()) {
119-
LOG(ERROR) << "Proto Tensor missing required field: datatype (e.g., "
120-
"\"FP32\", \"INT64\")";
121-
return torch::Tensor();
122-
}
123-
if (proto_tensor.shape().empty()) {
124-
LOG(ERROR) << "Proto Tensor has empty shape (invalid tensor)";
125-
return torch::Tensor();
126-
}
127-
if (!proto_tensor.has_contents()) {
128-
LOG(ERROR)
129-
<< "Proto Tensor missing required field: contents (TensorContents)";
130-
return torch::Tensor();
131-
}
132-
const auto& proto_contents = proto_tensor.contents();
133-
134-
const std::string& proto_datatype = proto_tensor.datatype();
135-
torch::ScalarType torch_dtype = datatype_proto_to_torch(proto_datatype);
136-
const size_t element_size = torch::elementSize(torch_dtype);
137-
138-
std::vector<int64_t> torch_shape;
139-
int64_t total_elements = 1;
140-
for (const auto& dim : proto_tensor.shape()) {
141-
if (dim <= 0) {
142-
LOG(ERROR) << "Proto Tensor has invalid dimension: " << dim
143-
<< " (must be positive, datatype=" << proto_datatype << ")";
144-
return torch::Tensor();
145-
}
146-
torch_shape.emplace_back(dim);
147-
total_elements *= dim;
148-
}
149-
torch::IntArrayRef tensor_shape(torch_shape);
150-
151-
const void* data_ptr = nullptr;
152-
size_t data_count = 0;
153-
if (proto_datatype == "BOOL") {
154-
data_ptr = get_data_from_contents<bool>(proto_contents, proto_datatype);
155-
data_count = proto_contents.bool_contents_size();
156-
} else if (proto_datatype == "INT32") {
157-
data_ptr = get_data_from_contents<int32_t>(proto_contents, proto_datatype);
158-
data_count = proto_contents.int_contents_size();
159-
} else if (proto_datatype == "INT64") {
160-
data_ptr = get_data_from_contents<int64_t>(proto_contents, proto_datatype);
161-
data_count = proto_contents.int64_contents_size();
162-
} else if (proto_datatype == "UINT32") {
163-
data_ptr = get_data_from_contents<uint32_t>(proto_contents, proto_datatype);
164-
data_count = proto_contents.uint_contents_size();
165-
} else if (proto_datatype == "UINT64") {
166-
data_ptr = get_data_from_contents<uint64_t>(proto_contents, proto_datatype);
167-
data_count = proto_contents.uint64_contents_size();
168-
} else if (proto_datatype == "FP32") {
169-
data_ptr = get_data_from_contents<float>(proto_contents, proto_datatype);
170-
data_count = proto_contents.fp32_contents_size();
171-
} else if (proto_datatype == "FP64") {
172-
data_ptr = get_data_from_contents<double>(proto_contents, proto_datatype);
173-
data_count = proto_contents.fp64_contents_size();
174-
}
175-
176-
if (data_ptr == nullptr) {
177-
LOG(ERROR) << "Failed to get data from TensorContents (datatype="
178-
<< proto_datatype << ")";
179-
return torch::Tensor();
180-
}
181-
if (data_count != static_cast<size_t>(total_elements)) {
182-
LOG(ERROR) << "Proto Tensor data count mismatch (datatype="
183-
<< proto_datatype << "): "
184-
<< "expected " << total_elements
185-
<< " elements (shape=" << tensor_shape << "), "
186-
<< "got " << data_count << " elements";
187-
return torch::Tensor();
188-
}
189-
190-
torch::Tensor tensor =
191-
torch::from_blob(const_cast<void*>(data_ptr), tensor_shape, torch_dtype)
192-
.clone();
193-
return tensor;
194-
}
195-
19638
std::pair<int, int> splitResolution(const std::string& s) {
19739
size_t pos = s.find('*');
19840
int width = std::stoi(s.substr(0, pos));
@@ -227,26 +69,26 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
22769
}
22870

22971
if (input.has_prompt_embed()) {
230-
input_params.prompt_embed = proto_to_torch(input.prompt_embed());
72+
input_params.prompt_embed = util::proto_to_torch(input.prompt_embed());
23173
}
23274
if (input.has_pooled_prompt_embed()) {
23375
input_params.pooled_prompt_embed =
234-
proto_to_torch(input.pooled_prompt_embed());
76+
util::proto_to_torch(input.pooled_prompt_embed());
23577
}
23678
if (input.has_negative_prompt_embed()) {
23779
input_params.negative_prompt_embed =
238-
proto_to_torch(input.negative_prompt_embed());
80+
util::proto_to_torch(input.negative_prompt_embed());
23981
}
24082
if (input.has_negative_pooled_prompt_embed()) {
24183
input_params.negative_pooled_prompt_embed =
242-
proto_to_torch(input.negative_pooled_prompt_embed());
84+
util::proto_to_torch(input.negative_pooled_prompt_embed());
24385
}
24486
if (input.has_latent()) {
245-
input_params.latent = proto_to_torch(input.latent());
87+
input_params.latent = util::proto_to_torch(input.latent());
24688
}
24789
if (input.has_masked_image_latent()) {
24890
input_params.masked_image_latent =
249-
proto_to_torch(input.masked_image_latent());
91+
util::proto_to_torch(input.masked_image_latent());
25092
}
25193

25294
OpenCVImageDecoder decoder;

xllm/core/runtime/params_utils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,
319319
pb_forward_input->eplb_info().expert_ids().end());
320320
eplb_info.update_layer_id = pb_forward_input->eplb_info().update_layer_id();
321321

322+
if (pb_forward_input->has_mm_data()) {
323+
util::proto_to_mmdata(pb_forward_input->mm_data(), &input_params.mm_data);
324+
}
325+
322326
COUNTER_ADD(proto_latency_seconds_proto2i, timer.elapsed_seconds());
323327
}
324328

@@ -467,6 +471,10 @@ void forward_input_to_proto(const RawForwardInput& inputs,
467471
inputs.dst_block_indices);
468472
ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_cum_sum(), inputs.cum_sum);
469473

474+
if (inputs.mm_data.valid()) {
475+
util::mmdata_to_proto(inputs.mm_data, pb_forward_input->mutable_mm_data());
476+
}
477+
470478
COUNTER_ADD(proto_latency_seconds_i2proto, timer.elapsed_seconds());
471479
}
472480

0 commit comments

Comments
 (0)