2121#include "util/arr_rm_alloc.h"
2222#include "util/dict.h"
2323#include "util/queue.h"
24- #include <ctype.h>
25- #include <errno.h>
2624#include <pthread.h>
2725#include <stdio.h>
28- #include <stdlib.h>
2926#include <string.h>
3027#include <unistd.h>
28+ #include <errno.h>
29+ #include <stdlib.h>
30+ #include <ctype.h>
3131
3232int freeRunQueueInfo (RunQueueInfo * info ) {
3333 int result = REDISMODULE_OK ;
@@ -132,7 +132,6 @@ void *RedisAI_Run_ThreadMain(void *arg) {
132132 // There might be more than one thread operating on the same
133133 // queue, according to the THREADS_PER_QUEUE config variable.
134134 long long run_queue_len = queueLength (run_queue_info -> run_queue );
135-
136135 while (run_queue_len > 0 ) {
137136 // We first peek the front of the queue
138137 queueItem * item = queueFront (run_queue_info -> run_queue );
@@ -176,15 +175,17 @@ void *RedisAI_Run_ThreadMain(void *arg) {
176175 if (timedOut == 1 ) {
177176 queueEvict (run_queue_info -> run_queue , item );
178177
178+ RedisAI_RunInfo * orig = rinfo -> orig_copy ;
179179 long long dagRefCount = RAI_DagRunInfoFreeShallowCopy (rinfo );
180- if (dagRefCount == 0 && rinfo -> client ) {
181- RedisModule_UnblockClient (rinfo -> client , rinfo );
180+ if (dagRefCount == 0 ) {
181+ RedisAI_OnFinishCtx finish_ctx = (RedisAI_RunInfo * )orig ;
182+ orig -> OnFinish (finish_ctx , orig -> private_data );
182183 }
183184
184185 queueItem * evicted_item = item ;
185186 item = item -> next ;
186187 RedisModule_Free (evicted_item );
187-
188+ // Continue with the next item in queue (if exists)
188189 continue ;
189190 }
190191 }
@@ -244,9 +245,9 @@ void *RedisAI_Run_ThreadMain(void *arg) {
244245 int currentOpReady , currentOpBatchable ;
245246 RedisAI_DagCurrentOpInfo (rinfo , & currentOpReady , & currentOpBatchable );
246247
247- // If any of the inputs of the current op is not in the context, it
248- // means that some parent ops did not execute. In this case we don't
249- // schedule to run, but we will place the entry back on the queue
248+ // If any of the inputs of the current op is not in the context, it means
249+ // that some parent ops did not execute. In this case we don't schedule
250+ // to run, but we will place the entry back on the queue
250251 if (currentOpReady == 0 ) {
251252 do_run = 0 ;
252253 do_retry = 1 ;
@@ -256,18 +257,16 @@ void *RedisAI_Run_ThreadMain(void *arg) {
256257 // If we made it this far, we will run the currentOp
257258 do_run = 1 ;
258259
259- // If the current op is not batchable (that is, if it's not a modelrun
260- // or if it's a modelrun but batchsize was set to 0), we stop looking
261- // further
260+ // If the current op is not batchable (that is, if it's not a modelrun or
261+ // if it's a modelrun but batchsize was set to 0), we stop looking further
262262 if (currentOpBatchable == 0 ) {
263263 break ;
264264 }
265265
266266 // If we are here, then we scheduled to run and we currently have an
267267 // operation that can be batched.
268268
269- // Since the current op can be batched, then we collect info on
270- // batching, namely
269+ // Since the current op can be batched, then we collect info on batching, namely
271270 // - batchsize
272271 // - minbatchsize
273272 // - minbatchtimeout
@@ -276,8 +275,8 @@ void *RedisAI_Run_ThreadMain(void *arg) {
276275 RedisAI_DagOpBatchInfo (rinfo , currentOp , & batchsize , & minbatchsize ,
277276 & minbatchtimeout , & inbatchsize );
278277
279- // Get the size of the batch so far, that is, the size of the first
280- // input tensor in the 0-th dimension
278+ // Get the size of the batch so far, that is, the size of the first input
279+ // tensor in the 0-th dimension
281280 size_t current_batchsize = inbatchsize ;
282281
283282 // If the size is zero or if it already exceeds the desired batch size
@@ -396,8 +395,8 @@ void *RedisAI_Run_ThreadMain(void *arg) {
396395 RedisAI_DagRunSessionStep (batch_rinfo [0 ], run_queue_info -> devicestr );
397396 }
398397
399- // Lock the queue again: we're done operating on evicted items only, we
400- // need to update the queue with the new information after run
398+ // Lock the queue again: we're done operating on evicted items only, we need
399+ // to update the queue with the new information after run
401400 pthread_mutex_lock (& run_queue_info -> run_queue_mutex );
402401
403402 // Run is over, now iterate over the run info structs in the batch
@@ -415,9 +414,11 @@ void *RedisAI_Run_ThreadMain(void *arg) {
415414 // If there was an error and the reference count for the dag
416415 // has gone to zero and the client is still around, we unblock
417416 if (dagError ) {
417+ RedisAI_RunInfo * orig = rinfo -> orig_copy ;
418418 long long dagRefCount = RAI_DagRunInfoFreeShallowCopy (rinfo );
419- if (dagRefCount == 0 && rinfo -> client ) {
420- RedisModule_UnblockClient (rinfo -> client , rinfo );
419+ if (dagRefCount == 0 ) {
420+ RedisAI_OnFinishCtx finish_ctx = (RedisAI_RunInfo * )orig ;
421+ orig -> OnFinish (finish_ctx , orig -> private_data );
421422 }
422423 } else {
423424 rinfo -> dagDeviceCompleteOpCount += 1 ;
@@ -426,29 +427,30 @@ void *RedisAI_Run_ThreadMain(void *arg) {
426427 }
427428 }
428429
429- // We initialize variables where we'll store the fact hat, after the
430- // current run, all ops for the device or all ops in the dag could be
431- // complete. This way we can avoid placing the op back on the queue if
432- // there's nothing left to do.
430+ // We initialize variables where we'll store the fact hat, after the current
431+ // run, all ops for the device or all ops in the dag could be complete. This
432+ // way we can avoid placing the op back on the queue if there's nothing left
433+ // to do.
433434 int device_complete_after_run = RedisAI_DagDeviceComplete (batch_rinfo [0 ]);
434435 int dag_complete_after_run = RedisAI_DagComplete (batch_rinfo [0 ]);
435436
436437 long long dagRefCount = -1 ;
437-
438+ RedisAI_RunInfo * orig ;
438439 if (device_complete == 1 || device_complete_after_run == 1 ) {
439440 RedisAI_RunInfo * evicted_rinfo = (RedisAI_RunInfo * )(evicted_items [0 ]-> value );
440- // We decrease and get the reference count for the DAG
441+ orig = evicted_rinfo -> orig_copy ;
442+ // We decrease and get the reference count for the DAG.
441443 dagRefCount = RAI_DagRunInfoFreeShallowCopy (evicted_rinfo );
442444 }
443445
444446 // If the DAG was complete, then it's time to unblock the client
445447 if (do_unblock == 1 || dag_complete_after_run == 1 ) {
446- RedisAI_RunInfo * evicted_rinfo = (RedisAI_RunInfo * )(evicted_items [0 ]-> value );
447448
448- // If the reference count for the DAG is zero and the client is still
449- // around, then we actually unblock the client
450- if (dagRefCount == 0 && evicted_rinfo -> client ) {
451- RedisModule_UnblockClient (evicted_rinfo -> client , evicted_rinfo );
449+ // If the reference count for the DAG is zero and the client is still around,
450+ // then we actually unblock the client
451+ if (dagRefCount == 0 ) {
452+ RedisAI_OnFinishCtx finish_ctx = (RedisAI_RunInfo * )orig ;
453+ orig -> OnFinish (finish_ctx , orig -> private_data );
452454 }
453455 }
454456
@@ -463,31 +465,28 @@ void *RedisAI_Run_ThreadMain(void *arg) {
463465 queueItem * next_item = queuePop (run_queue_info -> run_queue );
464466 RedisAI_RunInfo * next_rinfo = (RedisAI_RunInfo * )next_item -> value ;
465467 // Push the DAG to the front of the queue, and then the item we just
466- // popped in front of it, so that it becomes the first item in the
467- // queue. The rationale is, since the DAG needs to wait for other
468- // workers, we are giving way to the next item and we'll get back to
469- // the DAG when that is done
468+ // popped in front of it, so that it becomes the first item in the queue.
469+ // The rationale is, since the DAG needs to wait for other workers, we are
470+ // giving way to the next item and we'll get back to the DAG when that is done
470471 queuePushFront (run_queue_info -> run_queue , evicted_rinfo );
471472 queuePushFront (run_queue_info -> run_queue , next_rinfo );
472473 }
473474 // If there's nothing else in the queue
474475 else {
475476 // We push the DAG back at the front
476477 queuePushFront (run_queue_info -> run_queue , evicted_rinfo );
477- // Since there's nothing else on the queue we just break out and give
478- // other workers a chance to produce the inputs needed for this DAG
479- // step
478+ // Since there's nothing else on the queue we just break out and give other
479+ // workers a chance to produce the inputs needed for this DAG step
480480 break ;
481481 }
482482 }
483483
484- // If the op was ran successfully and without any error, then put the
485- // entry back on the queue unless all ops for the device have been
486- // executed
484+ // If the op was ran successfully and without any error, then put the entry back
485+ // on the queue unless all ops for the device have been executed
487486 if (do_run == 1 && run_error == 0 ) {
488487 // Here we iterate backwards to keep the first evicted on top
489- // A side effect of this is that we are potentially changing priority in
490- // the queue We could solve this using a priority queue, TODO for later
488+ // A side effect of this is that we are potentially changing priority in the queue
489+ // We could solve this using a priority queue, TODO for later
491490 for (long long i = array_len (evicted_items ) - 1 ; i >= 0 ; i -- ) {
492491 // Get the current evicted run info
493492 RedisAI_RunInfo * evicted_rinfo = (RedisAI_RunInfo * )(evicted_items [i ]-> value );
@@ -500,24 +499,22 @@ void *RedisAI_Run_ThreadMain(void *arg) {
500499 }
501500 }
502501
503- // TODO now we can figure out of the device is complete or the dag is
504- // complete if (dag_complete_op_count == evicted_rinfo[0]->dagOpCount) ->
505- // ublock, free if (dag_device_complete_op_count ==
506- // evicted_rinfo[0]->dagDeviceOpCount) -> device complete
502+ // TODO now we can figure out of the device is complete or the dag is complete
503+ // if (dag_complete_op_count == evicted_rinfo[0]->dagOpCount) -> ublock, free
504+ // if (dag_device_complete_op_count == evicted_rinfo[0]->dagDeviceOpCount) -> device
505+ // complete
507506
508- // If there's nothing else to do for the DAG in the current worker or if
509- // an error occurred in any worker, we just move on
507+ // If there's nothing else to do for the DAG in the current worker or if an error
508+ // occurred in any worker, we just move on
510509 if (device_complete == 1 || device_complete_after_run == 1 || do_unblock == 1 ||
511510 run_error == 1 ) {
512511 for (long long i = 0 ; i < array_len (evicted_items ); i ++ ) {
513512 RedisModule_Free (evicted_items [i ]);
514513 }
515514 }
516-
517515 run_queue_len = queueLength (run_queue_info -> run_queue );
518516 }
519517 }
520-
521518 array_free (evicted_items );
522519 array_free (batch_rinfo );
523520}
0 commit comments