Skip to content

Commit e50e5cf

Browse files
author
DvirDukhan
authored
Merge branch 'master' into torchscript_extensions
2 parents f89a033 + 2fc1827 commit e50e5cf

File tree

16 files changed

+384
-46
lines changed

16 files changed

+384
-46
lines changed

src/backends/onnxruntime.c

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
288288

289289
RAI_Device device;
290290
int64_t deviceid;
291+
char **inputs_ = NULL;
292+
char **outputs_ = NULL;
291293

292294
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
293295
RAI_SetError(error, RAI_EMODELCREATE, "ERR unsupported device");
@@ -352,6 +354,41 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
352354
goto error;
353355
}
354356

357+
size_t n_input_nodes;
358+
status = ort->SessionGetInputCount(session, &n_input_nodes);
359+
if (status != NULL) {
360+
goto error;
361+
}
362+
363+
size_t n_output_nodes;
364+
status = ort->SessionGetOutputCount(session, &n_output_nodes);
365+
if (status != NULL) {
366+
goto error;
367+
}
368+
369+
OrtAllocator *allocator;
370+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
371+
372+
inputs_ = array_new(char *, n_input_nodes);
373+
for (long long i = 0; i < n_input_nodes; i++) {
374+
char *input_name;
375+
status = ort->SessionGetInputName(session, i, allocator, &input_name);
376+
if (status != NULL) {
377+
goto error;
378+
}
379+
inputs_ = array_append(inputs_, input_name);
380+
}
381+
382+
outputs_ = array_new(char *, n_output_nodes);
383+
for (long long i = 0; i < n_output_nodes; i++) {
384+
char *output_name;
385+
status = ort->SessionGetOutputName(session, i, allocator, &output_name);
386+
if (status != NULL) {
387+
goto error;
388+
}
389+
outputs_ = array_append(outputs_, output_name);
390+
}
391+
355392
// Since ONNXRuntime doesn't have a re-serialization function,
356393
// we cache the blob in order to re-serialize it.
357394
// Not optimal for storage purposes, but again, it may be temporary
@@ -367,11 +404,29 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
367404
ret->opts = opts;
368405
ret->data = buffer;
369406
ret->datalen = modellen;
407+
ret->ninputs = n_input_nodes;
408+
ret->noutputs = n_output_nodes;
409+
ret->inputs = inputs_;
410+
ret->outputs = outputs_;
370411

371412
return ret;
372413

373414
error:
374415
RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status));
416+
if (inputs_) {
417+
n_input_nodes = array_len(inputs_);
418+
for (uint32_t i = 0; i < n_input_nodes; i++) {
419+
status = ort->AllocatorFree(allocator, inputs_[i]);
420+
}
421+
array_free(inputs_);
422+
}
423+
if (outputs_) {
424+
n_output_nodes = array_len(outputs_);
425+
for (uint32_t i = 0; i < n_output_nodes; i++) {
426+
status = ort->AllocatorFree(allocator, outputs_[i]);
427+
}
428+
array_free(outputs_);
429+
}
375430
ort->ReleaseStatus(status);
376431
return NULL;
377432
}
@@ -381,6 +436,19 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
381436

382437
RedisModule_Free(model->data);
383438
RedisModule_Free(model->devicestr);
439+
OrtAllocator *allocator;
440+
OrtStatus *status = NULL;
441+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
442+
for (uint32_t i = 0; i < model->ninputs; i++) {
443+
status = ort->AllocatorFree(allocator, model->inputs[i]);
444+
}
445+
array_free(model->inputs);
446+
447+
for (uint32_t i = 0; i < model->noutputs; i++) {
448+
status = ort->AllocatorFree(allocator, model->outputs[i]);
449+
}
450+
array_free(model->outputs);
451+
384452
ort->ReleaseSession(model->session);
385453

386454
model->model = NULL;

src/backends/tensorflow.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
390390
ret->session = session;
391391
ret->backend = backend;
392392
ret->devicestr = RedisModule_Strdup(devicestr);
393+
ret->ninputs = ninputs;
393394
ret->inputs = inputs_;
395+
ret->noutputs = noutputs;
394396
ret->outputs = outputs_;
395397
ret->opts = opts;
396398
ret->refCount = 1;

src/backends/tflite.c

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) {
1818
RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
1919
const char *modeldef, size_t modellen, RAI_Error *error) {
2020
DLDeviceType dl_device;
21-
2221
RAI_Device device;
2322
int64_t deviceid;
23+
char **inputs_ = NULL;
24+
char **outputs_ = NULL;
2425
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
2526
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Unsupported device");
2627
return NULL;
@@ -47,6 +48,36 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
4748
return NULL;
4849
}
4950

51+
size_t ninputs = tfliteModelNumInputs(model, &error_descr);
52+
if (error_descr) {
53+
goto cleanup;
54+
}
55+
56+
size_t noutputs = tfliteModelNumOutputs(model, &error_descr);
57+
if (error_descr) {
58+
goto cleanup;
59+
}
60+
61+
inputs_ = array_new(char *, ninputs);
62+
outputs_ = array_new(char *, noutputs);
63+
64+
for (size_t i = 0; i < ninputs; i++) {
65+
const char *input = tfliteModelInputNameAtIndex(model, i, &error_descr);
66+
if (error_descr) {
67+
goto cleanup;
68+
}
69+
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
70+
}
71+
72+
for (size_t i = 0; i < noutputs; i++) {
73+
const char *output = tfliteModelOutputNameAtIndex(model, i, &error_descr);
74+
;
75+
if (error_descr) {
76+
goto cleanup;
77+
}
78+
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
79+
}
80+
5081
char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
5182
memcpy(buffer, modeldef, modellen);
5283

@@ -55,20 +86,51 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
5586
ret->session = NULL;
5687
ret->backend = backend;
5788
ret->devicestr = RedisModule_Strdup(devicestr);
58-
ret->inputs = NULL;
59-
ret->outputs = NULL;
89+
ret->ninputs = ninputs;
90+
ret->inputs = inputs_;
91+
ret->noutputs = noutputs;
92+
ret->outputs = outputs_;
6093
ret->refCount = 1;
6194
ret->opts = opts;
6295
ret->data = buffer;
6396
ret->datalen = modellen;
64-
6597
return ret;
98+
99+
cleanup:
100+
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
101+
RedisModule_Free(error_descr);
102+
if (inputs_) {
103+
ninputs = array_len(inputs_);
104+
for (size_t i = 0; i < ninputs; i++) {
105+
RedisModule_Free(inputs_[i]);
106+
}
107+
array_free(inputs_);
108+
}
109+
if (outputs_) {
110+
noutputs = array_len(outputs_);
111+
for (size_t i = 0; i < noutputs; i++) {
112+
RedisModule_Free(outputs_[i]);
113+
}
114+
array_free(outputs_);
115+
}
116+
return NULL;
66117
}
67118

68119
void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error) {
69120
RedisModule_Free(model->data);
70121
RedisModule_Free(model->devicestr);
71122
tfliteDeallocContext(model->model);
123+
size_t ninputs = model->ninputs;
124+
for (size_t i = 0; i < ninputs; i++) {
125+
RedisModule_Free(model->inputs[i]);
126+
}
127+
array_free(model->inputs);
128+
129+
size_t noutputs = model->noutputs;
130+
for (size_t i = 0; i < noutputs; i++) {
131+
RedisModule_Free(model->outputs[i]);
132+
}
133+
array_free(model->outputs);
72134

73135
model->model = NULL;
74136
}

src/backends/torch.c

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
3636
RAI_Device device = RAI_DEVICE_CPU;
3737
int64_t deviceid = 0;
3838

39+
char **inputs_ = NULL;
40+
char **outputs_ = NULL;
41+
3942
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
4043
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device");
4144
return NULL;
@@ -67,7 +70,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
6770
if (opts.backends_intra_op_parallelism > 0) {
6871
torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr, RedisModule_Alloc);
6972
}
70-
if (error_descr != NULL) {
73+
if (error_descr) {
7174
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
7275
RedisModule_Free(error_descr);
7376
return NULL;
@@ -76,10 +79,37 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
7679
void *model =
7780
torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc);
7881

79-
if (model == NULL) {
80-
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
81-
RedisModule_Free(error_descr);
82-
return NULL;
82+
if (error_descr) {
83+
goto cleanup;
84+
}
85+
86+
size_t ninputs = torchModelNumInputs(model, &error_descr);
87+
if (error_descr) {
88+
goto cleanup;
89+
}
90+
91+
size_t noutputs = torchModelNumOutputs(model, &error_descr);
92+
if (error_descr) {
93+
goto cleanup;
94+
}
95+
96+
inputs_ = array_new(char *, ninputs);
97+
outputs_ = array_new(char *, noutputs);
98+
99+
for (size_t i = 0; i < ninputs; i++) {
100+
const char *input = torchModelInputNameAtIndex(model, i, &error_descr);
101+
if (error_descr) {
102+
goto cleanup;
103+
}
104+
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
105+
}
106+
107+
for (size_t i = 0; i < noutputs; i++) {
108+
const char *output = "";
109+
if (error_descr) {
110+
goto cleanup;
111+
}
112+
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
83113
}
84114

85115
char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
@@ -90,14 +120,34 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
90120
ret->session = NULL;
91121
ret->backend = backend;
92122
ret->devicestr = RedisModule_Strdup(devicestr);
93-
ret->inputs = NULL;
94-
ret->outputs = NULL;
123+
ret->ninputs = ninputs;
124+
ret->inputs = inputs_;
125+
ret->noutputs = noutputs;
126+
ret->outputs = outputs_;
95127
ret->opts = opts;
96128
ret->refCount = 1;
97129
ret->data = buffer;
98130
ret->datalen = modellen;
99-
100131
return ret;
132+
133+
cleanup:
134+
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
135+
RedisModule_Free(error_descr);
136+
if (inputs_) {
137+
ninputs = array_len(inputs_);
138+
for (size_t i = 0; i < ninputs; i++) {
139+
RedisModule_Free(inputs_[i]);
140+
}
141+
array_free(inputs_);
142+
}
143+
if (outputs_) {
144+
noutputs = array_len(outputs_);
145+
for (size_t i = 0; i < noutputs; i++) {
146+
RedisModule_Free(outputs_[i]);
147+
}
148+
array_free(outputs_);
149+
}
150+
return NULL;
101151
}
102152

103153
void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
@@ -107,6 +157,18 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
107157
if (model->data) {
108158
RedisModule_Free(model->data);
109159
}
160+
size_t ninputs = model->ninputs;
161+
for (size_t i = 0; i < ninputs; i++) {
162+
RedisModule_Free(model->inputs[i]);
163+
}
164+
array_free(model->inputs);
165+
166+
size_t noutputs = model->noutputs;
167+
for (size_t i = 0; i < noutputs; i++) {
168+
RedisModule_Free(model->outputs[i]);
169+
}
170+
array_free(model->outputs);
171+
110172
torchDeallocContext(model->model);
111173
}
112174

src/command_parser.c

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,15 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a
6868
}
6969
}
7070
}
71-
if ((*model)->inputs && (*model)->ninputs != ninputs) {
71+
if ((*model)->ninputs != ninputs) {
7272
RAI_SetError(error, RAI_EMODELRUN,
73-
"Number of names given as INPUTS during MODELSET and keys given as "
74-
"INPUTS here do not match");
73+
"Number of keys given as INPUTS here does not match model definition");
7574
return REDISMODULE_ERR;
7675
}
7776

78-
if ((*model)->outputs && (*model)->noutputs != noutputs) {
77+
if ((*model)->noutputs != noutputs) {
7978
RAI_SetError(error, RAI_EMODELRUN,
80-
"Number of names given as OUTPUTS during MODELSET and keys given as "
81-
"OUTPUTS here do not match");
79+
"Number of keys given as OUTPUTS here does not match model definition");
8280
return REDISMODULE_ERR;
8381
}
8482
return REDISMODULE_OK;

0 commit comments

Comments
 (0)