Skip to content

Commit 152092a

Browse files
author
DvirDukhan
committed
simple get
1 parent 19ef6a3 commit 152092a

File tree

6 files changed

+100
-95
lines changed

6 files changed

+100
-95
lines changed

src/command_parser.c

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **
181181
RedisModuleString **runkey, char const **func_name,
182182
long long *timeout, int *variadic) {
183183

184-
if (argc < 5) {
184+
if (argc < 3) {
185185
RAI_SetError(error, RAI_ESCRIPTRUN,
186186
"ERR wrong number of arguments for 'AI.SCRIPTRUN' command");
187187
return REDISMODULE_ERR;
@@ -198,41 +198,50 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **
198198
*runkey = argv[argpos];
199199

200200
const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
201-
if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS")) {
201+
if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS") || !strcasecmp(arg_string, "OUTPUTS")) {
202202
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR function name not specified");
203203
return REDISMODULE_ERR;
204204
}
205205
*func_name = arg_string;
206-
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
207206

208-
// Parse timeout arg if given and store it in timeout
209-
if (!strcasecmp(arg_string, "TIMEOUT")) {
210-
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
211-
return REDISMODULE_ERR;
212-
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
213-
}
214-
if (strcasecmp(arg_string, "INPUTS") != 0) {
215-
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR INPUTS not specified");
216-
return REDISMODULE_ERR;
217-
}
218-
219-
bool is_input = true, is_output = false;
207+
bool is_input = false;
208+
bool is_output = false;
209+
bool timeout_set = false;
220210
size_t ninputs = 0, noutputs = 0;
221211
int varidic_start_pos = -1;
222212

223213
while (++argpos < argc) {
224214
arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
215+
216+
// Parse timeout arg if given and store it in timeout
217+
if (!strcasecmp(arg_string, "TIMEOUT") && !timeout_set) {
218+
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
219+
return REDISMODULE_ERR;
220+
timeout_set = true;
221+
continue;
222+
}
223+
224+
if (!strcasecmp(arg_string, "INPUTS") && !is_input) {
225+
is_input = true;
226+
is_output = false;
227+
continue;
228+
}
225229
if (!strcasecmp(arg_string, "OUTPUTS") && !is_output) {
226230
is_input = false;
227231
is_output = true;
228-
} else if (!strcasecmp(arg_string, "$")) {
232+
continue;
233+
}
234+
if (!strcasecmp(arg_string, "$")) {
229235
if (varidic_start_pos > -1) {
230236
RAI_SetError(error, RAI_ESCRIPTRUN,
231237
"ERR Already encountered a variable size list of tensors");
232238
return REDISMODULE_ERR;
233239
}
234240
varidic_start_pos = ninputs;
235-
} else {
241+
continue;
242+
}
243+
// Parse argument name
244+
{
236245
RAI_HoldString(NULL, argv[argpos]);
237246
if (is_input) {
238247
ninputs++;

src/libtorch_c/torch_c.cpp

Lines changed: 2 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,7 @@
77
#include <iostream>
88
#include <sstream>
99

10-
#include "torch_extensions/torch_redis_value.h"
11-
#include "../redismodule.h"
12-
#include "../util/arr.h"
13-
14-
#include "torch/csrc/jit/frontend/resolver.h"
15-
#include "torch/script.h"
16-
#include "torch/jit.h"
17-
18-
namespace torch {
19-
namespace jit {
20-
namespace script {
21-
struct RedisResolver : public Resolver {
22-
23-
std::shared_ptr<SugaredValue> resolveValue(const std::string& name, Function& m, const SourceRange& loc) override {
24-
if(strcasecmp(name.c_str(), "torch") == 0) {
25-
return std::make_shared<BuiltinModule>("aten");
26-
}
27-
else if (strcasecmp(name.c_str(), "redis") == 0) {
28-
return std::make_shared<BuiltinModule>("redis");
29-
}
30-
return nullptr;
31-
}
32-
33-
TypePtr resolveType(const std::string& name, const SourceRange& loc) override {
34-
return nullptr;
35-
}
36-
37-
};
38-
inline std::shared_ptr<RedisResolver> redisResolver() {
39-
return std::make_shared<RedisResolver>();
40-
}
41-
}
42-
}
43-
}
44-
45-
46-
10+
#include "torch_extensions/torch_redis.h"
4711
namespace {
4812

4913
static DLDataType getDLDataType(const at::Tensor &t) {
@@ -279,6 +243,7 @@ void torchRunModule(ModuleContext *ctx, const char *fnName, int variadic, long n
279243
torch::DeviceType output_device_type = torch::kCPU;
280244
torch::Device output_device(output_device_type, -1);
281245

246+
if(nOutputs == 0) return;
282247
int count = 0;
283248
for (size_t i = 0; i < stack.size(); i++) {
284249
if (count > nOutputs - 1) {
@@ -337,21 +302,6 @@ extern "C" DLManagedTensor *torchNewTensor(DLDataType dtype, long ndims, int64_t
337302
return dl_tensor;
338303
}
339304

340-
void redisExecute(std::string fn_name, std::vector<std::string> args ) {
341-
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL);
342-
size_t len = args.size();
343-
RedisModuleString* arguments[len];
344-
len = 0;
345-
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) {
346-
const std::string arg = *it;
347-
const char* str = arg.c_str();
348-
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str));
349-
}
350-
RedisModule_Call(ctx, fn_name.c_str(), "v", arguments, len);
351-
RedisModule_FreeThreadSafeContext(ctx);
352-
}
353-
354-
static auto registry = torch::RegisterOperators("redis::execute", &redisExecute);
355305
extern "C" void* torchCompileScript(const char* script, DLDeviceType device, int64_t device_id,
356306
char **error, void* (*alloc)(size_t))
357307
{
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "../../redismodule.h"
2+
#include "../../util/arr.h"
3+
4+
#include "torch/csrc/jit/frontend/resolver.h"
5+
#include "torch/script.h"
6+
#include "torch/jit.h"
7+
8+
namespace torch {
9+
namespace jit {
10+
namespace script {
11+
struct RedisResolver : public Resolver {
12+
13+
std::shared_ptr<SugaredValue> resolveValue(const std::string& name, Function& m, const SourceRange& loc) override {
14+
if(strcasecmp(name.c_str(), "torch") == 0) {
15+
return std::make_shared<BuiltinModule>("aten");
16+
}
17+
else if (strcasecmp(name.c_str(), "redis") == 0) {
18+
return std::make_shared<BuiltinModule>("redis");
19+
}
20+
return nullptr;
21+
}
22+
23+
TypePtr resolveType(const std::string& name, const SourceRange& loc) override {
24+
return nullptr;
25+
}
26+
27+
};
28+
inline std::shared_ptr<RedisResolver> redisResolver() {
29+
return std::make_shared<RedisResolver>();
30+
}
31+
}
32+
}
33+
}
34+
35+
void redisExecute(std::string fn_name, std::vector<std::string> args ) {
36+
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL);
37+
size_t len = args.size();
38+
RedisModuleString* arguments[len];
39+
len = 0;
40+
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) {
41+
const std::string arg = *it;
42+
const char* str = arg.c_str();
43+
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str));
44+
}
45+
RedisModule_Call(ctx, fn_name.c_str(), "v", arguments, len);
46+
RedisModule_FreeThreadSafeContext(ctx);
47+
}
48+
49+
static auto registry = torch::RegisterOperators("redis::execute", &redisExecute);

src/libtorch_c/torch_extensions/torch_redis_value.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,31 @@ struct RedisValue : torch::CustomClassHolder {
1313

1414
public:
1515
RedisValue(RedisModuleCallReply *reply) {
16-
if (RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_ARRAY) {
17-
size_t len = RedisModule_CallReplyLength(reply);
18-
for (auto i = 0; i < len; ++i) {
19-
RedisModuleCallReply *subReply = RedisModule_CallReplyArrayElement(reply, i);
20-
RedisValue value(subReply);
21-
// arrayValue.push_back(value);
22-
}
23-
}
16+
// if (RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_ARRAY) {
17+
// size_t len = RedisModule_CallReplyLength(reply);
18+
// for (auto i = 0; i < len; ++i) {
19+
// RedisModuleCallReply *subReply = RedisModule_CallReplyArrayElement(reply, i);
20+
// RedisValue value(subReply);
21+
// arrayValue.push_back(value);
22+
// }
23+
// return;
24+
// }
2425

2526
if (RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_STRING ||
2627
RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_ERROR) {
2728
size_t len;
2829
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len);
29-
// PyObject *ret = PyUnicode_FromStringAndSize(replyStr, len);
30-
// if (!ret) {
31-
// PyErr_Clear();
32-
// ret = PyByteArray_FromStringAndSize(replyStr, len);
33-
// }
34-
// return ret;
30+
stringValue= replyStr;
31+
return;
32+
3533
}
3634

3735
if (RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_INTEGER) {
38-
long long val = RedisModule_CallReplyInteger(reply);
39-
// return PyLong_FromLongLong(val);
36+
intValue = RedisModule_CallReplyInteger(reply);
37+
return;
4038
}
39+
40+
throw(std::runtime_error("Unsupported redis type"));
4141
}
4242

4343
virtual ~RedisValue() {

src/redisai.c

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,6 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
578578
return RedisAI_ScriptRun_IsKeysPositionRequest_ReportKeys(ctx, argv, argc);
579579
}
580580

581-
if (argc < 6)
582-
return RedisModule_WrongArity(ctx);
583-
584581
// Convert The script run command into a DAG command that contains a single op.
585582
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_SCRIPTRUN, false);
586583
}

tests/flow/test_torchscript_extensions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@ def __init__(self):
2828
self.env.assertEqual(ret, b'OK')
2929
# self.env.ensureSlaveSynced(self.con, self.env)
3030

31-
# def test_int_get_set(self):
32-
# self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_int_set_get', 'INPUTS', 'OUTPUTS', 'y')
33-
# y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
34-
# self.env.assertEqual(y, ["dtype", "INT64", "shape", [0, 1], "VALUES", "1"] )
35-
3631
# def test_float_get_set(self):
3732
# self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_float_set_get', 'INPUTS', 'OUTPUTS', 'y')
3833
# y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
3934
# self.env.assertEqual(y, ["dtype", "FLOAT", "shape", [0, 1], "VALUES", "1.1"])
4035

4136
def test_simple_test_set(self):
4237
self.con.execute_command(
43-
'AI.SCRIPTRUN', 'redis_scripts', 'test_set_key', 'INPUTS', 'OUTPUTS', 'y')
44-
self.env.assertEqual("1", self.con.get("x"))
38+
'AI.SCRIPTRUN', 'redis_scripts', 'test_set_key')
39+
self.env.assertEqual(b"1", self.con.get("x"))
40+
41+
# def test_int_get_set(self):
42+
# self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_int_set_get', 'OUTPUTS', 'y')
43+
# y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
44+
# self.env.assertEqual(y, ["dtype", "INT64", "shape", [0, 1], "VALUES", "1"] )

0 commit comments

Comments
 (0)