Skip to content

Commit a191e32

Browse files
authored
Merge pull request #545 from RedisAI/Fix_unsafe_dag_release_upon_error
Fix unsafe dag release upon error
2 parents cb7f664 + 9c8ac4a commit a191e32

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

src/background_workers.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,19 +391,17 @@ void *RedisAI_Run_ThreadMain(void *arg) {
391391

392392
// Run is over, now iterate over the run info structs in the batch
393393
// and see if any error was generated
394-
int dagError = 0;
394+
bool first_dag_error = false;
395395
for (long long i = 0; i < array_len(batch_rinfo); i++) {
396396
RedisAI_RunInfo *rinfo = batch_rinfo[i];
397-
// We lock on the DAG because error could be set from
398-
// other threads operating on the same DAG (TODO: use atomic)
399-
dagError = __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED);
400-
401397
// We record that there was an error for later on
402-
run_error = dagError;
403-
398+
run_error = __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED);
399+
if (i == 0 && run_error == 1) {
400+
first_dag_error = true;
401+
}
404402
// If there was an error and the reference count for the dag
405403
// has gone to zero and the client is still around, we unblock
406-
if (dagError) {
404+
if (run_error) {
407405
RedisAI_RunInfo *orig = rinfo->orig_copy;
408406
long long dagRefCount = RAI_DagRunInfoFreeShallowCopy(rinfo);
409407
if (dagRefCount == 0) {
@@ -415,12 +413,17 @@ void *RedisAI_Run_ThreadMain(void *arg) {
415413
__atomic_add_fetch(rinfo->dagCompleteOpCount, 1, __ATOMIC_RELAXED);
416414
}
417415
}
416+
if (first_dag_error) {
417+
run_queue_len = queueLength(run_queue_info->run_queue);
418+
continue;
419+
}
418420
}
419421

420422
// We initialize variables where we'll store the fact hat, after the current
421423
// run, all ops for the device or all ops in the dag could be complete. This
422424
// way we can avoid placing the op back on the queue if there's nothing left
423425
// to do.
426+
RedisModule_Assert(run_error == 0);
424427
int device_complete_after_run = RedisAI_DagDeviceComplete(batch_rinfo[0]);
425428
int dag_complete_after_run = RedisAI_DagComplete(batch_rinfo[0]);
426429

tests/flow/tests_onnx.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,21 @@ def test_onnx_modelrun_mnist(env):
6363
except Exception as e:
6464
exception = e
6565
env.assertEqual(type(exception), redis.exceptions.ResponseError)
66-
env.assertEqual("No graph was found in the protobuf.", exception.__str__())
66+
env.assertEqual("No graph was found in the protobuf.", str(exception))
6767

6868
try:
6969
con.execute_command('AI.MODELSET', 'm_1{1}', 'ONNX', 'BLOB', model_pb)
7070
except Exception as e:
7171
exception = e
7272
env.assertEqual(type(exception), redis.exceptions.ResponseError)
73-
env.assertEqual("Invalid DEVICE", exception.__str__())
73+
env.assertEqual("Invalid DEVICE", str(exception))
7474

7575
try:
7676
con.execute_command('AI.MODELSET', 'm_2{1}', model_pb)
7777
except Exception as e:
7878
exception = e
7979
env.assertEqual(type(exception), redis.exceptions.ResponseError)
80-
env.assertEqual("wrong number of arguments for 'AI.MODELSET' command", exception.__str__())
80+
env.assertEqual("wrong number of arguments for 'AI.MODELSET' command", str(exception))
8181

8282
con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
8383

@@ -86,56 +86,64 @@ def test_onnx_modelrun_mnist(env):
8686
except Exception as e:
8787
exception = e
8888
env.assertEqual(type(exception), redis.exceptions.ResponseError)
89-
env.assertEqual("model key is empty", exception.__str__())
89+
env.assertEqual("model key is empty", str(exception))
9090

9191
try:
9292
con.execute_command('AI.MODELRUN', 'm_2{1}', 'INPUTS', 'a{1}', 'b{1}', 'c{1}')
9393
except Exception as e:
9494
exception = e
9595
env.assertEqual(type(exception), redis.exceptions.ResponseError)
96-
env.assertEqual("model key is empty", exception.__str__())
96+
env.assertEqual("model key is empty", str(exception))
9797

9898
try:
9999
con.execute_command('AI.MODELRUN', 'm_3{1}', 'a{1}', 'b{1}', 'c{1}')
100100
except Exception as e:
101101
exception = e
102102
env.assertEqual(type(exception), redis.exceptions.ResponseError)
103-
env.assertEqual("model key is empty", exception.__str__())
103+
env.assertEqual("model key is empty", str(exception))
104104

105105
try:
106106
con.execute_command('AI.MODELRUN', 'm_1{1}', 'OUTPUTS', 'c{1}')
107107
except Exception as e:
108108
exception = e
109109
env.assertEqual(type(exception), redis.exceptions.ResponseError)
110-
env.assertEqual("model key is empty", exception.__str__())
110+
env.assertEqual("model key is empty", str(exception))
111111

112112
try:
113113
con.execute_command('AI.MODELRUN', 'm{1}', 'OUTPUTS', 'c{1}')
114114
except Exception as e:
115115
exception = e
116116
env.assertEqual(type(exception), redis.exceptions.ResponseError)
117-
env.assertEqual("INPUTS not specified", exception.__str__())
117+
env.assertEqual("INPUTS not specified", str(exception))
118118

119119
try:
120120
con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'b{1}')
121121
except Exception as e:
122122
exception = e
123123
env.assertEqual(type(exception), redis.exceptions.ResponseError)
124-
env.assertEqual("tensor key is empty", exception.__str__())
124+
env.assertEqual("tensor key is empty", str(exception))
125125

126126
try:
127127
con.execute_command('AI.MODELRUN', 'm_1{1}', 'INPUTS', 'OUTPUTS')
128128
except Exception as e:
129129
exception = e
130130
env.assertEqual(type(exception), redis.exceptions.ResponseError)
131-
env.assertEqual("model key is empty", exception.__str__())
131+
env.assertEqual("model key is empty", str(exception))
132132

133133
try:
134134
con.execute_command('AI.MODELRUN', 'm_1{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}')
135135
except Exception as e:
136136
exception = e
137137
env.assertEqual(type(exception), redis.exceptions.ResponseError)
138-
env.assertEqual("model key is empty", exception.__str__())
138+
env.assertEqual("model key is empty", str(exception))
139+
140+
# This error is caught after the model is sent to the backend, not in parsing like before.
141+
try:
142+
con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'a{1}', 'OUTPUTS', 'b{1}')
143+
except Exception as e:
144+
exception = e
145+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
146+
env.assertEqual('Expected 1 inputs but got 2', str(exception))
139147

140148
con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}')
141149

0 commit comments

Comments
 (0)