11#include " torch_c.h"
22#include < torch/torch.h>
33#include < torch/csrc/jit/import.h>
4+ #include < torch/csrc/jit/script/compilation_unit.h>
45#include < iostream>
56#include < sstream>
67
78#include < ATen/Functions.h>
89
910namespace {
1011
11- static DLDataType getDLDataType (const at::Type& type ) {
12+ static DLDataType getDLDataType (const at::Tensor& t ) {
1213 DLDataType dtype;
1314 dtype.lanes = 1 ;
14- dtype.bits = type. elementSizeInBytes () * 8 ;
15- switch (type. scalarType ()) {
15+ dtype.bits = t. element_size () * 8 ;
16+ switch (t. scalar_type ()) {
1617 case at::ScalarType::Byte:
1718 dtype.code = DLDataTypeCode::kDLUInt ;
1819 break ;
@@ -37,6 +38,10 @@ static DLDataType getDLDataType(const at::Type& type) {
3738 case at::ScalarType::Half:
3839 dtype.code = DLDataTypeCode::kDLFloat ;
3940 break ;
41+ case at::ScalarType::Bool:
42+ throw std::logic_error (" Bool is not supported by dlpack" );
43+ case at::ScalarType::QInt8:
44+ throw std::logic_error (" QInt8 is not supported by dlpack" );
4045 case at::ScalarType::ComplexHalf:
4146 throw std::logic_error (" ComplexHalf is not supported by dlpack" );
4247 case at::ScalarType::ComplexFloat:
@@ -51,10 +56,10 @@ static DLDataType getDLDataType(const at::Type& type) {
5156 return dtype;
5257}
5358
54- static DLContext getDLContext (const at::Type& type , const int64_t & device_id) {
59+ static DLContext getDLContext (const at::Tensor& tensor , const int64_t & device_id) {
5560 DLContext ctx;
5661 ctx.device_id = device_id;
57- if (type .is_cuda ()) {
62+ if (tensor .is_cuda ()) {
5863 ctx.device_type = DLDeviceType::kDLGPU ;
5964 } else {
6065 ctx.device_type = DLDeviceType::kDLCPU ;
@@ -134,8 +139,8 @@ torch::Tensor fromDLPack(const DLTensor* src) {
134139 at::DeviceType device_type = getATenDeviceType (src->ctx .device_type );
135140 at::ScalarType stype = toScalarType (src->dtype );
136141 return torch::from_blob (src->data ,
137- at::IntList (src->shape , src->ndim ),
138- at::IntList (src->strides , src->ndim ),
142+ at::IntArrayRef (src->shape , src->ndim ),
143+ at::IntArrayRef (src->strides , src->ndim ),
139144 torch::device (device_type).dtype (stype));
140145}
141146
@@ -158,9 +163,9 @@ DLManagedTensor* toManagedDLPack(const torch::Tensor& src) {
158163 if (src.is_cuda ()) {
159164 device_id = src.get_device ();
160165 }
161- atDLMTensor->tensor .dl_tensor .ctx = getDLContext (src. type () , device_id);
166+ atDLMTensor->tensor .dl_tensor .ctx = getDLContext (src, device_id);
162167 atDLMTensor->tensor .dl_tensor .ndim = src.dim ();
163- atDLMTensor->tensor .dl_tensor .dtype = getDLDataType (src. type () );
168+ atDLMTensor->tensor .dl_tensor .dtype = getDLDataType (src);
164169 atDLMTensor->tensor .dl_tensor .shape = const_cast <int64_t *>(src.sizes ().data ());
165170 atDLMTensor->tensor .dl_tensor .strides = const_cast <int64_t *>(src.strides ().data ());
166171 atDLMTensor->tensor .dl_tensor .byte_offset = 0 ;
@@ -169,6 +174,7 @@ DLManagedTensor* toManagedDLPack(const torch::Tensor& src) {
169174
170175struct ModuleContext {
171176 std::shared_ptr<torch::jit::script::Module> module ;
177+ std::shared_ptr<torch::jit::script::CompilationUnit> cu;
172178 DLDeviceType device;
173179};
174180
@@ -191,8 +197,6 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
191197 throw std::runtime_error (std::string (" Unsupported device " ) + std::to_string (ctx->device ));
192198 }
193199
194- torch::jit::script::Method& method = ctx->module ->get_method (fnName);
195-
196200 torch::jit::Stack stack;
197201
198202 for (int i=0 ; i<nInputs; i++) {
@@ -201,7 +205,14 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
201205 stack.push_back (tensor.to (device));
202206 }
203207
204- method.run (stack);
208+ if (ctx->module ) {
209+ torch::jit::script::Method& method = ctx->module ->get_method (fnName);
210+ method.run (stack);
211+ }
212+ else {
213+ torch::jit::script::Function& fn = ctx->cu ->get_function (fnName);
214+ fn.run (stack);
215+ }
205216
206217 torch::DeviceType output_device = torch::kCPU ;
207218
@@ -254,8 +265,8 @@ extern "C" DLManagedTensor* torchNewTensor(DLDataType dtype, long ndims, int64_t
254265 at::DeviceType device_type = getATenDeviceType (kDLCPU );
255266 at::ScalarType stype = toScalarType (dtype);
256267 torch::Tensor tensor = torch::from_blob (data,
257- at::IntList (shape, ndims),
258- at::IntList (strides, ndims),
268+ at::IntArrayRef (shape, ndims),
269+ at::IntArrayRef (strides, ndims),
259270 torch::device (at::DeviceType::CPU).dtype (stype));
260271
261272 DLManagedTensor *dl_tensor = toManagedDLPack (tensor);
@@ -269,8 +280,9 @@ extern "C" void* torchCompileScript(const char* script, DLDeviceType device,
269280 ModuleContext* ctx = new ModuleContext ();
270281 ctx->device = device;
271282 try {
272- auto module = torch::jit::compile (script);
273- ctx->module = module ;
283+ auto cu = torch::jit::compile (script);
284+ ctx->cu = cu;
285+ ctx->module = nullptr ;
274286 }
275287 catch (std::exception& e) {
276288 size_t len = strlen (e.what ());
@@ -297,6 +309,7 @@ extern "C" void* torchLoadModel(const char* graph, size_t graphlen, DLDeviceType
297309 }
298310 module ->to (aten_device);
299311 ctx->module = module ;
312+ ctx->cu = nullptr ;
300313 }
301314 catch (std::exception& e) {
302315 size_t len = strlen (e.what ());
0 commit comments