Skip to content

Commit 3891559

Browse files
author
DvirDukhan
committed
simple scalars round trip
1 parent 152092a commit 3891559

File tree

8 files changed

+110
-96
lines changed

8 files changed

+110
-96
lines changed

src/backends/torch.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ int RAI_InitBackendTorch(int (*get_api_fn)(const char *, void *)) {
1414
get_api_fn("RedisModule_CreateString", ((void **)&RedisModule_CreateString));
1515
get_api_fn("RedisModule_GetThreadSafeContext", ((void **)&RedisModule_GetThreadSafeContext));
1616
get_api_fn("RedisModule_Call", ((void **)&RedisModule_Call));
17+
get_api_fn("RedisModule_CallReplyType", ((void **)&RedisModule_CallReplyType));
18+
get_api_fn("RedisModule_CallReplyStringPtr", ((void **)&RedisModule_CallReplyStringPtr));
19+
get_api_fn("RedisModule_CallReplyInteger", ((void **)&RedisModule_CallReplyInteger));
20+
get_api_fn("RedisModule_CallReplyLength", ((void **)&RedisModule_CallReplyLength));
21+
get_api_fn("RedisModule_CallReplyArrayElement", ((void **)&RedisModule_CallReplyArrayElement));
22+
get_api_fn("RedisModule_FreeCallReply", ((void **)&RedisModule_FreeCallReply));
1723
get_api_fn("RedisModule_FreeThreadSafeContext", ((void **)&RedisModule_FreeThreadSafeContext));
1824

1925
return REDISMODULE_OK;

src/libtorch_c/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
add_library(torch_c STATIC torch_c.cpp)
1+
add_library(torch_c STATIC torch_c.cpp torch_extensions/torch_redis.cpp)
22
target_link_libraries(torch_c "${TORCH_LIBRARIES}")
33
set_property(TARGET torch_c PROPERTY CXX_STANDARD 14)

src/libtorch_c/torch_c.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,13 @@ extern "C" void* torchCompileScript(const char* script, DLDeviceType device, int
316316
torch::jit::script::redisResolver(),
317317
nullptr);
318318
auto aten_device_type = getATenDeviceType(device);
319+
319320
if (aten_device_type == at::DeviceType::CUDA && !torch::cuda::is_available()) {
320321
throw std::logic_error("GPU requested but Torch couldn't find CUDA");
321322
}
322323
ctx->cu = cu;
323324
ctx->module = nullptr;
325+
324326
}
325327
catch(std::exception& e) {
326328
size_t len = strlen(e.what()) +1;
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include <string>
2+
#include "torch_redis.h"
3+
#include "../../redismodule.h"
4+
5+
torch::IValue IValueFromRedisReply(RedisModuleCallReply *reply){
6+
7+
int reply_type = RedisModule_CallReplyType(reply);
8+
switch(reply_type) {
9+
case REDISMODULE_REPLY_NULL: {
10+
return torch::IValue();
11+
}
12+
case REDISMODULE_REPLY_STRING: {
13+
size_t len;
14+
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len);
15+
std::string str = replyStr;
16+
return torch::IValue(str.substr(0,len));
17+
}
18+
case REDISMODULE_REPLY_INTEGER: {
19+
int intValue = RedisModule_CallReplyInteger(reply);
20+
return torch::IValue(intValue);
21+
}
22+
case REDISMODULE_REPLY_ARRAY: {
23+
c10::impl::GenericList vec = c10::impl::GenericList(c10::AnyType::create());
24+
size_t len = RedisModule_CallReplyLength(reply);
25+
for (auto i = 0; i < len; ++i) {
26+
RedisModuleCallReply *subReply = RedisModule_CallReplyArrayElement(reply, i);
27+
torch::IValue value = IValueFromRedisReply(subReply);
28+
vec.push_back(value);
29+
}
30+
return torch::IValue(vec);
31+
}
32+
case REDISMODULE_REPLY_ERROR: {
33+
size_t len;
34+
const char *replyStr = RedisModule_CallReplyStringPtr(reply, &len);
35+
throw std::runtime_error(replyStr);
36+
break;
37+
}
38+
default:{
39+
throw(std::runtime_error("Unsupported redis type"));
40+
}
41+
}
42+
}
43+
44+
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args ) {
45+
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL);
46+
size_t len = args.size();
47+
RedisModuleString* arguments[len];
48+
len = 0;
49+
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) {
50+
const std::string arg = *it;
51+
const char* str = arg.c_str();
52+
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str));
53+
}
54+
55+
RedisModuleCallReply *reply = RedisModule_Call(ctx, fn_name.c_str(), "v", arguments, len);
56+
// RedisValue value = RedisValue::fromRedisReply(RedisModule_Call(ctx, fn_name.c_str(), "v", arguments, len));
57+
torch::IValue value = IValueFromRedisReply(reply);
58+
RedisModule_FreeThreadSafeContext(ctx);
59+
RedisModule_FreeCallReply(reply);
60+
return value;
61+
}

src/libtorch_c/torch_extensions/torch_redis.h

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
#include "../../redismodule.h"
2-
#include "../../util/arr.h"
3-
4-
#include "torch/csrc/jit/frontend/resolver.h"
5-
#include "torch/script.h"
61
#include "torch/jit.h"
2+
#include "torch/script.h"
3+
#include "torch/csrc/jit/frontend/resolver.h"
4+
5+
#include "torch_redis_value.h"
76

87
namespace torch {
98
namespace jit {
@@ -32,18 +31,11 @@ namespace torch {
3231
}
3332
}
3433

35-
void redisExecute(std::string fn_name, std::vector<std::string> args ) {
36-
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL);
37-
size_t len = args.size();
38-
RedisModuleString* arguments[len];
39-
len = 0;
40-
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) {
41-
const std::string arg = *it;
42-
const char* str = arg.c_str();
43-
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str));
44-
}
45-
RedisModule_Call(ctx, fn_name.c_str(), "v", arguments, len);
46-
RedisModule_FreeThreadSafeContext(ctx);
47-
}
34+
35+
// c10::intrusive_ptr<RedisValue> redisExecute(std::string fn_name, std::vector<std::string> args );
36+
37+
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args );
38+
39+
4840

4941
static auto registry = torch::RegisterOperators("redis::execute", &redisExecute);

src/libtorch_c/torch_extensions/torch_redis_value.h

Lines changed: 0 additions & 46 deletions
This file was deleted.

tests/flow/test_data/redis_scripts.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
# def redis_string_to_int(redis_value: RedisValue):
2-
# return int(redis_value.stringValue())
31

2+
def redis_string_int_to_tensor(redis_value: Any):
3+
return torch.tensor(int(str(redis_value)))
44

5-
# def redis_string_to_float(redis_value: RedisValue):
6-
# return float(redis_value.stringValue())
75

8-
9-
# def redis_string_int_to_tensor(redis_value: RedisValue):
10-
# return tensor(redis_string_to_int(redis_value))
11-
12-
13-
# def redis_string_float_to_tensor(redis_value: RedisValue):
14-
# return tensor(redis_string_to_float(redis_value))
6+
def redis_string_float_to_tensor(redis_value: Any):
7+
return torch.tensor(float(str((redis_value))))
158

169

1710
# def redis_int_to_tensor(redis_value: RedisValue):
@@ -45,17 +38,17 @@
4538
# res = redis.executeCommand("SET", "x")
4639
# return tensor(res.getValueType())
4740

48-
# def test_int_set_get():
49-
# redis.executeCommand("SET", "x", "1")
50-
# res = redis.executeCommand("GET", "x",)
51-
# redis.executeCommand("DEL", "x")
52-
# return redis_string_int_to_tensor(res)
53-
54-
# def test_float_set_get():
55-
# redis.executeCommand("SET", "x", "1.1")
56-
# res = redis.executeCommand("GET", "x",)
57-
# redis.executeCommand("DEL", "x")
58-
# return redis_string_int_to_tensor(res)
41+
def test_int_set_get():
42+
redis.execute("SET", "x", "1")
43+
res = redis.execute("GET", "x",)
44+
redis.execute("DEL", "x")
45+
return redis_string_int_to_tensor(res)
46+
47+
def test_float_set_get():
48+
redis.execute("SET", "x", "1.1")
49+
res = redis.execute("GET", "x",)
50+
redis.execute("DEL", "x")
51+
return redis_string_float_to_tensor(res)
5952

6053
# def test_int_list():
6154
# redis.executeCommand("LPUSH", "x", "1")

tests/flow/test_torchscript_extensions.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import redis
23

34
from includes import *
@@ -28,17 +29,22 @@ def __init__(self):
2829
self.env.assertEqual(ret, b'OK')
2930
# self.env.ensureSlaveSynced(self.con, self.env)
3031

31-
# def test_float_get_set(self):
32-
# self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_float_set_get', 'INPUTS', 'OUTPUTS', 'y')
33-
# y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
34-
# self.env.assertEqual(y, ["dtype", "FLOAT", "shape", [0, 1], "VALUES", "1.1"])
35-
3632
def test_simple_test_set(self):
3733
self.con.execute_command(
3834
'AI.SCRIPTRUN', 'redis_scripts', 'test_set_key')
3935
self.env.assertEqual(b"1", self.con.get("x"))
4036

41-
# def test_int_get_set(self):
42-
# self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_int_set_get', 'OUTPUTS', 'y')
43-
# y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
44-
# self.env.assertEqual(y, ["dtype", "INT64", "shape", [0, 1], "VALUES", "1"] )
37+
def test_int_get_set(self):
38+
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_int_set_get', 'OUTPUTS', 'y')
39+
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
40+
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [], b"values", [1]] )
41+
42+
def test_float_get_set(self):
43+
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_float_set_get', 'OUTPUTS', 'y')
44+
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
45+
self.env.assertEqual(y[0], b"dtype")
46+
self.env.assertEqual(y[1], b"FLOAT")
47+
self.env.assertEqual(y[2], b"shape")
48+
self.env.assertEqual(y[3], [])
49+
self.env.assertEqual(y[4], b"values")
50+
self.env.assertAlmostEqual(float(y[5][0]), 1.1, 0.1)

0 commit comments

Comments
 (0)