Skip to content

Commit 7947e6e

Browse files
author
DvirDukhan
authored
Merge pull request #489 from RedisAI/torchscript_extensions
Execute Redis commands in Torch Script
2 parents 2fc1827 + e50e5cf commit 7947e6e

File tree

9 files changed

+347
-50
lines changed

9 files changed

+347
-50
lines changed

src/backends/torch.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@ int RAI_InitBackendTorch(int (*get_api_fn)(const char *, void *)) {
1111
get_api_fn("RedisModule_Free", ((void **)&RedisModule_Free));
1212
get_api_fn("RedisModule_Realloc", ((void **)&RedisModule_Realloc));
1313
get_api_fn("RedisModule_Strdup", ((void **)&RedisModule_Strdup));
14+
get_api_fn("RedisModule_CreateString", ((void **)&RedisModule_CreateString));
15+
get_api_fn("RedisModule_FreeString", ((void **)&RedisModule_FreeString));
16+
get_api_fn("RedisModule_Call", ((void **)&RedisModule_Call));
17+
get_api_fn("RedisModule_CallReplyType", ((void **)&RedisModule_CallReplyType));
18+
get_api_fn("RedisModule_CallReplyStringPtr", ((void **)&RedisModule_CallReplyStringPtr));
19+
get_api_fn("RedisModule_CallReplyInteger", ((void **)&RedisModule_CallReplyInteger));
20+
get_api_fn("RedisModule_CallReplyLength", ((void **)&RedisModule_CallReplyLength));
21+
get_api_fn("RedisModule_CallReplyArrayElement", ((void **)&RedisModule_CallReplyArrayElement));
22+
get_api_fn("RedisModule_FreeCallReply", ((void **)&RedisModule_FreeCallReply));
23+
get_api_fn("RedisModule_GetThreadSafeContext", ((void **)&RedisModule_GetThreadSafeContext));
24+
get_api_fn("RedisModule_ThreadSafeContextLock", ((void **)&RedisModule_ThreadSafeContextLock));
25+
get_api_fn("RedisModule_ThreadSafeContextUnlock",
26+
((void **)&RedisModule_ThreadSafeContextUnlock));
27+
get_api_fn("RedisModule_FreeThreadSafeContext", ((void **)&RedisModule_FreeThreadSafeContext));
1428

1529
return REDISMODULE_OK;
1630
}

src/command_parser.c

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **
164164
RedisModuleString **runkey, char const **func_name,
165165
long long *timeout, int *variadic) {
166166

167-
if (argc < 5) {
167+
if (argc < 3) {
168168
RAI_SetError(error, RAI_ESCRIPTRUN,
169169
"ERR wrong number of arguments for 'AI.SCRIPTRUN' command");
170170
return REDISMODULE_ERR;
@@ -181,49 +181,83 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **
181181
*runkey = argv[argpos];
182182

183183
const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
184-
if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS")) {
184+
if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS") ||
185+
!strcasecmp(arg_string, "OUTPUTS")) {
185186
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR function name not specified");
186187
return REDISMODULE_ERR;
187188
}
188189
*func_name = arg_string;
189-
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
190190

191-
// Parse timeout arg if given and store it in timeout
192-
if (!strcasecmp(arg_string, "TIMEOUT")) {
193-
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
194-
return REDISMODULE_ERR;
195-
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
196-
}
197-
if (strcasecmp(arg_string, "INPUTS") != 0) {
198-
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR INPUTS not specified");
199-
return REDISMODULE_ERR;
200-
}
201-
202-
bool is_input = true, is_output = false;
191+
bool is_input = false;
192+
bool is_output = false;
193+
bool timeout_set = false;
194+
bool inputs_done = false;
203195
size_t ninputs = 0, noutputs = 0;
204196
int varidic_start_pos = -1;
205197

206198
while (++argpos < argc) {
207199
arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
208-
if (!strcasecmp(arg_string, "OUTPUTS") && !is_output) {
200+
201+
// Parse timeout arg if given and store it in timeout
202+
if (!strcasecmp(arg_string, "TIMEOUT") && !timeout_set) {
203+
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
204+
return REDISMODULE_ERR;
205+
timeout_set = true;
206+
continue;
207+
}
208+
209+
if (!strcasecmp(arg_string, "INPUTS")) {
210+
if (inputs_done) {
211+
RAI_SetError(error, RAI_ESCRIPTRUN,
212+
"ERR Already encountered an INPUTS section in SCRIPTRUN");
213+
return REDISMODULE_ERR;
214+
}
215+
if (is_input) {
216+
RAI_SetError(error, RAI_ESCRIPTRUN,
217+
"ERR Already encountered an INPUTS keyword in SCRIPTRUN");
218+
return REDISMODULE_ERR;
219+
}
220+
is_input = true;
221+
is_output = false;
222+
continue;
223+
}
224+
if (!strcasecmp(arg_string, "OUTPUTS")) {
225+
if (is_output) {
226+
RAI_SetError(error, RAI_ESCRIPTRUN,
227+
"ERR Already encountered an OUTPUTS keyword in SCRIPTRUN");
228+
return REDISMODULE_ERR;
229+
}
209230
is_input = false;
210231
is_output = true;
211-
} else if (!strcasecmp(arg_string, "$")) {
232+
inputs_done = true;
233+
continue;
234+
}
235+
if (!strcasecmp(arg_string, "$")) {
236+
if (!is_input) {
237+
RAI_SetError(
238+
error, RAI_ESCRIPTRUN,
239+
"ERR Encountered a variable size list of tensors outside of input section");
240+
return REDISMODULE_ERR;
241+
}
212242
if (varidic_start_pos > -1) {
213243
RAI_SetError(error, RAI_ESCRIPTRUN,
214244
"ERR Already encountered a variable size list of tensors");
215245
return REDISMODULE_ERR;
216246
}
217247
varidic_start_pos = ninputs;
248+
continue;
249+
}
250+
// Parse argument name
251+
RAI_HoldString(NULL, argv[argpos]);
252+
if (is_input) {
253+
ninputs++;
254+
*inkeys = array_append(*inkeys, argv[argpos]);
255+
} else if (is_output) {
256+
noutputs++;
257+
*outkeys = array_append(*outkeys, argv[argpos]);
218258
} else {
219-
RAI_HoldString(NULL, argv[argpos]);
220-
if (is_input) {
221-
ninputs++;
222-
*inkeys = array_append(*inkeys, argv[argpos]);
223-
} else {
224-
noutputs++;
225-
*outkeys = array_append(*outkeys, argv[argpos]);
226-
}
259+
RAI_SetError(error, RAI_ESCRIPTRUN, "ERR Unrecongnized parameter to SCRIPTRUN");
260+
return REDISMODULE_ERR;
227261
}
228262
}
229263
*variadic = varidic_start_pos;

src/libtorch_c/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
add_library(torch_c STATIC torch_c.cpp)
1+
add_library(torch_c STATIC torch_c.cpp torch_extensions/torch_redis.cpp)
22
target_link_libraries(torch_c "${TORCH_LIBRARIES}")
33
set_property(TARGET torch_c PROPERTY CXX_STANDARD 14)

src/libtorch_c/torch_c.cpp

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <iostream>
99
#include <sstream>
1010

11+
#include "torch_extensions/torch_redis.h"
1112
namespace {
1213

1314
static DLDataType getDLDataType(const at::Tensor &t) {
@@ -246,6 +247,7 @@ void torchRunModule(ModuleContext *ctx, const char *fnName, int variadic, long n
246247
torch::DeviceType output_device_type = torch::kCPU;
247248
torch::Device output_device(output_device_type, -1);
248249

250+
if(nOutputs == 0) return;
249251
int count = 0;
250252
for (size_t i = 0; i < stack.size(); i++) {
251253
if (count > nOutputs - 1) {
@@ -304,28 +306,37 @@ extern "C" DLManagedTensor *torchNewTensor(DLDataType dtype, long ndims, int64_t
304306
return dl_tensor;
305307
}
306308

307-
extern "C" void *torchCompileScript(const char *script, DLDeviceType device, int64_t device_id,
308-
char **error, void *(*alloc)(size_t)) {
309-
ModuleContext *ctx = new ModuleContext();
310-
ctx->device = device;
311-
ctx->device_id = device_id;
312-
try {
313-
auto cu = torch::jit::compile(script);
314-
auto aten_device_type = getATenDeviceType(device);
315-
if (aten_device_type == at::DeviceType::CUDA && !torch::cuda::is_available()) {
316-
throw std::logic_error("GPU requested but Torch couldn't find CUDA");
317-
}
318-
ctx->cu = cu;
319-
ctx->module = nullptr;
320-
} catch (std::exception &e) {
321-
size_t len = strlen(e.what()) + 1;
322-
*error = (char *)alloc(len * sizeof(char));
323-
strcpy(*error, e.what());
324-
(*error)[len - 1] = '\0';
325-
delete ctx;
326-
return NULL;
309+
extern "C" void* torchCompileScript(const char* script, DLDeviceType device, int64_t device_id,
310+
char **error, void* (*alloc)(size_t))
311+
{
312+
ModuleContext* ctx = new ModuleContext();
313+
ctx->device = device;
314+
ctx->device_id = device_id;
315+
try {
316+
auto cu = std::make_shared<torch::jit::script::CompilationUnit>();
317+
cu->define(
318+
c10::nullopt,
319+
script,
320+
torch::jit::script::redisResolver(),
321+
nullptr);
322+
auto aten_device_type = getATenDeviceType(device);
323+
324+
if (aten_device_type == at::DeviceType::CUDA && !torch::cuda::is_available()) {
325+
throw std::logic_error("GPU requested but Torch couldn't find CUDA");
327326
}
328-
return ctx;
327+
ctx->cu = cu;
328+
ctx->module = nullptr;
329+
330+
}
331+
catch(std::exception& e) {
332+
size_t len = strlen(e.what()) +1;
333+
*error = (char*)alloc(len * sizeof(char));
334+
strcpy(*error, e.what());
335+
(*error)[len-1] = '\0';
336+
delete ctx;
337+
return NULL;
338+
}
339+
return ctx;
329340
}
330341

331342
extern "C" void *torchLoadModel(const char *graph, size_t graphlen, DLDeviceType device,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include <string>
2+
#include "torch_redis.h"
3+
#include "../../redismodule.h"
4+
5+
torch::IValue IValueFromRedisReply(RedisModuleCallReply *reply){
6+
7+
int reply_type = RedisModule_CallReplyType(reply);
8+
switch(reply_type) {
9+
case REDISMODULE_REPLY_NULL: {
10+
return torch::IValue();
11+
}
12+
case REDISMODULE_REPLY_STRING: {
13+
size_t len;
14+
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len);
15+
std::string str = replyStr;
16+
return torch::IValue(str.substr(0,len));
17+
}
18+
case REDISMODULE_REPLY_INTEGER: {
19+
int intValue = RedisModule_CallReplyInteger(reply);
20+
return torch::IValue(intValue);
21+
}
22+
case REDISMODULE_REPLY_ARRAY: {
23+
c10::impl::GenericList vec = c10::impl::GenericList(c10::AnyType::create());
24+
size_t len = RedisModule_CallReplyLength(reply);
25+
for (auto i = 0; i < len; ++i) {
26+
RedisModuleCallReply *subReply = RedisModule_CallReplyArrayElement(reply, i);
27+
torch::IValue value = IValueFromRedisReply(subReply);
28+
vec.push_back(value);
29+
}
30+
return torch::IValue(vec);
31+
}
32+
case REDISMODULE_REPLY_ERROR: {
33+
size_t len;
34+
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len);
35+
throw std::runtime_error(replyStr);
36+
break;
37+
}
38+
default:{
39+
throw(std::runtime_error("Unsupported redis type"));
40+
}
41+
}
42+
}
43+
44+
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args ) {
45+
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL);
46+
RedisModule_ThreadSafeContextLock(ctx);
47+
size_t len = args.size();
48+
RedisModuleString* arguments[len];
49+
len = 0;
50+
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) {
51+
const std::string arg = *it;
52+
const char* str = arg.c_str();
53+
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str));
54+
}
55+
56+
RedisModuleCallReply *reply = RedisModule_Call(ctx, fn_name.c_str(), "!v", arguments, len);
57+
RedisModule_ThreadSafeContextUnlock(ctx);
58+
torch::IValue value = IValueFromRedisReply(reply);
59+
RedisModule_FreeThreadSafeContext(ctx);
60+
RedisModule_FreeCallReply(reply);
61+
for(int i= 0; i < len; i++){
62+
RedisModule_FreeString(NULL, arguments[i]);
63+
}
64+
return value;
65+
}
66+
67+
torch::List<torch::IValue> asList(torch::IValue v) {
68+
return v.toList();
69+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "torch/jit.h"
2+
#include "torch/script.h"
3+
#include "torch/csrc/jit/frontend/resolver.h"
4+
5+
namespace torch {
6+
namespace jit {
7+
namespace script {
8+
struct RedisResolver : public Resolver {
9+
10+
std::shared_ptr<SugaredValue> resolveValue(const std::string &name, Function &m,
11+
const SourceRange &loc) override {
12+
if (strcasecmp(name.c_str(), "torch") == 0) {
13+
return std::make_shared<BuiltinModule>("aten");
14+
} else if (strcasecmp(name.c_str(), "redis") == 0) {
15+
return std::make_shared<BuiltinModule>("redis");
16+
}
17+
return nullptr;
18+
}
19+
20+
TypePtr resolveType(const std::string &name, const SourceRange &loc) override {
21+
return nullptr;
22+
}
23+
};
24+
inline std::shared_ptr<RedisResolver> redisResolver() { return std::make_shared<RedisResolver>(); }
25+
} // namespace script
26+
} // namespace jit
27+
} // namespace torch
28+
29+
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args);
30+
torch::List<torch::IValue> asList(torch::IValue);
31+
32+
static auto registry =
33+
torch::RegisterOperators("redis::execute", &redisExecute).op("redis::asList", &asList);

src/redisai.c

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,9 +585,6 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
585585
return RedisAI_ScriptRun_IsKeysPositionRequest_ReportKeys(ctx, argv, argc);
586586
}
587587

588-
if (argc < 6)
589-
return RedisModule_WrongArity(ctx);
590-
591588
// Convert The script run command into a DAG command that contains a single op.
592589
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_SCRIPTRUN, false);
593590
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
2+
def redis_string_int_to_tensor(redis_value: Any):
3+
return torch.tensor(int(str(redis_value)))
4+
5+
6+
def redis_string_float_to_tensor(redis_value: Any):
7+
return torch.tensor(float(str((redis_value))))
8+
9+
10+
def redis_int_to_tensor(redis_value: int):
11+
return torch.tensor(redis_value)
12+
13+
14+
def redis_int_list_to_tensor(redis_value: Any):
15+
values = redis.asList(redis_value)
16+
l = [torch.tensor(int(str(v))).reshape(1,1) for v in values]
17+
return torch.cat(l, dim=0)
18+
19+
20+
def redis_hash_to_tensor(redis_value: Any):
21+
values = redis.asList(redis_value)
22+
l = [torch.tensor(int(str(v))).reshape(1,1) for v in values]
23+
return torch.cat(l, dim=0)
24+
25+
def test_redis_error():
26+
redis.execute("SET", "x")
27+
28+
def test_int_set_get():
29+
redis.execute("SET", "x", "1")
30+
res = redis.execute("GET", "x",)
31+
redis.execute("DEL", "x")
32+
return redis_string_int_to_tensor(res)
33+
34+
def test_int_set_incr():
35+
redis.execute("SET", "x", "1")
36+
res = redis.execute("INCR", "x")
37+
redis.execute("DEL", "x")
38+
return redis_string_int_to_tensor(res)
39+
40+
def test_float_set_get():
41+
redis.execute("SET", "x", "1.1")
42+
res = redis.execute("GET", "x",)
43+
redis.execute("DEL", "x")
44+
return redis_string_float_to_tensor(res)
45+
46+
def test_int_list():
47+
redis.execute("RPUSH", "x", "1")
48+
redis.execute("RPUSH", "x", "2")
49+
res = redis.execute("LRANGE", "x", "0", "2")
50+
redis.execute("DEL", "x")
51+
return redis_int_list_to_tensor(res)
52+
53+
54+
def test_hash():
55+
redis.execute("HSET", "x", "field1", "1", "field2", "2")
56+
res = redis.execute("HVALS", "x")
57+
redis.execute("DEL", "x")
58+
return redis_hash_to_tensor(res)
59+
60+
61+
def test_set_key():
62+
redis.execute("SET", ["x{1}", "1"])
63+
64+
65+
def test_del_key():
66+
redis.execute("DEL", ["x"])

0 commit comments

Comments
 (0)