|
47 | 47 | #include "dag_parser.h" |
48 | 48 | #include "util/string_utils.h" |
49 | 49 |
|
50 | | -/** |
51 | | - * Execution of a TENSORSET DAG step. |
52 | | - * If an error occurs, it is recorded in the DagOp struct. |
53 | | - * |
54 | | - * @param rinfo context in which RedisAI blocking commands operate. |
55 | | - * @param currentOp TENSORSET DagOp to be executed |
56 | | - * @return |
57 | | - */ |
58 | | -void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) { |
59 | | - RAI_Tensor *t = NULL; |
60 | | - const int parse_result = |
61 | | - RAI_parseTensorSetArgs(NULL, currentOp->argv, currentOp->argc, &t, 0, currentOp->err); |
62 | | - if (parse_result > 0) { |
63 | | - RedisModuleString *key_string = currentOp->outkeys[0]; |
64 | | - RAI_ContextWriteLock(rinfo); |
65 | | - AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, t); |
66 | | - RAI_ContextUnlock(rinfo); |
67 | | - currentOp->result = REDISMODULE_OK; |
68 | | - } else { |
69 | | - currentOp->result = REDISMODULE_ERR; |
70 | | - } |
71 | | -} |
72 | | - |
73 | | -/** |
74 | | - * Execution of a TENSORGET DAG step. |
75 | | - * If an error occurs, it is recorded in the DagOp struct. |
76 | | - * |
77 | | - * @param rinfo context in which RedisAI blocking commands operate. |
78 | | - * @param currentOp TENSORGET DagOp to be executed |
79 | | - * @return |
80 | | - */ |
81 | | -void RedisAI_DagRunSession_TensorGet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) { |
82 | | - RedisModuleString *key_string = currentOp->inkeys[0]; |
83 | | - RAI_Tensor *t = NULL; |
84 | | - RAI_ContextReadLock(rinfo); |
85 | | - currentOp->result = RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext, key_string, |
86 | | - &t, currentOp->err); |
87 | | - RAI_ContextUnlock(rinfo); |
88 | | - if (currentOp->result == REDISMODULE_OK) { |
89 | | - RAI_Tensor *outTensor = NULL; |
90 | | - // TODO: check tensor copy return value |
91 | | - RAI_TensorDeepCopy(t, &outTensor); |
92 | | - currentOp->outTensors = array_append(currentOp->outTensors, outTensor); |
93 | | - } |
94 | | -} |
95 | | - |
96 | 50 | static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) { |
97 | 51 | uint n_inkeys = array_len(currentOp->inkeys); |
98 | 52 | uint n_outkeys = array_len(currentOp->outkeys); |
@@ -477,11 +431,13 @@ void RedisAI_DagRunSessionStep(RedisAI_RunInfo *rinfo, const char *devicestr) { |
477 | 431 |
|
478 | 432 | switch (currentOp->commandType) { |
479 | 433 | case REDISAI_DAG_CMD_TENSORSET: { |
480 | | - RedisAI_DagRunSession_TensorSet_Step(rinfo, currentOp); |
| 434 | + // TENSORSET op is done in parsing stage (consider removing it from dag ops). |
| 435 | + currentOp->result = REDISMODULE_OK; |
481 | 436 | break; |
482 | 437 | } |
483 | 438 | case REDISAI_DAG_CMD_TENSORGET: { |
484 | | - RedisAI_DagRunSession_TensorGet_Step(rinfo, currentOp); |
| 439 | + // TENSORSET op is done when we finish (consider removing it from dag ops). |
| 440 | + currentOp->result = REDISMODULE_OK; |
485 | 441 | break; |
486 | 442 | } |
487 | 443 | case REDISAI_DAG_CMD_MODELRUN: { |
@@ -680,18 +636,14 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc |
680 | 636 |
|
681 | 637 | case REDISAI_DAG_CMD_TENSORGET: { |
682 | 638 | rinfo->dagReplyLength++; |
683 | | - if (currentOp->result == REDISMODULE_ERR) { |
| 639 | + RAI_Tensor *t; |
| 640 | + int res = RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext, |
| 641 | + currentOp->inkeys[0], &t, currentOp->err); |
| 642 | + if (res != REDISMODULE_OK) { |
684 | 643 | RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline); |
685 | 644 | dag_error = 1; |
686 | 645 | } else { |
687 | | - if (array_len(currentOp->outTensors) > 0) { |
688 | | - RAI_Tensor *tensor = currentOp->outTensors[0]; |
689 | | - RAI_parseTensorGetArgs(ctx, currentOp->argv, currentOp->argc, tensor); |
690 | | - } else if (currentOp->result == -1) { |
691 | | - RedisModule_ReplyWithSimpleString(ctx, "NA"); |
692 | | - } else { |
693 | | - RedisModule_ReplyWithError(ctx, "ERR error getting tensor from local context"); |
694 | | - } |
| 646 | + ReplyWithTensor(ctx, currentOp->fmt, t); |
695 | 647 | } |
696 | 648 | break; |
697 | 649 | } |
|
0 commit comments