66#include "model.h"
77
88#include "tensorflow/c/c_api.h"
9+ #include "tensorflow/c/eager/c_api.h"
10+
11+ #define RAI_TF_FN_NAME "rai_tf_forward"
912
1013int RAI_InitBackendTF (int (* get_api_fn )(const char * , void * )) {
1114 get_api_fn ("RedisModule_Alloc" , ((void * * )& RedisModule_Alloc ));
@@ -223,19 +226,15 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
223226 RAI_SetError (error , RAI_EMODELIMPORT , "ERR unsupported device" );
224227 }
225228
226- TF_Graph * model = TF_NewGraph ();
229+ TF_Graph * graph = TF_NewGraph ();
230+ TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
227231 TF_Status * status = TF_NewStatus ();
228232 TF_Buffer * tfbuffer = TF_NewBuffer ();
229- TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
230- TF_Status * optionsStatus = NULL ;
231- TF_SessionOptions * sessionOptions = NULL ;
232- TF_Status * sessionStatus = NULL ;
233- TF_Session * session = NULL ;
234233
235234 tfbuffer -> length = modellen ;
236235 tfbuffer -> data = modeldef ;
237236
238- TF_GraphImportGraphDef (model , tfbuffer , options , status );
237+ TF_GraphImportGraphDef (graph , tfbuffer , options , status );
239238
240239 if (TF_GetCode (status ) != TF_OK ) {
241240 char * errorMessage = RedisModule_Strdup (TF_Message (status ));
@@ -245,26 +244,26 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
245244 }
246245
247246 for (size_t i = 0 ; i < ninputs ; ++ i ) {
248- TF_Operation * oper = TF_GraphOperationByName (model , inputs [i ]);
247+ TF_Operation * oper = TF_GraphOperationByName (graph , inputs [i ]);
249248 if (oper == NULL || strcmp (TF_OperationOpType (oper ), "Placeholder" ) != 0 ) {
250249 size_t len = strlen (inputs [i ]);
251250 char * msg = RedisModule_Calloc (60 + len , sizeof (* msg ));
252251 sprintf (msg , "ERR Input node named \"%s\" not found in TF graph." , inputs [i ]);
253252 RAI_SetError (error , RAI_EMODELIMPORT , msg );
254253 RedisModule_Free (msg );
255- goto cleanup ;
254+ return NULL ;
256255 }
257256 }
258257
259258 for (size_t i = 0 ; i < noutputs ; ++ i ) {
260- TF_Operation * oper = TF_GraphOperationByName (model , outputs [i ]);
259+ TF_Operation * oper = TF_GraphOperationByName (graph , outputs [i ]);
261260 if (oper == NULL ) {
262261 size_t len = strlen (outputs [i ]);
263262 char * msg = RedisModule_Calloc (60 + len , sizeof (* msg ));
264263 sprintf (msg , "ERR Output node named \"%s\" not found in TF graph" , outputs [i ]);
265264 RAI_SetError (error , RAI_EMODELIMPORT , msg );
266265 RedisModule_Free (msg );
267- goto cleanup ;
266+ return NULL ;
268267 }
269268 }
270269
@@ -275,6 +274,65 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
275274 TF_DeleteStatus (status );
276275 status = NULL ;
277276
277+ TF_Output tf_inputs [ninputs ];
278+ TF_Output tf_outputs [noutputs ];
279+
280+ for (size_t i = 0 ; i < ninputs ; ++ i ) {
281+ TF_Output port ;
282+ port .oper = TF_GraphOperationByName (graph , inputs [i ]);
283+ port .index = 0 ;
284+ if (port .oper == NULL ) {
285+ return NULL ;
286+ }
287+ tf_inputs [i ] = port ;
288+ }
289+
290+ for (size_t i = 0 ; i < noutputs ; ++ i ) {
291+ TF_Output port ;
292+ port .oper = TF_GraphOperationByName (graph , outputs [i ]);
293+ port .index = 0 ;
294+ if (port .oper == NULL ) {
295+ return NULL ;
296+ }
297+ tf_outputs [i ] = port ;
298+ }
299+
300+ TF_Function * function = TF_GraphToFunction (
301+ graph , // fn_body
302+ RAI_TF_FN_NAME , 0 , // fn_name, append_hash_to_fn_name,
303+ -1 , NULL , // num_opers, opers
304+ ninputs , tf_inputs , // ninputs, inputs,
305+ noutputs , tf_outputs , // noutputs, outputs
306+ outputs , // output_names,
307+ NULL , // opts
308+ "" , // description
309+ status // status
310+ );
311+ // TODO EAGER
312+ // check status and return error
313+
314+ TFE_ContextOptions * context_opts = TFE_NewContextOptions ();
315+ // TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
316+ // TFE_ContextOptionsSetAsync(context_opts, 0);
317+ TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
318+
319+ TFE_Context * context = TFE_NewContext (context_opts , status );
320+ // TODO EAGER
321+ // check status and return error
322+
323+ TFE_ContextAddFunction (context , function , status );
324+ // TODO EAGER
325+ // check status and return error
326+
327+ TFE_DeleteContextOptions (context_opts );
328+ TFE_DeleteContext (context );
329+
330+ #if 0
331+ TF_Status * optionsStatus = NULL ;
332+ TF_SessionOptions * sessionOptions = NULL ;
333+ TF_Status * sessionStatus = NULL ;
334+ TF_Session * session = NULL ;
335+
278336 optionsStatus = TF_NewStatus ();
279337 sessionOptions = TF_NewSessionOptions ();
280338
@@ -340,7 +398,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
340398 optionsStatus = NULL ;
341399
342400 sessionStatus = TF_NewStatus ();
343- session = TF_NewSession (model , sessionOptions , sessionStatus );
401+ session = TF_NewSession (graph , sessionOptions , sessionStatus );
344402
345403 TF_Status * deviceListStatus = TF_NewStatus ();
346404 TF_DeviceList * deviceList = TF_SessionListDevices (session , deviceListStatus );
@@ -370,6 +428,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
370428
371429 TF_DeleteSessionOptions (sessionOptions );
372430 TF_DeleteStatus (sessionStatus );
431+ #endif
373432
374433 char * * inputs_ = array_new (char * , ninputs );
375434 for (long long i = 0 ; i < ninputs ; i ++ ) {
@@ -385,8 +444,8 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
385444 memcpy (buffer , modeldef , modellen );
386445
387446 RAI_Model * ret = RedisModule_Calloc (1 , sizeof (* ret ));
388- ret -> model = model ;
389- ret -> session = session ;
447+ ret -> model = graph ;
448+ ret -> session = context ;
390449 ret -> backend = backend ;
391450 ret -> devicestr = RedisModule_Strdup (devicestr );
392451 ret -> ninputs = ninputs ;
@@ -401,22 +460,23 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
401460 return ret ;
402461
403462cleanup :
404- TF_DeleteGraph (model );
463+ TF_DeleteGraph (graph );
405464 if (options )
406465 TF_DeleteImportGraphDefOptions (options );
407466 if (tfbuffer )
408467 TF_DeleteBuffer (tfbuffer );
409468 if (status )
410469 TF_DeleteStatus (status );
411- if (sessionOptions )
412- TF_DeleteSessionOptions (sessionOptions );
413- if (sessionStatus )
414- TF_DeleteStatus (sessionStatus );
470+ // if (sessionOptions)
471+ // TF_DeleteSessionOptions(sessionOptions);
472+ // if (sessionStatus)
473+ // TF_DeleteStatus(sessionStatus);
415474 return NULL ;
416475}
417476
418477void RAI_ModelFreeTF (RAI_Model * model , RAI_Error * error ) {
419478 TF_Status * status = TF_NewStatus ();
479+ #if 0
420480 TF_CloseSession (model -> session , status );
421481
422482 if (TF_GetCode (status ) != TF_OK ) {
@@ -425,12 +485,14 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
425485 }
426486
427487 TF_DeleteSession (model -> session , status );
488+ #endif
489+ TFE_DeleteContext (model -> session );
428490 model -> session = NULL ;
429491
430- if (TF_GetCode (status ) != TF_OK ) {
431- RAI_SetError (error , RAI_EMODELFREE , RedisModule_Strdup (TF_Message (status )));
432- return ;
433- }
492+ // if (TF_GetCode(status) != TF_OK) {
493+ // RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
494+ // return;
495+ // }
434496
435497 TF_DeleteGraph (model -> model );
436498 model -> model = NULL ;
@@ -457,7 +519,9 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
457519 RedisModule_Free (model -> data );
458520 }
459521
522+ #if 0
460523 TF_DeleteStatus (status );
524+ #endif
461525}
462526
463527int RAI_ModelRunTF (RAI_ModelRunCtx * * mctxs , RAI_Error * error ) {
@@ -472,9 +536,9 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
472536 const size_t ninputs = array_len (mctxs [0 ]-> inputs );
473537 const size_t noutputs = array_len (mctxs [0 ]-> outputs );
474538 TF_Tensor * inputTensorsValues [ninputs ];
475- TF_Output inputs [ninputs ];
476539 TF_Tensor * outputTensorsValues [noutputs ];
477- TF_Output outputs [noutputs ];
540+ TFE_TensorHandle * inputTensorsHandles [ninputs ];
541+ TFE_TensorHandle * outputTensorsHandles [noutputs ];
478542
479543 size_t batch_sizes [nbatches ];
480544 size_t batch_offsets [nbatches ];
@@ -497,30 +561,28 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
497561 batched_input_tensors [b ] = mctxs [b ]-> inputs [i ].tensor ;
498562 }
499563 inputTensorsValues [i ] = RAI_TFTensorFromTensors (batched_input_tensors , nbatches );
500- TF_Output port ;
501- port .oper = TF_GraphOperationByName (mctxs [0 ]-> model -> model , mctxs [0 ]-> inputs [i ].name );
502- port .index = 0 ;
503- if (port .oper == NULL ) {
504- return 1 ;
505- }
506- inputs [i ] = port ;
564+ inputTensorsHandles [i ] = TFE_NewTensorHandle (inputTensorsValues [i ], status );
565+ // TODO EAGER
566+ // check status and return error
507567 }
508568
509- for (size_t i = 0 ; i < noutputs ; ++ i ) {
510- TF_Output port ;
511- port .oper = TF_GraphOperationByName (mctxs [0 ]-> model -> model , mctxs [0 ]-> outputs [i ].name );
512- port .index = 0 ;
513- if (port .oper == NULL ) {
514- return 1 ;
515- }
516- outputs [i ] = port ;
517- }
569+ TFE_Op * fn_op = TFE_NewOp (mctxs [0 ]-> model -> session , RAI_TF_FN_NAME , status );
570+ // TODO EAGER
571+ // check status and return error
518572
519- TF_SessionRun (mctxs [0 ]-> model -> session , NULL /* run_options */ , inputs , inputTensorsValues ,
520- ninputs , outputs , outputTensorsValues , noutputs , NULL /* target_opers */ ,
521- 0 /* ntargets */ , NULL /* run_Metadata */ , status );
573+ TFE_OpAddInputList (fn_op , inputTensorsHandles , ninputs , status );
574+ // TODO EAGER
575+ // check status and return error
576+
577+ // TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
578+
579+ int noutputs_ = noutputs ;
580+ TFE_Execute (fn_op , outputTensorsHandles , & noutputs_ , status );
581+ // TODO EAGER
582+ // check status and return error
522583
523584 for (size_t i = 0 ; i < ninputs ; ++ i ) {
585+ TFE_DeleteTensorHandle (inputTensorsHandles [i ]);
524586 TF_DeleteTensor (inputTensorsValues [i ]);
525587 }
526588
@@ -532,13 +594,25 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
532594 return 1 ;
533595 }
534596
597+ for (size_t i = 0 ; i < noutputs ; ++ i ) {
598+ outputTensorsValues [i ] = TFE_TensorHandleResolve (outputTensorsHandles [i ], status );
599+
600+ if (TF_GetCode (status ) != TF_OK ) {
601+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
602+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
603+ TF_DeleteStatus (status );
604+ RedisModule_Free (errorMessage );
605+ return 1 ;
606+ }
607+ }
608+
535609 for (size_t i = 0 ; i < noutputs ; ++ i ) {
536610 if (nbatches > 1 ) {
537611 if (TF_NumDims (outputTensorsValues [i ]) == 0 ) {
538612 continue ;
539613 }
540614 if (TF_Dim (outputTensorsValues [i ], 0 ) != total_batch_size ) {
541- TF_DeleteTensor (outputTensorsValues [i ]);
615+ // TF_DeleteTensor(outputTensorsValues[i]);
542616 TF_DeleteStatus (status );
543617 RAI_SetError (error , RAI_EMODELRUN ,
544618 "ERR Model did not generate the expected batch size" );
@@ -553,7 +627,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
553627 mctxs [0 ]-> outputs [i ].tensor =
554628 RAI_TensorCreateFromTFTensor (outputTensorsValues [i ], 0 , -1 );
555629 }
556- TF_DeleteTensor (outputTensorsValues [i ]);
630+ // TF_DeleteTensor(outputTensorsValues[i]);
631+ TFE_DeleteTensorHandle (outputTensorsHandles [i ]);
557632 }
558633
559634 TF_DeleteStatus (status );
0 commit comments