@@ -123,7 +123,8 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
123123 }
124124 rinfo -> dagError = RedisModule_Calloc (1 , sizeof (int ));
125125 rinfo -> dagLock = RedisModule_Alloc (sizeof (pthread_rwlock_t ));
126- rinfo -> dagRefCount = RedisModule_Calloc (1 , sizeof (long long ));
126+ rinfo -> dagRefCount = RedisModule_Alloc (sizeof (long long ));
127+ * (rinfo -> dagRefCount ) = 0 ;
127128 rinfo -> dagOpCount = 0 ;
128129 rinfo -> dagCompleteOpCount = RedisModule_Calloc (1 , sizeof (long long ));
129130 rinfo -> dagDeviceOpCount = 0 ;
@@ -145,6 +146,7 @@ int RAI_ShallowCopyDagRunInfo(RedisAI_RunInfo **result, RedisAI_RunInfo *src) {
145146 if (!(rinfo -> dagDeviceOps )) {
146147 return REDISMODULE_ERR ;
147148 }
149+ (* rinfo -> dagRefCount )++ ;
148150 rinfo -> dagDeviceOpCount = 0 ;
149151 rinfo -> dagDeviceCompleteOpCount = 0 ;
150152 * result = rinfo ;
@@ -197,20 +199,26 @@ void RAI_FreeDagOp(RedisModuleCtx *ctx, RAI_DagOp *dagOp) {
197199 }
198200}
199201
200- void RAI_FreeRunInfo (RedisModuleCtx * ctx , struct RedisAI_RunInfo * rinfo ) {
201- if (!rinfo ) {
202- return ;
202+ long long RAI_DagRunInfoFreeShallowCopy (RedisAI_RunInfo * rinfo ) {
203+ long long ref_count = __atomic_sub_fetch (rinfo -> dagRefCount , 1 , __ATOMIC_RELAXED );
204+ RedisModule_Assert (ref_count >= 0 && "Tried to free the original RunInfo object" );
205+ if (rinfo -> dagDeviceOps ) {
206+ array_free (rinfo -> dagDeviceOps );
203207 }
204- if (* rinfo -> dagRefCount > 0 ) {
205- if (rinfo -> dagDeviceOps ) {
206- array_free (rinfo -> dagDeviceOps );
207- }
208+ // If this is the last run info copy we do not free it, the OnFinish callback may free.
209+ if (ref_count > 0 )
208210 RedisModule_Free (rinfo );
211+ return ref_count ;
212+ }
213+
214+ void RAI_FreeRunInfo (RedisModuleCtx * ctx , struct RedisAI_RunInfo * rinfo ) {
215+ if (!rinfo ) {
209216 return ;
210- } else {
211- pthread_rwlock_destroy (rinfo -> dagLock );
212- RedisModule_Free (rinfo -> dagLock );
213217 }
218+ long long ref_count = * rinfo -> dagRefCount ;
219+ RedisModule_Assert (ref_count == 0 );
220+ pthread_rwlock_destroy (rinfo -> dagLock );
221+ RedisModule_Free (rinfo -> dagLock );
214222
215223 if (rinfo -> dagTensorsContext ) {
216224 AI_dictRelease (rinfo -> dagTensorsContext );
@@ -225,10 +233,6 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
225233 array_free (rinfo -> dagOps );
226234 }
227235
228- if (rinfo -> dagDeviceOps ) {
229- array_free (rinfo -> dagDeviceOps );
230- }
231-
232236 if (rinfo -> dagError ) {
233237 RedisModule_Free (rinfo -> dagError );
234238 }
0 commit comments