Skip to content

Commit 34d153b

Browse files
committed
Support Async script run via LLAPI
1 parent bf2b13e commit 34d153b

File tree

10 files changed

+228
-18
lines changed

10 files changed

+228
-18
lines changed

src/DAG/dag.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
257257
RAI_ContextUnlock(rinfo);
258258

259259
for (uint i = 0; i < n_inkeys; i++) {
260-
RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i]);
260+
RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i], currentOp->err);
261261
}
262262
for (uint i = 0; i < n_outkeys; i++) {
263263
RAI_ScriptRunCtxAddOutput(currentOp->sctx);
@@ -593,16 +593,16 @@ static void _PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
593593
"ERR specified persistent key that was not used in DAG");
594594
rinfo->dagReplyLength++;
595595
RedisModule_Log(ctx, "warning",
596-
"on DAGRUN's PERSIST pecified persistent key (%s) that "
596+
"on DAGRUN's PERSIST specified persistent key (%s) that "
597597
"was not used on DAG. Logging all local context keys",
598-
persist_key_name);
598+
RedisModule_StringPtrLen(persist_key_name, NULL));
599599
AI_dictIterator *local_iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
600600
AI_dictEntry *local_entry = AI_dictNext(local_iter);
601601

602602
while (local_entry) {
603603
RedisModuleString *localcontext_key_name = AI_dictGetKey(local_entry);
604604
RedisModule_Log(ctx, "warning", "DAG's local context key (%s)",
605-
localcontext_key_name);
605+
RedisModule_StringPtrLen(localcontext_key_name, NULL));
606606
local_entry = AI_dictNext(local_iter);
607607
}
608608
AI_dictReleaseIterator(local_iter);

src/command_parser.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ static int _ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inke
263263

264264
RAI_Tensor *t;
265265
RedisModuleKey *key;
266+
RAI_Error *err;
266267
size_t ninputs = array_len(inkeys), noutputs = array_len(outkeys);
267268
for (size_t i = 0; i < ninputs; i++) {
268269
const int status = RAI_GetTensorFromKeyspace(ctx, inkeys[i], &key, &t, REDISMODULE_READ);
@@ -271,7 +272,7 @@ static int _ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inke
271272
RedisModule_StringPtrLen(inkeys[i], NULL));
272273
return REDISMODULE_ERR;
273274
}
274-
RAI_ScriptRunCtxAddInput(sctx, t);
275+
RAI_ScriptRunCtxAddInput(sctx, t, err);
275276
}
276277
for (size_t i = 0; i < noutputs; i++) {
277278
RAI_ScriptRunCtxAddOutput(sctx);

src/redisai.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) {
969969
REGISTER_API(ModelGetShallowCopy, ctx);
970970
REGISTER_API(ModelRedisType, ctx);
971971
REGISTER_API(ModelRunAsync, ctx);
972-
REGISTER_API(GetAsModelRunCtx, ctx)
972+
REGISTER_API(GetAsModelRunCtx, ctx);
973973

974974
REGISTER_API(ScriptCreate, ctx);
975975
REGISTER_API(ScriptFree, ctx);
@@ -983,6 +983,8 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) {
983983
REGISTER_API(ScriptRun, ctx);
984984
REGISTER_API(ScriptGetShallowCopy, ctx);
985985
REGISTER_API(ScriptRedisType, ctx);
986+
REGISTER_API(ScriptRunAsync, ctx);
987+
REGISTER_API(GetAsScriptRunCtx, ctx);
986988

987989
return REDISMODULE_OK;
988990
}

src/redisai.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ void MODULE_API_FUNC(RedisAI_ScriptRunCtxFree)(RAI_ScriptRunCtx *sctx);
121121
int MODULE_API_FUNC(RedisAI_ScriptRun)(RAI_ScriptRunCtx *sctx, RAI_Error *err);
122122
RAI_Script *MODULE_API_FUNC(RedisAI_ScriptGetShallowCopy)(RAI_Script *script);
123123
RedisModuleType *MODULE_API_FUNC(RedisAI_ScriptRedisType)(void);
124+
int MODULE_API_FUNC(RedisAI_ScriptRunAsync)(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB DAGAsyncFinish,
125+
void *private_data);
126+
RAI_ScriptRunCtx *MODULE_API_FUNC(RedisAI_GetAsScriptRunCtx)(RAI_OnFinishCtx *ctx, RAI_Error *err);
124127

125128
int MODULE_API_FUNC(RedisAI_GetLLAPIVersion)();
126129

@@ -204,6 +207,8 @@ static int RedisAI_Initialize(RedisModuleCtx *ctx) {
204207
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRun);
205208
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptGetShallowCopy);
206209
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRedisType);
210+
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunAsync);
211+
REDISAI_MODULE_INIT_FUNCTION(ctx, GetAsScriptRunCtx);
207212

208213
if (RedisAI_GetLLAPIVersion() < REDISAI_LLAPI_VERSION) {
209214
return REDISMODULE_ERR;

src/run_info.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ int RAI_RunInfoBatchable(struct RAI_DagOp *op1, struct RAI_DagOp *op2) {
329329

330330
return 1;
331331
}
332+
332333
RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) {
333334

334335
RAI_DagOp *op = rinfo->dagOps[0];
@@ -342,3 +343,17 @@ RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) {
342343
RAI_FreeRunInfo(rinfo);
343344
return mctx;
344345
}
346+
347+
RAI_ScriptRunCtx *RAI_GetAsScriptRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) {
348+
349+
RAI_DagOp *op = rinfo->dagOps[0];
350+
if (!rinfo->single_op_dag || !op->sctx) {
351+
RAI_SetError(err, RedisAI_ErrorCode_EFINISHCTX, "Finish ctx is not a script run ctx");
352+
return NULL;
353+
}
354+
RAI_SetError(err, RAI_GetErrorCode(op->err), RAI_GetError(op->err));
355+
RAI_ScriptRunCtx *sctx = op->sctx;
356+
rinfo->dagOps[0]->sctx = NULL;
357+
RAI_FreeRunInfo(rinfo);
358+
return sctx;
359+
}

src/run_info.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@ int RAI_RunInfoBatchable(struct RAI_DagOp *op1, struct RAI_DagOp *op2);
188188
*/
189189
RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err);
190190

191+
/**
192+
* Retreive the ScriptRunCtx of a DAG runInfo that contains a single op of type
193+
* SCRIPTRUN.
194+
* @param DAG runInfo.
195+
* @return Pointer to the ScriptRunCtx in DAG's single op.
196+
*/
197+
RAI_ScriptRunCtx *RAI_GetAsScriptRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err);
198+
191199
#ifdef __cplusplus
192200
} // extern "C"
193201
#endif

src/script.c

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
*/
88

99
#include "script.h"
10-
10+
#include "run_info.h"
11+
#include "DAG/dag.h"
1112
#include "backends.h"
1213
#include "rmutil/alloc.h"
1314
#include "script_struct.h"
@@ -164,7 +165,7 @@ static int _Script_RunCtxAddParam(RAI_ScriptCtxParam **paramArr, RAI_Tensor *ten
164165
return 1;
165166
}
166167

167-
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor) {
168+
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *error) {
168169
// Even if variadic is set, we still allow to add inputs in the LLAPI
169170
return _Script_RunCtxAddParam(&sctx->inputs, inputTensor);
170171
}
@@ -352,3 +353,26 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString
352353
}
353354

354355
RedisModuleType *RAI_ScriptRedisType(void) { return RedisAI_ScriptType; }
356+
357+
int RAI_ScriptRunAsync(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB ScriptAsyncFinish,
358+
void *private_data) {
359+
360+
RedisAI_RunInfo *rinfo = NULL;
361+
if (RAI_InitRunInfo(&rinfo) == REDISMODULE_ERR) {
362+
return REDISMODULE_ERR;
363+
}
364+
rinfo->single_op_dag = 1;
365+
rinfo->OnFinish = (RedisAI_OnFinishCB)ScriptAsyncFinish;
366+
rinfo->private_data = private_data;
367+
368+
RAI_DagOp *op;
369+
if (RAI_InitDagOp(&op) == REDISMODULE_ERR) {
370+
return REDISMODULE_ERR;
371+
}
372+
op->commandType = REDISAI_DAG_CMD_SCRIPTRUN;
373+
Dag_PopulateOp(op, sctx, NULL, NULL, NULL);
374+
375+
rinfo->dagOps = array_append(rinfo->dagOps, op);
376+
rinfo->dagOpCount = 1;
377+
return DAG_InsertDAGToQueue(rinfo);
378+
}

src/script.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "redismodule.h"
1515
#include "script_struct.h"
1616
#include "tensor.h"
17+
#include "run_info.h"
1718

1819
extern RedisModuleType *RedisAI_ScriptType;
1920

@@ -68,7 +69,7 @@ RAI_ScriptRunCtx *RAI_ScriptRunCtxCreate(RAI_Script *script, const char *fnname)
6869
* @param inputTensor input tensor structure
6970
* @return returns 1 on success, 0 in case of error.
7071
*/
71-
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor);
72+
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *error);
7273

7374
/**
7475
* For each Allocates a RAI_ScriptCtxParam data structure, and enforces a
@@ -212,4 +213,16 @@ void RedisAI_ReplyOrSetError(RedisModuleCtx *ctx, RAI_Error *error, RAI_ErrorCod
212213
*/
213214
RedisModuleType *RAI_ScriptRedisType(void);
214215

216+
/**
217+
* Insert the ScriptRunCtx to the run queues so it will run asynchronously.
218+
*
219+
* @param sctx SodelRunCtx to execute
220+
* @param ScriptAsyncFinish A callback that will be called when the execution is finished.
221+
* @param private_data This is going to be sent to to the ScriptAsyncFinish.
222+
* @return REDISMODULE_OK if the sctx was insert to the queues successfully, REDISMODULE_ERR
223+
* otherwise.
224+
*/
225+
int RAI_ScriptRunAsync(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB ScriptAsyncFinish,
226+
void *private_data);
227+
215228
#endif /* SRC_SCRIPT_H_ */

tests/flow/tests_llapi.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,45 @@
22

33
from includes import *
44
import os
5+
from functools import wraps
56

67
'''
78
python -m RLTest --test tests_llapi.py --module path/to/redisai.so
89
'''
910

10-
goal_dir = os.path.join(os.getcwd(), "../module/LLAPI.so")
11-
TEST_MODULE_PATH = os.path.abspath(goal_dir)
1211

12+
def ensure_test_module_loaded(f):
13+
@wraps(f)
14+
def wrapper(env, *args, **kwargs):
15+
goal_dir = os.path.join(os.getcwd(), "../module/LLAPI.so")
16+
TEST_MODULE_PATH = os.path.abspath(goal_dir)
17+
con = env.getConnection()
18+
modules = con.execute_command("MODULE", "LIST")
19+
if b'RAI_llapi' in [module[1] for module in modules]:
20+
return f(env, *args, **kwargs)
21+
try:
22+
ret = con.execute_command('MODULE', 'LOAD', TEST_MODULE_PATH)
23+
env.assertEqual(ret, b'OK')
24+
return f(env, *args, **kwargs)
25+
except Exception as e:
26+
env.assertFalse(True)
27+
env.debugPrint(str(e), force=True)
28+
return
29+
return wrapper
1330

31+
32+
@ensure_test_module_loaded
1433
def test_basic_check(env):
1534

1635
con = env.getConnection()
17-
ret = con.execute_command("MODULE", "LOAD", TEST_MODULE_PATH)
18-
env.assertEqual(ret, b'OK')
1936
ret = con.execute_command("RAI_llapi.basic_check")
2037
env.assertEqual(ret, b'OK')
2138

2239

40+
@ensure_test_module_loaded
2341
def test_model_run_async(env):
2442

2543
con = env.getConnection()
26-
ret = con.execute_command("MODULE", "LOAD", TEST_MODULE_PATH)
27-
env.assertEqual(ret, b'OK')
28-
2944
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
3045
model_filename = os.path.join(test_data_path, 'graph.pb')
3146

@@ -39,3 +54,25 @@ def test_model_run_async(env):
3954
con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
4055
ret = con.execute_command("RAI_llapi.modelRun")
4156
env.assertEqual(ret, b'Async run success')
57+
58+
59+
@ensure_test_module_loaded
60+
def test_script_run_async(env):
61+
62+
con = env.getConnection()
63+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
64+
script_filename = os.path.join(test_data_path, 'script.txt')
65+
66+
with open(script_filename, 'rb') as f:
67+
script = f.read()
68+
69+
ret = con.execute_command('AI.SCRIPTSET', 'myscript{1}', DEVICE, 'TAG', 'version1', 'SOURCE', script)
70+
env.assertEqual(ret, b'OK')
71+
72+
ret = con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
73+
env.assertEqual(ret, b'OK')
74+
ret = con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
75+
env.assertEqual(ret, b'OK')
76+
77+
ret = con.execute_command("RAI_llapi.scriptRun")
78+
env.assertEqual(ret, b'Async run success')

0 commit comments

Comments
 (0)