Skip to content

Commit 15c6bf1

Browse files
authored
Merge pull request #520 from RedisAI/Make_ref_count_atomic_read
Make atomic read for ref_count in DAG run_info free function.
2 parents 1d173ff + 8654840 commit 15c6bf1

File tree

4 files changed

+31
-25
lines changed

4 files changed

+31
-25
lines changed

src/background_workers.c

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ void *RedisAI_Run_ThreadMain(void *arg) {
176176
if (timedOut == 1) {
177177
queueEvict(run_queue_info->run_queue, item);
178178

179-
int dagRefCount =
180-
__atomic_sub_fetch(rinfo->dagRefCount, 1, __ATOMIC_RELAXED);
179+
long long dagRefCount = RAI_DagRunInfoFreeShallowCopy(rinfo);
181180
if (dagRefCount == 0 && rinfo->client) {
182181
RedisModule_UnblockClient(rinfo->client, rinfo);
183182
}
@@ -416,9 +415,7 @@ void *RedisAI_Run_ThreadMain(void *arg) {
416415
// If there was an error and the reference count for the dag
417416
// has gone to zero and the client is still around, we unblock
418417
if (dagError) {
419-
int dagRefCount =
420-
__atomic_sub_fetch(rinfo->dagRefCount, 1, __ATOMIC_RELAXED);
421-
418+
long long dagRefCount = RAI_DagRunInfoFreeShallowCopy(rinfo);
422419
if (dagRefCount == 0 && rinfo->client) {
423420
RedisModule_UnblockClient(rinfo->client, rinfo);
424421
}
@@ -436,12 +433,12 @@ void *RedisAI_Run_ThreadMain(void *arg) {
436433
int device_complete_after_run = RedisAI_DagDeviceComplete(batch_rinfo[0]);
437434
int dag_complete_after_run = RedisAI_DagComplete(batch_rinfo[0]);
438435

439-
int dagRefCount = -1;
436+
long long dagRefCount = -1;
440437

441438
if (device_complete == 1 || device_complete_after_run == 1) {
442439
RedisAI_RunInfo *evicted_rinfo = (RedisAI_RunInfo *)(evicted_items[0]->value);
443440
// We decrease and get the reference count for the DAG
444-
dagRefCount = __atomic_sub_fetch(evicted_rinfo->dagRefCount, 1, __ATOMIC_RELAXED);
441+
dagRefCount = RAI_DagRunInfoFreeShallowCopy(evicted_rinfo);
445442
}
446443

447444
// If the DAG was complete, then it's time to unblock the client

src/dag.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,10 +1231,8 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv, in
12311231
}
12321232

12331233
size_t ndevices = array_len(devices);
1234-
1235-
*rinfo->dagRefCount = ndevices;
1236-
12371234
RedisAI_RunInfo **rinfo_copies = array_new(RedisAI_RunInfo *, ndevices);
1235+
*rinfo->dagRefCount = 1;
12381236
rinfo_copies = array_append(rinfo_copies, rinfo);
12391237

12401238
for (long long i = 1; i < ndevices; i++) {

src/run_info.c

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
123123
}
124124
rinfo->dagError = RedisModule_Calloc(1, sizeof(int));
125125
rinfo->dagLock = RedisModule_Alloc(sizeof(pthread_rwlock_t));
126-
rinfo->dagRefCount = RedisModule_Calloc(1, sizeof(long long));
126+
rinfo->dagRefCount = RedisModule_Alloc(sizeof(long long));
127+
*(rinfo->dagRefCount) = 0;
127128
rinfo->dagOpCount = 0;
128129
rinfo->dagCompleteOpCount = RedisModule_Calloc(1, sizeof(long long));
129130
rinfo->dagDeviceOpCount = 0;
@@ -145,6 +146,7 @@ int RAI_ShallowCopyDagRunInfo(RedisAI_RunInfo **result, RedisAI_RunInfo *src) {
145146
if (!(rinfo->dagDeviceOps)) {
146147
return REDISMODULE_ERR;
147148
}
149+
(*rinfo->dagRefCount)++;
148150
rinfo->dagDeviceOpCount = 0;
149151
rinfo->dagDeviceCompleteOpCount = 0;
150152
*result = rinfo;
@@ -197,20 +199,26 @@ void RAI_FreeDagOp(RedisModuleCtx *ctx, RAI_DagOp *dagOp) {
197199
}
198200
}
199201

200-
void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
201-
if (!rinfo) {
202-
return;
202+
long long RAI_DagRunInfoFreeShallowCopy(RedisAI_RunInfo *rinfo) {
203+
long long ref_count = __atomic_sub_fetch(rinfo->dagRefCount, 1, __ATOMIC_RELAXED);
204+
RedisModule_Assert(ref_count >= 0 && "Tried to free the original RunInfo object");
205+
if (rinfo->dagDeviceOps) {
206+
array_free(rinfo->dagDeviceOps);
203207
}
204-
if (*rinfo->dagRefCount > 0) {
205-
if (rinfo->dagDeviceOps) {
206-
array_free(rinfo->dagDeviceOps);
207-
}
208+
// If this is the last run info copy we do not free it, the OnFinish callback may free.
209+
if (ref_count > 0)
208210
RedisModule_Free(rinfo);
211+
return ref_count;
212+
}
213+
214+
void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
215+
if (!rinfo) {
209216
return;
210-
} else {
211-
pthread_rwlock_destroy(rinfo->dagLock);
212-
RedisModule_Free(rinfo->dagLock);
213217
}
218+
long long ref_count = *rinfo->dagRefCount;
219+
RedisModule_Assert(ref_count == 0);
220+
pthread_rwlock_destroy(rinfo->dagLock);
221+
RedisModule_Free(rinfo->dagLock);
214222

215223
if (rinfo->dagTensorsContext) {
216224
AI_dictRelease(rinfo->dagTensorsContext);
@@ -225,10 +233,6 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
225233
array_free(rinfo->dagOps);
226234
}
227235

228-
if (rinfo->dagDeviceOps) {
229-
array_free(rinfo->dagDeviceOps);
230-
}
231-
232236
if (rinfo->dagError) {
233237
RedisModule_Free(rinfo->dagError);
234238
}

src/run_info.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result);
103103

104104
int RAI_ShallowCopyDagRunInfo(RedisAI_RunInfo **result, RedisAI_RunInfo *src);
105105

106+
/**
107+
* Frees the shallow copy of RedisAI_RunInfo pointed by rinfo.
108+
* @param rinfo copy to be freed.
109+
* @retval The ref_count of the rinfo object after freeing this copy.
110+
*/
111+
long long RAI_DagRunInfoFreeShallowCopy(RedisAI_RunInfo *rinfo);
112+
106113
/**
107114
* Frees the memory allocated on RedisAI_RunInfo
108115
* @param ctx Context in which Redis modules operate

0 commit comments

Comments
 (0)