@@ -91,7 +91,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
9191
9292static void Dag_StoreOutputsFromModelRunCtx (RedisAI_RunInfo * rinfo , RAI_DagOp * currentOp ) {
9393
94- RAI_ContextReadLock (rinfo );
94+ RAI_ContextWriteLock (rinfo );
9595 const size_t noutputs = RAI_ModelRunCtxNumOutputs (currentOp -> mctx );
9696 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
9797 RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (currentOp -> mctx , outputNumber );
@@ -177,6 +177,9 @@ void RedisAI_BatchedDagRunSession_ModelRun_Step(RedisAI_RunInfo **batched_rinfo,
177177 if (rinfo -> single_op_dag == 0 )
178178 Dag_StoreOutputsFromModelRunCtx (rinfo , currentOp );
179179 }
180+ // Clear the result in case of an error.
181+ if (result == REDISMODULE_ERR )
182+ RAI_ClearError (& err );
180183}
181184
182185/**
@@ -346,16 +349,20 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
346349 return 1 ;
347350}
348351
349- int RedisAI_DagDeviceComplete (RedisAI_RunInfo * rinfo ) {
352+ bool RedisAI_DagDeviceComplete (RedisAI_RunInfo * rinfo ) {
350353 return rinfo -> dagDeviceCompleteOpCount == rinfo -> dagDeviceOpCount ;
351354}
352355
353- int RedisAI_DagComplete (RedisAI_RunInfo * rinfo ) {
356+ bool RedisAI_DagComplete (RedisAI_RunInfo * rinfo ) {
354357 int completeOpCount = __atomic_load_n (rinfo -> dagCompleteOpCount , __ATOMIC_RELAXED );
355358
356359 return completeOpCount == rinfo -> dagOpCount ;
357360}
358361
362+ bool RedisAI_DagError (RedisAI_RunInfo * rinfo ) {
363+ return __atomic_load_n (rinfo -> dagError , __ATOMIC_RELAXED ) != 0 ;
364+ }
365+
359366RAI_DagOp * RedisAI_DagCurrentOp (RedisAI_RunInfo * rinfo ) {
360367 if (rinfo -> dagDeviceCompleteOpCount == rinfo -> dagDeviceOpCount ) {
361368 return NULL ;
@@ -364,21 +371,21 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) {
364371 return rinfo -> dagDeviceOps [rinfo -> dagDeviceCompleteOpCount ];
365372}
366373
367- void RedisAI_DagCurrentOpInfo (RedisAI_RunInfo * rinfo , int * currentOpReady ,
368- int * currentOpBatchable ) {
374+ void RedisAI_DagCurrentOpInfo (RedisAI_RunInfo * rinfo , bool * currentOpReady ,
375+ bool * currentOpBatchable ) {
369376 RAI_DagOp * currentOp_ = RedisAI_DagCurrentOp (rinfo );
370377
371- * currentOpReady = 0 ;
372- * currentOpBatchable = 0 ;
378+ * currentOpReady = false ;
379+ * currentOpBatchable = false ;
373380
374381 if (currentOp_ == NULL ) {
375382 return ;
376383 }
377384
378385 if (currentOp_ -> mctx && currentOp_ -> mctx -> model -> opts .batchsize > 0 ) {
379- * currentOpBatchable = 1 ;
386+ * currentOpBatchable = true ;
380387 }
381- * currentOpReady = 1 ;
388+ * currentOpReady = true ;
382389 // If this is a single op dag, the op is definitely ready.
383390 if (rinfo -> single_op_dag == 1 )
384391 return ;
@@ -389,7 +396,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
389396 for (int i = 0 ; i < n_inkeys ; i ++ ) {
390397 if (AI_dictFind (rinfo -> dagTensorsContext , currentOp_ -> inkeys [i ]) == NULL ) {
391398 RAI_ContextUnlock (rinfo );
392- * currentOpReady = 0 ;
399+ * currentOpReady = false ;
393400 return ;
394401 }
395402 }
@@ -577,7 +584,6 @@ static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
577584 const size_t noutputs = RAI_ModelRunCtxNumOutputs (op -> mctx );
578585 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
579586 RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (op -> mctx , outputNumber );
580- tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
581587 if (tensor )
582588 _StoreTensorInKeySpace (ctx , tensor , op -> outkeys [outputNumber ], false);
583589 }
@@ -587,7 +593,6 @@ static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
587593 const size_t noutputs = RAI_ScriptRunCtxNumOutputs (op -> sctx );
588594 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
589595 RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (op -> sctx , outputNumber );
590- tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
591596 if (tensor )
592597 _StoreTensorInKeySpace (ctx , tensor , op -> outkeys [outputNumber ], false);
593598 }
@@ -600,7 +605,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
600605
601606 if (RAI_GetErrorCode (rinfo -> err ) == RAI_EDAGRUN ) {
602607 RedisModule_ReplyWithError (ctx , RAI_GetErrorOneLine (rinfo -> err ));
603- RAI_FreeRunInfo (rinfo );
604608 return REDISMODULE_ERR ;
605609 }
606610 int dag_error = 0 ;
@@ -610,7 +614,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
610614
611615 if (* rinfo -> timedOut ) {
612616 RedisModule_ReplyWithSimpleString (ctx , "TIMEDOUT" );
613- RAI_FreeRunInfo (rinfo );
614617 return REDISMODULE_OK ;
615618 }
616619
@@ -701,7 +704,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
701704 if (rinfo -> single_op_dag == 0 ) {
702705 RedisModule_ReplySetArrayLength (ctx , rinfo -> dagReplyLength );
703706 }
704- RAI_FreeRunInfo (rinfo );
705707 return REDISMODULE_ERR ;
706708 }
707709
@@ -718,7 +720,6 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
718720 }
719721 }
720722
721- RAI_FreeRunInfo (rinfo );
722723 return REDISMODULE_OK ;
723724}
724725
@@ -746,11 +747,7 @@ int RedisAI_DagRun_IsKeysPositionRequest_ReportKeys(RedisModuleCtx *ctx, RedisMo
746747 return REDISMODULE_OK ;
747748}
748749
749- void RunInfo_FreeData (RedisModuleCtx * ctx , void * rinfo ) {}
750-
751- void RedisAI_Disconnected (RedisModuleCtx * ctx , RedisModuleBlockedClient * bc ) {
752- RedisModule_Log (ctx , "warning" , "Blocked client %p disconnected!" , (void * )bc );
753- }
750+ void RunInfo_FreeData (RedisModuleCtx * ctx , void * rinfo ) { RAI_FreeRunInfo (rinfo ); }
754751
755752// Add Shallow copies of the DAG run info to the devices' queues.
756753// Return REDISMODULE_OK in case of success, REDISMODULE_ERR if (at least) one insert op had
0 commit comments