@@ -19,6 +19,7 @@ limitations under the License.
1919#include " butil/base64.h"
2020#include " core/common/instance_name.h"
2121#include " core/common/macros.h"
22+ #include " core/util/utils.h"
2223#include " core/util/uuid.h"
2324#include " mm_codec.h"
2425#include " request.h"
@@ -34,165 +35,6 @@ std::string generate_image_generation_request_id() {
3435
3536} // namespace
3637
37- static torch::ScalarType datatype_proto_to_torch (
38- const std::string& proto_datatype) {
39- static const std::unordered_map<std::string, torch::ScalarType> kDatatypeMap =
40- {{" BOOL" , torch::kBool },
41- {" INT32" , torch::kInt },
42- {" INT64" , torch::kLong },
43- {" UINT32" , torch::kInt32 },
44- {" UINT64" , torch::kInt64 },
45- {" FP32" , torch::kFloat },
46- {" FP64" , torch::kDouble },
47- {" BYTES" , torch::kByte }};
48-
49- auto iter = kDatatypeMap .find (proto_datatype);
50- if (iter == kDatatypeMap .end ()) {
51- LOG (FATAL)
52- << " Unsupported proto datatype: " << proto_datatype
53- << " (supported types: BOOL/INT32/INT64/UINT32/UINT64/FP32/FP64/BYTES)" ;
54- }
55- return iter->second ;
56- }
57-
58- template <typename T>
59- static const void * get_data_from_contents (const proto::TensorContents& contents,
60- const std::string& datatype) {
61- if constexpr (std::is_same_v<T, bool >) {
62- if (contents.bool_contents ().empty ()) {
63- LOG (ERROR) << " TensorContents.bool_contents is empty (datatype="
64- << datatype << " )" ;
65- return nullptr ;
66- }
67- return contents.bool_contents ().data ();
68- } else if constexpr (std::is_same_v<T, int32_t >) {
69- if (contents.int_contents ().empty ()) {
70- LOG (ERROR) << " TensorContents.int_contents is empty (datatype="
71- << datatype << " )" ;
72- return nullptr ;
73- }
74- return contents.int_contents ().data ();
75- } else if constexpr (std::is_same_v<T, int64_t >) {
76- if (contents.int64_contents ().empty ()) {
77- LOG (ERROR) << " TensorContents.int64_contents is empty (datatype="
78- << datatype << " )" ;
79- return nullptr ;
80- }
81- return contents.int64_contents ().data ();
82- } else if constexpr (std::is_same_v<T, uint32_t >) {
83- if (contents.uint_contents ().empty ()) {
84- LOG (ERROR) << " TensorContents.uint_contents is empty (datatype="
85- << datatype << " )" ;
86- return nullptr ;
87- }
88- return contents.uint_contents ().data ();
89- } else if constexpr (std::is_same_v<T, uint64_t >) {
90- if (contents.uint64_contents ().empty ()) {
91- LOG (ERROR) << " TensorContents.uint64_contents is empty (datatype="
92- << datatype << " )" ;
93- return nullptr ;
94- }
95- return contents.uint64_contents ().data ();
96- } else if constexpr (std::is_same_v<T, float >) {
97- if (contents.fp32_contents ().empty ()) {
98- LOG (ERROR) << " TensorContents.fp32_contents is empty (datatype="
99- << datatype << " )" ;
100- return nullptr ;
101- }
102- return contents.fp32_contents ().data ();
103- } else if constexpr (std::is_same_v<T, double >) {
104- if (contents.fp64_contents ().empty ()) {
105- LOG (ERROR) << " TensorContents.fp64_contents is empty (datatype="
106- << datatype << " )" ;
107- return nullptr ;
108- }
109- return contents.fp64_contents ().data ();
110- } else {
111- LOG (FATAL) << " Unsupported data type for TensorContents: "
112- << typeid (T).name ();
113- return nullptr ;
114- }
115- }
116-
117- torch::Tensor proto_to_torch (const proto::Tensor& proto_tensor) {
118- if (proto_tensor.datatype ().empty ()) {
119- LOG (ERROR) << " Proto Tensor missing required field: datatype (e.g., "
120- " \" FP32\" , \" INT64\" )" ;
121- return torch::Tensor ();
122- }
123- if (proto_tensor.shape ().empty ()) {
124- LOG (ERROR) << " Proto Tensor has empty shape (invalid tensor)" ;
125- return torch::Tensor ();
126- }
127- if (!proto_tensor.has_contents ()) {
128- LOG (ERROR)
129- << " Proto Tensor missing required field: contents (TensorContents)" ;
130- return torch::Tensor ();
131- }
132- const auto & proto_contents = proto_tensor.contents ();
133-
134- const std::string& proto_datatype = proto_tensor.datatype ();
135- torch::ScalarType torch_dtype = datatype_proto_to_torch (proto_datatype);
136- const size_t element_size = torch::elementSize (torch_dtype);
137-
138- std::vector<int64_t > torch_shape;
139- int64_t total_elements = 1 ;
140- for (const auto & dim : proto_tensor.shape ()) {
141- if (dim <= 0 ) {
142- LOG (ERROR) << " Proto Tensor has invalid dimension: " << dim
143- << " (must be positive, datatype=" << proto_datatype << " )" ;
144- return torch::Tensor ();
145- }
146- torch_shape.emplace_back (dim);
147- total_elements *= dim;
148- }
149- torch::IntArrayRef tensor_shape (torch_shape);
150-
151- const void * data_ptr = nullptr ;
152- size_t data_count = 0 ;
153- if (proto_datatype == " BOOL" ) {
154- data_ptr = get_data_from_contents<bool >(proto_contents, proto_datatype);
155- data_count = proto_contents.bool_contents_size ();
156- } else if (proto_datatype == " INT32" ) {
157- data_ptr = get_data_from_contents<int32_t >(proto_contents, proto_datatype);
158- data_count = proto_contents.int_contents_size ();
159- } else if (proto_datatype == " INT64" ) {
160- data_ptr = get_data_from_contents<int64_t >(proto_contents, proto_datatype);
161- data_count = proto_contents.int64_contents_size ();
162- } else if (proto_datatype == " UINT32" ) {
163- data_ptr = get_data_from_contents<uint32_t >(proto_contents, proto_datatype);
164- data_count = proto_contents.uint_contents_size ();
165- } else if (proto_datatype == " UINT64" ) {
166- data_ptr = get_data_from_contents<uint64_t >(proto_contents, proto_datatype);
167- data_count = proto_contents.uint64_contents_size ();
168- } else if (proto_datatype == " FP32" ) {
169- data_ptr = get_data_from_contents<float >(proto_contents, proto_datatype);
170- data_count = proto_contents.fp32_contents_size ();
171- } else if (proto_datatype == " FP64" ) {
172- data_ptr = get_data_from_contents<double >(proto_contents, proto_datatype);
173- data_count = proto_contents.fp64_contents_size ();
174- }
175-
176- if (data_ptr == nullptr ) {
177- LOG (ERROR) << " Failed to get data from TensorContents (datatype="
178- << proto_datatype << " )" ;
179- return torch::Tensor ();
180- }
181- if (data_count != static_cast <size_t >(total_elements)) {
182- LOG (ERROR) << " Proto Tensor data count mismatch (datatype="
183- << proto_datatype << " ): "
184- << " expected " << total_elements
185- << " elements (shape=" << tensor_shape << " ), "
186- << " got " << data_count << " elements" ;
187- return torch::Tensor ();
188- }
189-
190- torch::Tensor tensor =
191- torch::from_blob (const_cast <void *>(data_ptr), tensor_shape, torch_dtype)
192- .clone ();
193- return tensor;
194- }
195-
19638std::pair<int , int > splitResolution (const std::string& s) {
19739 size_t pos = s.find (' *' );
19840 int width = std::stoi (s.substr (0 , pos));
@@ -227,26 +69,26 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
22769 }
22870
22971 if (input.has_prompt_embed ()) {
230- input_params.prompt_embed = proto_to_torch (input.prompt_embed ());
72+ input_params.prompt_embed = util:: proto_to_torch (input.prompt_embed ());
23173 }
23274 if (input.has_pooled_prompt_embed ()) {
23375 input_params.pooled_prompt_embed =
234- proto_to_torch (input.pooled_prompt_embed ());
76+ util:: proto_to_torch (input.pooled_prompt_embed ());
23577 }
23678 if (input.has_negative_prompt_embed ()) {
23779 input_params.negative_prompt_embed =
238- proto_to_torch (input.negative_prompt_embed ());
80+ util:: proto_to_torch (input.negative_prompt_embed ());
23981 }
24082 if (input.has_negative_pooled_prompt_embed ()) {
24183 input_params.negative_pooled_prompt_embed =
242- proto_to_torch (input.negative_pooled_prompt_embed ());
84+ util:: proto_to_torch (input.negative_pooled_prompt_embed ());
24385 }
24486 if (input.has_latent ()) {
245- input_params.latent = proto_to_torch (input.latent ());
87+ input_params.latent = util:: proto_to_torch (input.latent ());
24688 }
24789 if (input.has_masked_image_latent ()) {
24890 input_params.masked_image_latent =
249- proto_to_torch (input.masked_image_latent ());
91+ util:: proto_to_torch (input.masked_image_latent ());
25092 }
25193
25294 OpenCVImageDecoder decoder;
0 commit comments