Skip to content

Commit 2fc1827

Browse files
author
DvirDukhan
authored
Merge pull request #552 from RedisAI/general_model_inputs_and_outputs_names
expose model inputs and outputs with respect to model definition
2 parents bd9f462 + 21b39cc commit 2fc1827

File tree

16 files changed

+384
-46
lines changed

16 files changed

+384
-46
lines changed

src/backends/onnxruntime.c

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
288288

289289
RAI_Device device;
290290
int64_t deviceid;
291+
char **inputs_ = NULL;
292+
char **outputs_ = NULL;
291293

292294
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
293295
RAI_SetError(error, RAI_EMODELCREATE, "ERR unsupported device");
@@ -352,6 +354,41 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
352354
goto error;
353355
}
354356

357+
size_t n_input_nodes;
358+
status = ort->SessionGetInputCount(session, &n_input_nodes);
359+
if (status != NULL) {
360+
goto error;
361+
}
362+
363+
size_t n_output_nodes;
364+
status = ort->SessionGetOutputCount(session, &n_output_nodes);
365+
if (status != NULL) {
366+
goto error;
367+
}
368+
369+
OrtAllocator *allocator;
370+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
371+
372+
inputs_ = array_new(char *, n_input_nodes);
373+
for (long long i = 0; i < n_input_nodes; i++) {
374+
char *input_name;
375+
status = ort->SessionGetInputName(session, i, allocator, &input_name);
376+
if (status != NULL) {
377+
goto error;
378+
}
379+
inputs_ = array_append(inputs_, input_name);
380+
}
381+
382+
outputs_ = array_new(char *, n_output_nodes);
383+
for (long long i = 0; i < n_output_nodes; i++) {
384+
char *output_name;
385+
status = ort->SessionGetOutputName(session, i, allocator, &output_name);
386+
if (status != NULL) {
387+
goto error;
388+
}
389+
outputs_ = array_append(outputs_, output_name);
390+
}
391+
355392
// Since ONNXRuntime doesn't have a re-serialization function,
356393
// we cache the blob in order to re-serialize it.
357394
// Not optimal for storage purposes, but again, it may be temporary
@@ -367,11 +404,29 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
367404
ret->opts = opts;
368405
ret->data = buffer;
369406
ret->datalen = modellen;
407+
ret->ninputs = n_input_nodes;
408+
ret->noutputs = n_output_nodes;
409+
ret->inputs = inputs_;
410+
ret->outputs = outputs_;
370411

371412
return ret;
372413

373414
error:
374415
RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status));
416+
if (inputs_) {
417+
n_input_nodes = array_len(inputs_);
418+
for (uint32_t i = 0; i < n_input_nodes; i++) {
419+
status = ort->AllocatorFree(allocator, inputs_[i]);
420+
}
421+
array_free(inputs_);
422+
}
423+
if (outputs_) {
424+
n_output_nodes = array_len(outputs_);
425+
for (uint32_t i = 0; i < n_output_nodes; i++) {
426+
status = ort->AllocatorFree(allocator, outputs_[i]);
427+
}
428+
array_free(outputs_);
429+
}
375430
ort->ReleaseStatus(status);
376431
return NULL;
377432
}
@@ -381,6 +436,19 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
381436

382437
RedisModule_Free(model->data);
383438
RedisModule_Free(model->devicestr);
439+
OrtAllocator *allocator;
440+
OrtStatus *status = NULL;
441+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
442+
for (uint32_t i = 0; i < model->ninputs; i++) {
443+
status = ort->AllocatorFree(allocator, model->inputs[i]);
444+
}
445+
array_free(model->inputs);
446+
447+
for (uint32_t i = 0; i < model->noutputs; i++) {
448+
status = ort->AllocatorFree(allocator, model->outputs[i]);
449+
}
450+
array_free(model->outputs);
451+
384452
ort->ReleaseSession(model->session);
385453

386454
model->model = NULL;

src/backends/tensorflow.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
390390
ret->session = session;
391391
ret->backend = backend;
392392
ret->devicestr = RedisModule_Strdup(devicestr);
393+
ret->ninputs = ninputs;
393394
ret->inputs = inputs_;
395+
ret->noutputs = noutputs;
394396
ret->outputs = outputs_;
395397
ret->opts = opts;
396398
ret->refCount = 1;

src/backends/tflite.c

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) {
1818
RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
1919
const char *modeldef, size_t modellen, RAI_Error *error) {
2020
DLDeviceType dl_device;
21-
2221
RAI_Device device;
2322
int64_t deviceid;
23+
char **inputs_ = NULL;
24+
char **outputs_ = NULL;
2425
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
2526
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Unsupported device");
2627
return NULL;
@@ -47,6 +48,36 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
4748
return NULL;
4849
}
4950

51+
size_t ninputs = tfliteModelNumInputs(model, &error_descr);
52+
if (error_descr) {
53+
goto cleanup;
54+
}
55+
56+
size_t noutputs = tfliteModelNumOutputs(model, &error_descr);
57+
if (error_descr) {
58+
goto cleanup;
59+
}
60+
61+
inputs_ = array_new(char *, ninputs);
62+
outputs_ = array_new(char *, noutputs);
63+
64+
for (size_t i = 0; i < ninputs; i++) {
65+
const char *input = tfliteModelInputNameAtIndex(model, i, &error_descr);
66+
if (error_descr) {
67+
goto cleanup;
68+
}
69+
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
70+
}
71+
72+
for (size_t i = 0; i < noutputs; i++) {
73+
const char *output = tfliteModelOutputNameAtIndex(model, i, &error_descr);
74+
;
75+
if (error_descr) {
76+
goto cleanup;
77+
}
78+
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
79+
}
80+
5081
char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
5182
memcpy(buffer, modeldef, modellen);
5283

@@ -55,20 +86,51 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
5586
ret->session = NULL;
5687
ret->backend = backend;
5788
ret->devicestr = RedisModule_Strdup(devicestr);
58-
ret->inputs = NULL;
59-
ret->outputs = NULL;
89+
ret->ninputs = ninputs;
90+
ret->inputs = inputs_;
91+
ret->noutputs = noutputs;
92+
ret->outputs = outputs_;
6093
ret->refCount = 1;
6194
ret->opts = opts;
6295
ret->data = buffer;
6396
ret->datalen = modellen;
64-
6597
return ret;
98+
99+
cleanup:
100+
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
101+
RedisModule_Free(error_descr);
102+
if (inputs_) {
103+
ninputs = array_len(inputs_);
104+
for (size_t i = 0; i < ninputs; i++) {
105+
RedisModule_Free(inputs_[i]);
106+
}
107+
array_free(inputs_);
108+
}
109+
if (outputs_) {
110+
noutputs = array_len(outputs_);
111+
for (size_t i = 0; i < noutputs; i++) {
112+
RedisModule_Free(outputs_[i]);
113+
}
114+
array_free(outputs_);
115+
}
116+
return NULL;
66117
}
67118

68119
void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error) {
69120
RedisModule_Free(model->data);
70121
RedisModule_Free(model->devicestr);
71122
tfliteDeallocContext(model->model);
123+
size_t ninputs = model->ninputs;
124+
for (size_t i = 0; i < ninputs; i++) {
125+
RedisModule_Free(model->inputs[i]);
126+
}
127+
array_free(model->inputs);
128+
129+
size_t noutputs = model->noutputs;
130+
for (size_t i = 0; i < noutputs; i++) {
131+
RedisModule_Free(model->outputs[i]);
132+
}
133+
array_free(model->outputs);
72134

73135
model->model = NULL;
74136
}

src/backends/torch.c

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
2222
RAI_Device device = RAI_DEVICE_CPU;
2323
int64_t deviceid = 0;
2424

25+
char **inputs_ = NULL;
26+
char **outputs_ = NULL;
27+
2528
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
2629
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device");
2730
return NULL;
@@ -53,7 +56,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
5356
if (opts.backends_intra_op_parallelism > 0) {
5457
torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr, RedisModule_Alloc);
5558
}
56-
if (error_descr != NULL) {
59+
if (error_descr) {
5760
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
5861
RedisModule_Free(error_descr);
5962
return NULL;
@@ -62,10 +65,37 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
6265
void *model =
6366
torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc);
6467

65-
if (model == NULL) {
66-
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
67-
RedisModule_Free(error_descr);
68-
return NULL;
68+
if (error_descr) {
69+
goto cleanup;
70+
}
71+
72+
size_t ninputs = torchModelNumInputs(model, &error_descr);
73+
if (error_descr) {
74+
goto cleanup;
75+
}
76+
77+
size_t noutputs = torchModelNumOutputs(model, &error_descr);
78+
if (error_descr) {
79+
goto cleanup;
80+
}
81+
82+
inputs_ = array_new(char *, ninputs);
83+
outputs_ = array_new(char *, noutputs);
84+
85+
for (size_t i = 0; i < ninputs; i++) {
86+
const char *input = torchModelInputNameAtIndex(model, i, &error_descr);
87+
if (error_descr) {
88+
goto cleanup;
89+
}
90+
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
91+
}
92+
93+
for (size_t i = 0; i < noutputs; i++) {
94+
const char *output = "";
95+
if (error_descr) {
96+
goto cleanup;
97+
}
98+
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
6999
}
70100

71101
char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
@@ -76,14 +106,34 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
76106
ret->session = NULL;
77107
ret->backend = backend;
78108
ret->devicestr = RedisModule_Strdup(devicestr);
79-
ret->inputs = NULL;
80-
ret->outputs = NULL;
109+
ret->ninputs = ninputs;
110+
ret->inputs = inputs_;
111+
ret->noutputs = noutputs;
112+
ret->outputs = outputs_;
81113
ret->opts = opts;
82114
ret->refCount = 1;
83115
ret->data = buffer;
84116
ret->datalen = modellen;
85-
86117
return ret;
118+
119+
cleanup:
120+
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
121+
RedisModule_Free(error_descr);
122+
if (inputs_) {
123+
ninputs = array_len(inputs_);
124+
for (size_t i = 0; i < ninputs; i++) {
125+
RedisModule_Free(inputs_[i]);
126+
}
127+
array_free(inputs_);
128+
}
129+
if (outputs_) {
130+
noutputs = array_len(outputs_);
131+
for (size_t i = 0; i < noutputs; i++) {
132+
RedisModule_Free(outputs_[i]);
133+
}
134+
array_free(outputs_);
135+
}
136+
return NULL;
87137
}
88138

89139
void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
@@ -93,6 +143,18 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
93143
if (model->data) {
94144
RedisModule_Free(model->data);
95145
}
146+
size_t ninputs = model->ninputs;
147+
for (size_t i = 0; i < ninputs; i++) {
148+
RedisModule_Free(model->inputs[i]);
149+
}
150+
array_free(model->inputs);
151+
152+
size_t noutputs = model->noutputs;
153+
for (size_t i = 0; i < noutputs; i++) {
154+
RedisModule_Free(model->outputs[i]);
155+
}
156+
array_free(model->outputs);
157+
96158
torchDeallocContext(model->model);
97159
}
98160

src/command_parser.c

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,15 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a
6868
}
6969
}
7070
}
71-
if ((*model)->inputs && (*model)->ninputs != ninputs) {
71+
if ((*model)->ninputs != ninputs) {
7272
RAI_SetError(error, RAI_EMODELRUN,
73-
"Number of names given as INPUTS during MODELSET and keys given as "
74-
"INPUTS here do not match");
73+
"Number of keys given as INPUTS here does not match model definition");
7574
return REDISMODULE_ERR;
7675
}
7776

78-
if ((*model)->outputs && (*model)->noutputs != noutputs) {
77+
if ((*model)->noutputs != noutputs) {
7978
RAI_SetError(error, RAI_EMODELRUN,
80-
"Number of names given as OUTPUTS during MODELSET and keys given as "
81-
"OUTPUTS here do not match");
79+
"Number of keys given as OUTPUTS here does not match model definition");
8280
return REDISMODULE_ERR;
8381
}
8482
return REDISMODULE_OK;

0 commit comments

Comments
 (0)