@@ -36,6 +36,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
3636 RAI_Device device = RAI_DEVICE_CPU ;
3737 int64_t deviceid = 0 ;
3838
39+ char * * inputs_ = NULL ;
40+ char * * outputs_ = NULL ;
41+
3942 if (!parseDeviceStr (devicestr , & device , & deviceid )) {
4043 RAI_SetError (error , RAI_EMODELCONFIGURE , "ERR unsupported device" );
4144 return NULL ;
@@ -67,7 +70,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
6770 if (opts .backends_intra_op_parallelism > 0 ) {
6871 torchSetIntraOpThreads (opts .backends_intra_op_parallelism , & error_descr , RedisModule_Alloc );
6972 }
70- if (error_descr != NULL ) {
73+ if (error_descr ) {
7174 RAI_SetError (error , RAI_EMODELCREATE , error_descr );
7275 RedisModule_Free (error_descr );
7376 return NULL ;
@@ -76,10 +79,37 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
7679 void * model =
7780 torchLoadModel (modeldef , modellen , dl_device , deviceid , & error_descr , RedisModule_Alloc );
7881
79- if (model == NULL ) {
80- RAI_SetError (error , RAI_EMODELCREATE , error_descr );
81- RedisModule_Free (error_descr );
82- return NULL ;
82+ if (error_descr ) {
83+ goto cleanup ;
84+ }
85+
86+ size_t ninputs = torchModelNumInputs (model , & error_descr );
87+ if (error_descr ) {
88+ goto cleanup ;
89+ }
90+
91+ size_t noutputs = torchModelNumOutputs (model , & error_descr );
92+ if (error_descr ) {
93+ goto cleanup ;
94+ }
95+
96+ inputs_ = array_new (char * , ninputs );
97+ outputs_ = array_new (char * , noutputs );
98+
99+ for (size_t i = 0 ; i < ninputs ; i ++ ) {
100+ const char * input = torchModelInputNameAtIndex (model , i , & error_descr );
101+ if (error_descr ) {
102+ goto cleanup ;
103+ }
104+ inputs_ = array_append (inputs_ , RedisModule_Strdup (input ));
105+ }
106+
107+ for (size_t i = 0 ; i < noutputs ; i ++ ) {
108+ const char * output = "" ;
109+ if (error_descr ) {
110+ goto cleanup ;
111+ }
112+ outputs_ = array_append (outputs_ , RedisModule_Strdup (output ));
83113 }
84114
85115 char * buffer = RedisModule_Calloc (modellen , sizeof (* buffer ));
@@ -90,14 +120,34 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
90120 ret -> session = NULL ;
91121 ret -> backend = backend ;
92122 ret -> devicestr = RedisModule_Strdup (devicestr );
93- ret -> inputs = NULL ;
94- ret -> outputs = NULL ;
123+ ret -> ninputs = ninputs ;
124+ ret -> inputs = inputs_ ;
125+ ret -> noutputs = noutputs ;
126+ ret -> outputs = outputs_ ;
95127 ret -> opts = opts ;
96128 ret -> refCount = 1 ;
97129 ret -> data = buffer ;
98130 ret -> datalen = modellen ;
99-
100131 return ret ;
132+
133+ cleanup :
134+ RAI_SetError (error , RAI_EMODELCREATE , error_descr );
135+ RedisModule_Free (error_descr );
136+ if (inputs_ ) {
137+ ninputs = array_len (inputs_ );
138+ for (size_t i = 0 ; i < ninputs ; i ++ ) {
139+ RedisModule_Free (inputs_ [i ]);
140+ }
141+ array_free (inputs_ );
142+ }
143+ if (outputs_ ) {
144+ noutputs = array_len (outputs_ );
145+ for (size_t i = 0 ; i < noutputs ; i ++ ) {
146+ RedisModule_Free (outputs_ [i ]);
147+ }
148+ array_free (outputs_ );
149+ }
150+ return NULL ;
101151}
102152
103153void RAI_ModelFreeTorch (RAI_Model * model , RAI_Error * error ) {
@@ -107,6 +157,18 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
107157 if (model -> data ) {
108158 RedisModule_Free (model -> data );
109159 }
160+ size_t ninputs = model -> ninputs ;
161+ for (size_t i = 0 ; i < ninputs ; i ++ ) {
162+ RedisModule_Free (model -> inputs [i ]);
163+ }
164+ array_free (model -> inputs );
165+
166+ size_t noutputs = model -> noutputs ;
167+ for (size_t i = 0 ; i < noutputs ; i ++ ) {
168+ RedisModule_Free (model -> outputs [i ]);
169+ }
170+ array_free (model -> outputs );
171+
110172 torchDeallocContext (model -> model );
111173}
112174
0 commit comments