Skip to content

Commit 85cb8c6

Browse files
author
DvirDukhan
committed
wip
1 parent 665b358 commit 85cb8c6

File tree

4 files changed

+184
-0
lines changed

4 files changed

+184
-0
lines changed

src/libtorch_c/torch_c.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,40 @@
77
#include <iostream>
88
#include <sstream>
99

10+
11+
#include "torch/csrc/jit/frontend/resolver.h"
12+
#include "torch/script.h"
13+
#include "torch/jit.h"
14+
15+
namespace torch {
16+
namespace jit {
17+
namespace script {
18+
struct RedisResolver : public Resolver {
19+
20+
std::shared_ptr<SugaredValue> resolveValue(const std::string& name, Function& m, const SourceRange& loc) override {
21+
if(strcasecmp(name.c_str(), "torch") == 0) {
22+
return std::make_shared<BuiltinModule>("aten");
23+
}
24+
else if (strcasecmp(name.c_str(), "redis") == 0) {
25+
return std::make_shared<BuiltinModule>("redis");
26+
}
27+
return nullptr;
28+
}
29+
30+
TypePtr resolveType(const std::string& name, const SourceRange& loc) override {
31+
return nullptr;
32+
}
33+
34+
inline std::shared_ptr<RedisResolver> redisResolver() {
35+
return std::make_shared<RedisResolver>();
36+
}
37+
38+
};
39+
}
40+
}
41+
}
42+
43+
1044
namespace {
1145

1246
static DLDataType getDLDataType(const at::Tensor &t) {
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
#include "torch/script.h"
3+
#include "torch/custom_class.h"
4+
#include "../../redismodule.h"
5+
6+
7+
8+
struct RedisValue: torch::CustomClassHolder {
9+
private:
10+
union {
11+
int intValue;
12+
std::string stringValue;
13+
std::vector<RedisValue*> arrayValue;
14+
};
15+
public:
16+
RedisValue(RedisModuleCallReply *reply) {
17+
if(RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_ARRAY) {
18+
size_t len = RedisModule_CallReplyLength(reply);
19+
for(auto i = 0 ; i < len ; ++i){
20+
RedisModuleCallReply *subReply = RedisModule_CallReplyArrayElement(reply, i);
21+
RedisValue value(subReply);
22+
arrayValue.push_back(value);
23+
}
24+
}
25+
26+
if(RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_STRING ||
27+
RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_ERROR){
28+
size_t len;
29+
const char* replyStr = RedisModule_CallReplyStringPtr(reply, &len);
30+
PyObject* ret = PyUnicode_FromStringAndSize(replyStr, len);
31+
if(!ret){
32+
PyErr_Clear();
33+
ret = PyByteArray_FromStringAndSize(replyStr, len);
34+
}
35+
return ret;
36+
}
37+
38+
if(RedisModule_CallReplyType(reply) == REDISMODULE_REPLY_INTEGER){
39+
long long val = RedisModule_CallReplyInteger(reply);
40+
return PyLong_FromLongLong(val);
41+
}
42+
}
43+
44+
};

test/test_data/redis_scripts.txt

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
def redis_string_to_int(redis_value: RedisValue):
2+
return int(redis_value.stringValue())
3+
4+
def redis_string_to_float(redis_value: RedisValue):
5+
return float(redis_value.stringValue())
6+
7+
def redis_string_int_to_tensor(redis_value: RedisValue):
8+
return tensor(redis_string_to_int(redis_value))
9+
10+
def redis_string_float_to_tensor(redis_value: RedisValue):
11+
return tensor(redis_string_to_float(redis_value))
12+
13+
def redis_int_to_tensor(redis_value: RedisValue):
14+
return tensor(redis_value.intValue())
15+
16+
def redis_int_list_to_tensor(redis_value: RedisValue):
17+
len = len(redis_value.getList())
18+
l = []
19+
for v in redis_value.getList():
20+
l.append(redis_string_to_int(v))
21+
return torch.cat(l, dim=0)
22+
23+
def redis_float_list_to_tensor(redis_value: RedisValue):
24+
len = len(redis_value.getList())
25+
l = []
26+
for v in redis_value.getList():
27+
l.append(redis_string_to_float(v))
28+
return torch.cat(l, dim=0)
29+
30+
def redis_hash_to_tensor(redis_value: RedisValue):
31+
len = len(redis_value.getList())
32+
l = []
33+
for v in redis_value.getList():
34+
l.append(redis_string_to_float(v.getList()[1]))
35+
return torch.cat(l, dim=0)
36+
37+
def test_redis_error():
38+
res = redis.executeCommand("SET", "x")
39+
return tensor(res.getValueType())
40+
41+
def test_int_set_get():
42+
redis.executeCommand("SET", "x", "1")
43+
res = redis.executeCommand("GET", "x",)
44+
redis.executeCommand("DEL", "x")
45+
return redis_string_int_to_tensor(res)
46+
47+
def test_float_set_get():
48+
redis.executeCommand("SET", "x", "1.1")
49+
res = redis.executeCommand("GET", "x",)
50+
redis.executeCommand("DEL", "x")
51+
return redis_string_int_to_tensor(res)
52+
53+
def test_int_list():
54+
redis.executeCommand("LPUSH", "x", "1")
55+
redis.executeCommand("LPUSH", "x", "2")
56+
res = redis.executeCommand("LRANGE", "x")
57+
redis.executeCommand("DEL", "x")
58+
return redis_int_list_to_tensor(res)
59+
60+
def test_float_list():
61+
redis.executeCommand("LPUSH", "x", "1.1")
62+
redis.executeCommand("LPUSH", "x", "2.2")
63+
res = redis.executeCommand("LRANGE", "x")
64+
redis.executeCommand("DEL", "x")
65+
return redis_float_list_to_tensor(res)
66+
67+
def test_hash():
68+
redis.executeCommand("HSET", "x", "1", "2.2)
69+
res = redis.executeCommand("HGETALL", "x")
70+
redis.executeCommand("DEL", "x")
71+
return redis_float_list_to_tensor(res)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import redis
2+
3+
from includes import *
4+
5+
'''
6+
python -m RLTest --test tests_torchscript_extensions.py --module path/to/redisai.so
7+
'''
8+
9+
10+
class test_torch_script_extesions:
11+
12+
def __init__(self):
13+
self.env = Env()
14+
if not TEST_PT:
15+
self.env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
16+
self.env.skip()
17+
18+
self.con = self.env.getConnect()
19+
script_filename = os.path.join(test_data_path, 'redis_scripts.txt')
20+
with open(script_filename, 'rb') as f:
21+
script = f.read()
22+
23+
ret = self.con.execute_command('AI.SCRIPTSET', 'redis_scripts', DEVICE, 'SOURCE', script)
24+
self.env.assertEqual(ret, b'OK')
25+
self.env.ensureSlaveSynced(self.con, self.env)
26+
27+
def test_int_get_set(self):
28+
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_int_set_get', 'INPUTS', 'OUTPUTS', 'y')
29+
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
30+
self.env.assertEqual(y, ["dtype", "INT64", "shape", [0, 1], "VALUES", "1"] )
31+
32+
def test_float_get_set(self):
33+
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_float_set_get', 'INPUTS', 'OUTPUTS', 'y')
34+
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
35+
self.env.assertEqual(y, ["dtype", "FLOAT", "shape", [0, 1], "VALUES", "1.1"])

0 commit comments

Comments
 (0)