@@ -237,53 +237,52 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
237237 uint n_inkeys = array_len (currentOp -> inkeys );
238238 uint n_outkeys = array_len (currentOp -> outkeys );
239239
240- RAI_ContextReadLock ( rinfo );
240+ if (! rinfo -> single_op_dag ) {
241241
242- RAI_Tensor * inputTensors [n_inkeys ];
243- for (uint i = 0 ; i < n_inkeys ; i ++ ) {
244- RAI_Tensor * inputTensor ;
245- const int get_result = RAI_getTensorFromLocalContext (
246- NULL , rinfo -> dagTensorsContext , currentOp -> inkeys [i ], & inputTensor , currentOp -> err );
247- if (get_result == REDISMODULE_ERR ) {
248- // We check for this outside the function
249- // this check cannot be covered by tests
250- currentOp -> result = REDISMODULE_ERR ;
251- RAI_ContextUnlock (rinfo );
252- return ;
242+ RAI_ContextReadLock (rinfo );
243+ RAI_Tensor * inputTensors [n_inkeys ];
244+ for (uint i = 0 ; i < n_inkeys ; i ++ ) {
245+ RAI_Tensor * inputTensor ;
246+ const int get_result = RAI_getTensorFromLocalContext (
247+ NULL , rinfo -> dagTensorsContext , currentOp -> inkeys [i ], & inputTensor , currentOp -> err );
248+ if (get_result == REDISMODULE_ERR ) {
249+ // We check for this outside the function
250+ // this check cannot be covered by tests
251+ currentOp -> result = REDISMODULE_ERR ;
252+ RAI_ContextUnlock (rinfo );
253+ return ;
254+ }
255+ inputTensors [i ] = inputTensor ;
253256 }
254- inputTensors [i ] = inputTensor ;
255- }
256-
257- RAI_ContextUnlock (rinfo );
258-
259- for (uint i = 0 ; i < n_inkeys ; i ++ ) {
260- RAI_ScriptRunCtxAddInput (currentOp -> sctx , inputTensors [i ], currentOp -> err );
261- }
257+ RAI_ContextUnlock (rinfo );
262258
263- for (uint i = 0 ; i < n_outkeys ; i ++ ) {
264- RAI_ScriptRunCtxAddOutput (currentOp -> sctx );
259+ for (uint i = 0 ; i < n_inkeys ; i ++ ) {
260+ RAI_ScriptRunCtxAddInput (currentOp -> sctx , inputTensors [i ], currentOp -> err );
261+ }
262+ for (uint i = 0 ; i < n_outkeys ; i ++ ) {
263+ RAI_ScriptRunCtxAddOutput (currentOp -> sctx );
264+ }
265265 }
266266
267267 const long long start = ustime ();
268268 int result = RAI_ScriptRun (currentOp -> sctx , currentOp -> err );
269269 const long long end = ustime ();
270270
271- RAI_ContextWriteLock (rinfo );
272-
273- const size_t noutputs = RAI_ScriptRunCtxNumOutputs (currentOp -> sctx );
274- for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
275- RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (currentOp -> sctx , outputNumber );
276- RedisModuleString * key_string = currentOp -> outkeys [outputNumber ];
277- tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
278- AI_dictReplace (rinfo -> dagTensorsContext , (void * )key_string , tensor );
279- }
280-
281271 currentOp -> result = result ;
282272 currentOp -> duration_us = end - start ;
283273
284- RAI_ContextUnlock ( rinfo );
274+ if (! rinfo -> single_op_dag ) {
285275
286- return ;
276+ RAI_ContextWriteLock (rinfo );
277+ const size_t noutputs = RAI_ScriptRunCtxNumOutputs (currentOp -> sctx );
278+ for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
279+ RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (currentOp -> sctx , outputNumber );
280+ RedisModuleString * key_string = currentOp -> outkeys [outputNumber ];
281+ tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
282+ AI_dictReplace (rinfo -> dagTensorsContext , (void * )key_string , tensor );
283+ }
284+ RAI_ContextUnlock (rinfo );
285+ }
287286}
288287
289288size_t RAI_DagOpBatchSize (RAI_DagOp * op , RedisAI_RunInfo * rinfo ) {
@@ -572,17 +571,16 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
572571 return ret ;
573572}
574573
575- static void PersistTensors (RedisModuleCtx * ctx , RedisAI_RunInfo * rinfo ) {
574+ static void _PersistTensors (RedisModuleCtx * ctx , RedisAI_RunInfo * rinfo ) {
575+
576576 AI_dictIterator * persist_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
577577 AI_dictEntry * persist_entry = AI_dictNext (persist_iter );
578+
578579 while (persist_entry ) {
579580 RedisModuleString * persist_key_name = AI_dictGetKey (persist_entry );
580-
581581 AI_dictEntry * tensor_entry = AI_dictFind (rinfo -> dagTensorsContext , persist_key_name );
582-
583582 if (tensor_entry ) {
584583 RAI_Tensor * tensor = AI_dictGetVal (tensor_entry );
585-
586584 if (tensor == NULL ) {
587585 persist_entry = AI_dictNext (persist_iter );
588586 continue ;
@@ -594,17 +592,17 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
594592 RedisModule_ReplyWithError (ctx ,
595593 "ERR specified persistent key that was not used in DAG" );
596594 rinfo -> dagReplyLength ++ ;
597-
598595 RedisModule_Log (ctx , "warning" ,
599- "on DAGRUN's PERSIST pecified persistent key (%s) that "
596+ "on DAGRUN's PERSIST specified persistent key (%s) that "
600597 "was not used on DAG. Logging all local context keys" ,
601- persist_key_name );
598+ RedisModule_StringPtrLen ( persist_key_name , NULL ) );
602599 AI_dictIterator * local_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
603600 AI_dictEntry * local_entry = AI_dictNext (local_iter );
601+
604602 while (local_entry ) {
605603 RedisModuleString * localcontext_key_name = AI_dictGetKey (local_entry );
606604 RedisModule_Log (ctx , "warning" , "DAG's local context key (%s)" ,
607- localcontext_key_name );
605+ RedisModule_StringPtrLen ( localcontext_key_name , NULL ) );
608606 local_entry = AI_dictNext (local_iter );
609607 }
610608 AI_dictReleaseIterator (local_iter );
@@ -619,7 +617,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
619617 AI_dictReleaseIterator (persist_iter );
620618}
621619
622- static void ModelSingleOp_PersistTensors (RedisModuleCtx * ctx , RAI_DagOp * op ) {
620+ static void _ModelSingleOp_PersistTensors (RedisModuleCtx * ctx , RAI_DagOp * op ) {
623621 const size_t noutputs = RAI_ModelRunCtxNumOutputs (op -> mctx );
624622 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
625623 RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (op -> mctx , outputNumber );
@@ -629,6 +627,16 @@ static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
629627 }
630628}
631629
630+ static void _ScriptSingleOp_PersistTensors (RedisModuleCtx * ctx , RAI_DagOp * op ) {
631+ const size_t noutputs = RAI_ScriptRunCtxNumOutputs (op -> sctx );
632+ for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
633+ RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (op -> sctx , outputNumber );
634+ tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
635+ if (tensor )
636+ _StoreTensorInKeySpace (ctx , tensor , op -> outkeys [outputNumber ], false);
637+ }
638+ }
639+
632640int RedisAI_DagRun_Reply (RedisModuleCtx * ctx , RedisModuleString * * argv , int argc ) {
633641 REDISMODULE_NOT_USED (argv );
634642 REDISMODULE_NOT_USED (argc );
@@ -650,7 +658,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
650658 return REDISMODULE_OK ;
651659 }
652660
653- if (rinfo -> single_op_dag == 0 ) {
661+ if (! rinfo -> single_op_dag ) {
654662 RedisModule_ReplyWithArray (ctx , REDISMODULE_POSTPONED_ARRAY_LEN );
655663 }
656664
@@ -745,18 +753,20 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
745753 return REDISMODULE_ERR ;
746754 }
747755
748- // TODO: Take care of script single op
749- if (rinfo -> single_op_dag == 0 || rinfo -> dagOps [0 ]-> commandType == REDISAI_DAG_CMD_SCRIPTRUN ) {
756+ if (!rinfo -> single_op_dag ) {
750757 // Save the required tensors in redis key space.
751- PersistTensors (ctx , rinfo );
752- if (rinfo -> single_op_dag == 0 )
753- RedisModule_ReplySetArrayLength (ctx , rinfo -> dagReplyLength );
758+ _PersistTensors (ctx , rinfo );
759+ RedisModule_ReplySetArrayLength (ctx , rinfo -> dagReplyLength );
754760 } else {
755- ModelSingleOp_PersistTensors (ctx , rinfo -> dagOps [0 ]);
761+ if (rinfo -> dagOps [0 ]-> commandType == REDISAI_DAG_CMD_MODELRUN ) {
762+ _ModelSingleOp_PersistTensors (ctx , rinfo -> dagOps [0 ]);
763+ } else {
764+ RedisModule_Assert (rinfo -> dagOps [0 ]-> commandType == REDISAI_DAG_CMD_SCRIPTRUN );
765+ _ScriptSingleOp_PersistTensors (ctx , rinfo -> dagOps [0 ]);
766+ }
756767 }
757768
758769 RAI_FreeRunInfo (rinfo );
759-
760770 return REDISMODULE_OK ;
761771}
762772
0 commit comments