@@ -22,6 +22,9 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
2222 RAI_Device device = RAI_DEVICE_CPU ;
2323 int64_t deviceid = 0 ;
2424
25+ char * * inputs_ = NULL ;
26+ char * * outputs_ = NULL ;
27+
2528 if (!parseDeviceStr (devicestr , & device , & deviceid )) {
2629 RAI_SetError (error , RAI_EMODELCONFIGURE , "ERR unsupported device" );
2730 return NULL ;
@@ -53,7 +56,7 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
5356 if (opts .backends_intra_op_parallelism > 0 ) {
5457 torchSetIntraOpThreads (opts .backends_intra_op_parallelism , & error_descr , RedisModule_Alloc );
5558 }
56- if (error_descr != NULL ) {
59+ if (error_descr ) {
5760 RAI_SetError (error , RAI_EMODELCREATE , error_descr );
5861 RedisModule_Free (error_descr );
5962 return NULL ;
@@ -62,10 +65,37 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
6265 void * model =
6366 torchLoadModel (modeldef , modellen , dl_device , deviceid , & error_descr , RedisModule_Alloc );
6467
65- if (model == NULL ) {
66- RAI_SetError (error , RAI_EMODELCREATE , error_descr );
67- RedisModule_Free (error_descr );
68- return NULL ;
68+ if (error_descr ) {
69+ goto cleanup ;
70+ }
71+
72+ size_t ninputs = torchModelNumInputs (model , & error_descr );
73+ if (error_descr ) {
74+ goto cleanup ;
75+ }
76+
77+ size_t noutputs = torchModelNumOutputs (model , & error_descr );
78+ if (error_descr ) {
79+ goto cleanup ;
80+ }
81+
82+ inputs_ = array_new (char * , ninputs );
83+ outputs_ = array_new (char * , noutputs );
84+
85+ for (size_t i = 0 ; i < ninputs ; i ++ ) {
86+ const char * input = torchModelInputNameAtIndex (model , i , & error_descr );
87+ if (error_descr ) {
88+ goto cleanup ;
89+ }
90+ inputs_ = array_append (inputs_ , RedisModule_Strdup (input ));
91+ }
92+
93+ for (size_t i = 0 ; i < noutputs ; i ++ ) {
94+ const char * output = "" ;
95+ if (error_descr ) {
96+ goto cleanup ;
97+ }
98+ outputs_ = array_append (outputs_ , RedisModule_Strdup (output ));
6999 }
70100
71101 char * buffer = RedisModule_Calloc (modellen , sizeof (* buffer ));
@@ -76,14 +106,34 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char *devicestr, RAI_
76106 ret -> session = NULL ;
77107 ret -> backend = backend ;
78108 ret -> devicestr = RedisModule_Strdup (devicestr );
79- ret -> inputs = NULL ;
80- ret -> outputs = NULL ;
109+ ret -> ninputs = ninputs ;
110+ ret -> inputs = inputs_ ;
111+ ret -> noutputs = noutputs ;
112+ ret -> outputs = outputs_ ;
81113 ret -> opts = opts ;
82114 ret -> refCount = 1 ;
83115 ret -> data = buffer ;
84116 ret -> datalen = modellen ;
85-
86117 return ret ;
118+
119+ cleanup :
120+ RAI_SetError (error , RAI_EMODELCREATE , error_descr );
121+ RedisModule_Free (error_descr );
122+ if (inputs_ ) {
123+ ninputs = array_len (inputs_ );
124+ for (size_t i = 0 ; i < ninputs ; i ++ ) {
125+ RedisModule_Free (inputs_ [i ]);
126+ }
127+ array_free (inputs_ );
128+ }
129+ if (outputs_ ) {
130+ noutputs = array_len (outputs_ );
131+ for (size_t i = 0 ; i < noutputs ; i ++ ) {
132+ RedisModule_Free (outputs_ [i ]);
133+ }
134+ array_free (outputs_ );
135+ }
136+ return NULL ;
87137}
88138
89139void RAI_ModelFreeTorch (RAI_Model * model , RAI_Error * error ) {
@@ -93,6 +143,18 @@ void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error) {
93143 if (model -> data ) {
94144 RedisModule_Free (model -> data );
95145 }
146+ size_t ninputs = model -> ninputs ;
147+ for (size_t i = 0 ; i < ninputs ; i ++ ) {
148+ RedisModule_Free (model -> inputs [i ]);
149+ }
150+ array_free (model -> inputs );
151+
152+ size_t noutputs = model -> noutputs ;
153+ for (size_t i = 0 ; i < noutputs ; i ++ ) {
154+ RedisModule_Free (model -> outputs [i ]);
155+ }
156+ array_free (model -> outputs );
157+
96158 torchDeallocContext (model -> model );
97159}
98160
0 commit comments