Skip to content

Commit 247c4e5

Browse files
authored
Merge pull request #551 from RedisAI/DAGRUN_command_refactor
Refactor DAGRUN command parsing
2 parents 671e157 + e56046f commit 247c4e5

File tree

16 files changed

+492
-553
lines changed

16 files changed

+492
-553
lines changed

src/DAG/dag.c

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -47,52 +47,6 @@
4747
#include "dag_parser.h"
4848
#include "util/string_utils.h"
4949

50-
/**
51-
* Execution of a TENSORSET DAG step.
52-
* If an error occurs, it is recorded in the DagOp struct.
53-
*
54-
* @param rinfo context in which RedisAI blocking commands operate.
55-
* @param currentOp TENSORSET DagOp to be executed
56-
* @return
57-
*/
58-
void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) {
59-
RAI_Tensor *t = NULL;
60-
const int parse_result =
61-
RAI_parseTensorSetArgs(NULL, currentOp->argv, currentOp->argc, &t, 0, currentOp->err);
62-
if (parse_result > 0) {
63-
RedisModuleString *key_string = currentOp->outkeys[0];
64-
RAI_ContextWriteLock(rinfo);
65-
AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, t);
66-
RAI_ContextUnlock(rinfo);
67-
currentOp->result = REDISMODULE_OK;
68-
} else {
69-
currentOp->result = REDISMODULE_ERR;
70-
}
71-
}
72-
73-
/**
74-
* Execution of a TENSORGET DAG step.
75-
* If an error occurs, it is recorded in the DagOp struct.
76-
*
77-
* @param rinfo context in which RedisAI blocking commands operate.
78-
* @param currentOp TENSORGET DagOp to be executed
79-
* @return
80-
*/
81-
void RedisAI_DagRunSession_TensorGet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) {
82-
RedisModuleString *key_string = currentOp->inkeys[0];
83-
RAI_Tensor *t = NULL;
84-
RAI_ContextReadLock(rinfo);
85-
currentOp->result = RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext, key_string,
86-
&t, currentOp->err);
87-
RAI_ContextUnlock(rinfo);
88-
if (currentOp->result == REDISMODULE_OK) {
89-
RAI_Tensor *outTensor = NULL;
90-
// TODO: check tensor copy return value
91-
RAI_TensorDeepCopy(t, &outTensor);
92-
currentOp->outTensors = array_append(currentOp->outTensors, outTensor);
93-
}
94-
}
95-
9650
static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) {
9751
uint n_inkeys = array_len(currentOp->inkeys);
9852
uint n_outkeys = array_len(currentOp->outkeys);
@@ -477,11 +431,13 @@ void RedisAI_DagRunSessionStep(RedisAI_RunInfo *rinfo, const char *devicestr) {
477431

478432
switch (currentOp->commandType) {
479433
case REDISAI_DAG_CMD_TENSORSET: {
480-
RedisAI_DagRunSession_TensorSet_Step(rinfo, currentOp);
434+
// TENSORSET op is done in parsing stage (consider removing it from dag ops).
435+
currentOp->result = REDISMODULE_OK;
481436
break;
482437
}
483438
case REDISAI_DAG_CMD_TENSORGET: {
484-
RedisAI_DagRunSession_TensorGet_Step(rinfo, currentOp);
439+
// TENSORSET op is done when we finish (consider removing it from dag ops).
440+
currentOp->result = REDISMODULE_OK;
485441
break;
486442
}
487443
case REDISAI_DAG_CMD_MODELRUN: {
@@ -680,18 +636,14 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
680636

681637
case REDISAI_DAG_CMD_TENSORGET: {
682638
rinfo->dagReplyLength++;
683-
if (currentOp->result == REDISMODULE_ERR) {
639+
RAI_Tensor *t;
640+
int res = RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext,
641+
currentOp->inkeys[0], &t, currentOp->err);
642+
if (res != REDISMODULE_OK) {
684643
RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline);
685644
dag_error = 1;
686645
} else {
687-
if (array_len(currentOp->outTensors) > 0) {
688-
RAI_Tensor *tensor = currentOp->outTensors[0];
689-
RAI_parseTensorGetArgs(ctx, currentOp->argv, currentOp->argc, tensor);
690-
} else if (currentOp->result == -1) {
691-
RedisModule_ReplyWithSimpleString(ctx, "NA");
692-
} else {
693-
RedisModule_ReplyWithError(ctx, "ERR error getting tensor from local context");
694-
}
646+
ReplyWithTensor(ctx, currentOp->fmt, t);
695647
}
696648
break;
697649
}

0 commit comments

Comments
 (0)