Skip to content

Commit 5b38e19

Browse files
author
DvirDukhan
committed
correct redis usage
1 parent ae398b9 commit 5b38e19

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

src/backends/torch.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@ int RAI_InitBackendTorch(int (*get_api_fn)(const char *, void *)) {
1212
get_api_fn("RedisModule_Realloc", ((void **)&RedisModule_Realloc));
1313
get_api_fn("RedisModule_Strdup", ((void **)&RedisModule_Strdup));
1414
get_api_fn("RedisModule_CreateString", ((void **)&RedisModule_CreateString));
15-
get_api_fn("RedisModule_GetThreadSafeContext", ((void **)&RedisModule_GetThreadSafeContext));
15+
get_api_fn("RedisModule_FreeString", ((void **)&RedisModule_FreeString));
1616
get_api_fn("RedisModule_Call", ((void **)&RedisModule_Call));
1717
get_api_fn("RedisModule_CallReplyType", ((void **)&RedisModule_CallReplyType));
1818
get_api_fn("RedisModule_CallReplyStringPtr", ((void **)&RedisModule_CallReplyStringPtr));
1919
get_api_fn("RedisModule_CallReplyInteger", ((void **)&RedisModule_CallReplyInteger));
2020
get_api_fn("RedisModule_CallReplyLength", ((void **)&RedisModule_CallReplyLength));
2121
get_api_fn("RedisModule_CallReplyArrayElement", ((void **)&RedisModule_CallReplyArrayElement));
2222
get_api_fn("RedisModule_FreeCallReply", ((void **)&RedisModule_FreeCallReply));
23+
get_api_fn("RedisModule_GetThreadSafeContext", ((void **)&RedisModule_GetThreadSafeContext));
24+
get_api_fn("RedisModule_ThreadSafeContextLock", ((void **)&RedisModule_ThreadSafeContextLock));
25+
get_api_fn("RedisModule_ThreadSafeContextUnlock", ((void **)&RedisModule_ThreadSafeContextUnlock));
2326
get_api_fn("RedisModule_FreeThreadSafeContext", ((void **)&RedisModule_FreeThreadSafeContext));
2427

2528
return REDISMODULE_OK;

src/libtorch_c/torch_extensions/torch_redis.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,28 @@ torch::IValue IValueFromRedisReply(RedisModuleCallReply *reply){
4242
}
4343

4444
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args ) {
45-
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL);
46-
size_t len = args.size();
47-
RedisModuleString* arguments[len];
48-
len = 0;
49-
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) {
50-
const std::string arg = *it;
51-
const char* str = arg.c_str();
52-
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str));
53-
}
45+
RedisModuleCtx* ctx = RedisModule_GetThreadSafeContext(NULL);
46+
RedisModule_ThreadSafeContextLock(ctx);
47+
size_t len = args.size();
48+
RedisModuleString* arguments[len];
49+
len = 0;
50+
for (std::vector<std::string>::iterator it = args.begin(); it != args.end(); it++) {
51+
const std::string arg = *it;
52+
const char* str = arg.c_str();
53+
arguments[len++] = RedisModule_CreateString(ctx, str, strlen(str));
54+
}
5455

55-
RedisModuleCallReply *reply = RedisModule_Call(ctx, fn_name.c_str(), "v", arguments, len);
56-
// RedisValue value = RedisValue::fromRedisReply(RedisModule_Call(ctx, fn_name.c_str(), "v", arguments, len));
57-
torch::IValue value = IValueFromRedisReply(reply);
58-
RedisModule_FreeThreadSafeContext(ctx);
59-
RedisModule_FreeCallReply(reply);
60-
return value;
56+
RedisModuleCallReply *reply = RedisModule_Call(ctx, fn_name.c_str(), "!v", arguments, len);
57+
RedisModule_ThreadSafeContextUnlock(ctx);
58+
torch::IValue value = IValueFromRedisReply(reply);
59+
RedisModule_FreeThreadSafeContext(ctx);
60+
RedisModule_FreeCallReply(reply);
61+
for(int i= 0; i < len; i++){
62+
RedisModule_FreeString(NULL, arguments[i]);
63+
}
64+
return value;
6165
}
6266

63-
6467
torch::List<torch::IValue> asList(torch::IValue v) {
6568
return v.toList();
66-
}
69+
}

src/libtorch_c/torch_extensions/torch_redis.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@ inline std::shared_ptr<RedisResolver> redisResolver() { return std::make_shared<
2626
} // namespace jit
2727
} // namespace torch
2828

29-
// c10::intrusive_ptr<RedisValue> redisExecute(std::string fn_name, std::vector<std::string> args );
30-
3129
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args);
3230
torch::List<torch::IValue> asList(torch::IValue);
3331

3432
static auto registry =
3533
torch::RegisterOperators("redis::execute", &redisExecute).op("redis::asList", &asList);
36-
// registry = torch::RegisterOperators("torch::asList", &asList);

0 commit comments

Comments
 (0)