Skip to content

Commit c36e28e

Browse files
author
DvirDukhan
committed
expose model inputs and outputs with respect to model definition
1 parent 64001de commit c36e28e

File tree

14 files changed

+356
-38
lines changed

14 files changed

+356
-38
lines changed

src/backends/onnxruntime.c

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
288288

289289
RAI_Device device;
290290
int64_t deviceid;
291+
char** inputs_ = NULL;
292+
char** outputs_ = NULL;
291293

292294
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
293295
RAI_SetError(error, RAI_EMODELCREATE, "ERR unsupported device");
@@ -352,6 +354,41 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
352354
goto error;
353355
}
354356

357+
size_t n_input_nodes;
358+
status = ort->SessionGetInputCount(session, &n_input_nodes);
359+
if (status != NULL) {
360+
goto error;
361+
}
362+
363+
size_t n_output_nodes;
364+
status = ort->SessionGetOutputCount(session, &n_output_nodes);
365+
if (status != NULL) {
366+
goto error;
367+
}
368+
369+
OrtAllocator *allocator;
370+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
371+
372+
inputs_ = array_new(char*, n_input_nodes);
373+
for (long long i = 0; i < n_input_nodes; i++) {
374+
char* input_name;
375+
status = ort->SessionGetInputName(session, i, allocator, &input_name);
376+
if (status != NULL) {
377+
goto error;
378+
}
379+
inputs_ = array_append(inputs_, input_name);
380+
}
381+
382+
outputs_ = array_new(char *, n_output_nodes);
383+
for (long long i = 0; i < n_output_nodes; i++) {
384+
char* output_name;
385+
status = ort->SessionGetOutputName(session, i, allocator, &output_name);
386+
if (status != NULL) {
387+
goto error;
388+
}
389+
outputs_ = array_append(outputs_, output_name);
390+
}
391+
355392
// Since ONNXRuntime doesn't have a re-serialization function,
356393
// we cache the blob in order to re-serialize it.
357394
// Not optimal for storage purposes, but again, it may be temporary
@@ -367,11 +404,29 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
367404
ret->opts = opts;
368405
ret->data = buffer;
369406
ret->datalen = modellen;
407+
ret->ninputs = n_input_nodes;
408+
ret->noutputs = n_output_nodes;
409+
ret->inputs = inputs_;
410+
ret->outputs = outputs_;
370411

371412
return ret;
372413

373414
error:
374415
RAI_SetError(error, RAI_EMODELCREATE, ort->GetErrorMessage(status));
416+
if(inputs_) {
417+
n_input_nodes = array_len(inputs_);
418+
for(uint32_t i = 0; i <n_input_nodes; i++) {
419+
status = ort->AllocatorFree(allocator, inputs_[i]);
420+
}
421+
array_free(inputs_);
422+
}
423+
if(outputs_){
424+
n_output_nodes = array_len(outputs_);
425+
for(uint32_t i = 0; i <n_output_nodes; i++) {
426+
status = ort->AllocatorFree(allocator, outputs_[i]);
427+
}
428+
array_free(outputs_);
429+
}
375430
ort->ReleaseStatus(status);
376431
return NULL;
377432
}
@@ -381,6 +436,19 @@ void RAI_ModelFreeORT(RAI_Model *model, RAI_Error *error) {
381436

382437
RedisModule_Free(model->data);
383438
RedisModule_Free(model->devicestr);
439+
OrtAllocator *allocator;
440+
OrtStatus *status = NULL;
441+
status = ort->GetAllocatorWithDefaultOptions(&allocator);
442+
for(uint32_t i = 0; i < model->ninputs; i++) {
443+
status = ort->AllocatorFree(allocator, model->inputs[i]);
444+
}
445+
array_free(model->inputs);
446+
447+
for(uint32_t i = 0; i < model->noutputs; i++) {
448+
status = ort->AllocatorFree(allocator, model->outputs[i]);
449+
}
450+
array_free(model->outputs);
451+
384452
ort->ReleaseSession(model->session);
385453

386454
model->model = NULL;

src/backends/tensorflow.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
390390
ret->session = session;
391391
ret->backend = backend;
392392
ret->devicestr = RedisModule_Strdup(devicestr);
393+
ret->ninputs = ninputs;
393394
ret->inputs = inputs_;
395+
ret->noutputs = noutputs;
394396
ret->outputs = outputs_;
395397
ret->opts = opts;
396398
ret->refCount = 1;

src/backends/tflite.c

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ int RAI_InitBackendTFLite(int (*get_api_fn)(const char *, void *)) {
1818
RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts,
1919
const char *modeldef, size_t modellen, RAI_Error *error) {
2020
DLDeviceType dl_device;
21-
2221
RAI_Device device;
2322
int64_t deviceid;
23+
char** inputs_ = NULL;
24+
char** outputs_ = NULL;
2425
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
2526
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR Unsupported device");
2627
return NULL;
@@ -48,6 +49,35 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
4849
return NULL;
4950
}
5051

52+
size_t ninputs = tfliteModelNumInputs(model, &error_descr, RedisModule_Alloc);
53+
if(error_descr) {
54+
goto cleanup;
55+
}
56+
57+
size_t noutputs = tfliteModelNumOutputs(model, &error_descr, RedisModule_Alloc);
58+
if(error_descr) {
59+
goto cleanup;
60+
}
61+
62+
inputs_ = array_new(char*, ninputs);
63+
outputs_ = array_new(char*, noutputs);
64+
65+
for (size_t i = 0; i < ninputs; i++) {
66+
const char* input = tfliteModelInputNameAtIndex(model, i, &error_descr, RedisModule_Alloc);
67+
if(error_descr) {
68+
goto cleanup;
69+
}
70+
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
71+
}
72+
73+
for (size_t i = 0; i < noutputs; i++) {
74+
const char* output = tfliteModelOutputNameAtIndex(model, i, &error_descr, RedisModule_Alloc);;
75+
if(error_descr) {
76+
goto cleanup;
77+
}
78+
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
79+
}
80+
5181
char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
5282
memcpy(buffer, modeldef, modellen);
5383

@@ -56,20 +86,51 @@ RAI_Model *RAI_ModelCreateTFLite(RAI_Backend backend, const char *devicestr, RAI
5686
ret->session = NULL;
5787
ret->backend = backend;
5888
ret->devicestr = RedisModule_Strdup(devicestr);
59-
ret->inputs = NULL;
60-
ret->outputs = NULL;
89+
ret->ninputs = ninputs;
90+
ret->inputs = inputs_;
91+
ret->noutputs = noutputs;
92+
ret->outputs = outputs_;
6193
ret->refCount = 1;
6294
ret->opts = opts;
6395
ret->data = buffer;
6496
ret->datalen = modellen;
65-
6697
return ret;
98+
99+
cleanup:
100+
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
101+
RedisModule_Free(error_descr);
102+
if(inputs_) {
103+
ninputs = array_len(inputs_);
104+
for(size_t i =0 ; i < ninputs; i++) {
105+
RedisModule_Free(inputs_[i]);
106+
}
107+
array_free(inputs_);
108+
}
109+
if(outputs_) {
110+
noutputs = array_len(outputs_);
111+
for(size_t i =0 ; i < noutputs; i++) {
112+
RedisModule_Free(outputs_[i]);
113+
}
114+
array_free(outputs_);
115+
}
116+
return NULL;
67117
}
68118

69119
void RAI_ModelFreeTFLite(RAI_Model *model, RAI_Error *error) {
70120
RedisModule_Free(model->data);
71121
RedisModule_Free(model->devicestr);
72122
tfliteDeallocContext(model->model);
123+
size_t ninputs = model->ninputs;
124+
for(size_t i =0 ; i < ninputs; i++) {
125+
RedisModule_Free(model->inputs[i]);
126+
}
127+
array_free(model->inputs);
128+
129+
size_t noutputs = model->noutputs;
130+
for(size_t i =0 ; i < noutputs; i++) {
131+
RedisModule_Free(model->outputs[i]);
132+
}
133+
array_free(model->outputs);
73134

74135
model->model = NULL;
75136
}

src/backends/torch.c

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
2222
RAI_Device device = RAI_DEVICE_CPU;
2323
int64_t deviceid = 0;
2424

25+
char** inputs_ = NULL;
26+
char** outputs_ = NULL;
27+
2528
if (!parseDeviceStr(devicestr, &device, &deviceid)) {
2629
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device");
2730
return NULL;
@@ -53,7 +56,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
5356
if (opts.backends_intra_op_parallelism > 0) {
5457
torchSetIntraOpThreads(opts.backends_intra_op_parallelism, &error_descr, RedisModule_Alloc);
5558
}
56-
if (error_descr != NULL) {
59+
if (error_descr) {
5760
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
5861
RedisModule_Free(error_descr);
5962
return NULL;
@@ -62,28 +65,76 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
6265
void *model =
6366
torchLoadModel(modeldef, modellen, dl_device, deviceid, &error_descr, RedisModule_Alloc);
6467

65-
if (model == NULL) {
66-
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
67-
RedisModule_Free(error_descr);
68-
return NULL;
68+
if (error_descr) {
69+
goto cleanup;
70+
}
71+
72+
size_t ninputs = torchModelNumInputs(model, &error_descr);
73+
if(error_descr) {
74+
goto cleanup;
75+
}
76+
77+
size_t noutputs = torchModelNumOutputs(model, &error_descr);
78+
if(error_descr) {
79+
goto cleanup;
80+
}
81+
82+
inputs_ = array_new(char*, ninputs);
83+
outputs_ = array_new(char*, noutputs);
84+
85+
for (size_t i = 0; i < ninputs; i++) {
86+
const char* input = torchModelInputNameAtIndex(model, i, &error_descr);
87+
if(error_descr) {
88+
goto cleanup;
89+
}
90+
inputs_ = array_append(inputs_, RedisModule_Strdup(input));
91+
}
92+
93+
for (size_t i = 0; i < noutputs; i++) {
94+
const char* output ="";
95+
if(error_descr) {
96+
goto cleanup;
97+
}
98+
outputs_ = array_append(outputs_, RedisModule_Strdup(output));
6999
}
70100

71101
char *buffer = RedisModule_Calloc(modellen, sizeof(*buffer));
72102
memcpy(buffer, modeldef, modellen);
73103

104+
74105
RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret));
75106
ret->model = model;
76107
ret->session = NULL;
77108
ret->backend = backend;
78109
ret->devicestr = RedisModule_Strdup(devicestr);
79-
ret->inputs = NULL;
80-
ret->outputs = NULL;
110+
ret->ninputs = ninputs;
111+
ret->inputs = inputs_;
112+
ret->noutputs = noutputs;
113+
ret->outputs = outputs_;
81114
ret->opts = opts;
82115
ret->refCount = 1;
83116
ret->data = buffer;
84117
ret->datalen = modellen;
85-
86118
return ret;
119+
120+
cleanup:
121+
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
122+
RedisModule_Free(error_descr);
123+
if(inputs_) {
124+
ninputs = array_len(inputs_);
125+
for(size_t i =0 ; i < ninputs; i++) {
126+
RedisModule_Free(inputs_[i]);
127+
}
128+
array_free(inputs_);
129+
}
130+
if(outputs_) {
131+
noutputs = array_len(outputs_);
132+
for(size_t i =0 ; i < noutputs; i++) {
133+
RedisModule_Free(outputs_[i]);
134+
}
135+
array_free(outputs_);
136+
}
137+
return NULL;
87138
}
88139

89140
void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
@@ -93,6 +144,18 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
93144
if (model->data) {
94145
RedisModule_Free(model->data);
95146
}
147+
size_t ninputs = model->ninputs;
148+
for(size_t i =0 ; i < ninputs; i++) {
149+
RedisModule_Free(model->inputs[i]);
150+
}
151+
array_free(model->inputs);
152+
153+
size_t noutputs = model->noutputs;
154+
for(size_t i =0 ; i < noutputs; i++) {
155+
RedisModule_Free(model->outputs[i]);
156+
}
157+
array_free(model->outputs);
158+
96159
torchDeallocContext(model->model);
97160
}
98161

src/command_parser.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,13 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a
7171
}
7272
if ((*model)->inputs && (*model)->ninputs != ninputs) {
7373
RAI_SetError(error, RAI_EMODELRUN,
74-
"Number of names given as INPUTS during MODELSET and keys given as "
75-
"INPUTS here do not match");
74+
"Number of keys given as INPUTS here does not match model definition");
7675
return REDISMODULE_ERR;
7776
}
7877

7978
if ((*model)->outputs && (*model)->noutputs != noutputs) {
8079
RAI_SetError(error, RAI_EMODELRUN,
81-
"Number of names given as OUTPUTS during MODELSET and keys given as "
82-
"OUTPUTS here do not match");
80+
"Number of keys given as OUTPUTS here does not match model definition");
8381
return REDISMODULE_ERR;
8482
}
8583
return REDISMODULE_OK;

0 commit comments

Comments
 (0)