2121#include <pthread.h>
2222#include "DAG/dag.h"
2323
24- RedisModuleType * RedisAI_ModelType = NULL ;
25-
26- static void * RAI_Model_RdbLoad (struct RedisModuleIO * io , int encver ) {
27- // if (encver != RAI_ENC_VER) {
28- // /* We should actually log an error here, or try to implement
29- // the ability to load older versions of our data structure. */
30- // return NULL;
31- // }
32-
33- RAI_Backend backend = RedisModule_LoadUnsigned (io );
34- const char * devicestr = RedisModule_LoadStringBuffer (io , NULL );
35-
36- RedisModuleString * tag = RedisModule_LoadString (io );
37-
38- const size_t batchsize = RedisModule_LoadUnsigned (io );
39- const size_t minbatchsize = RedisModule_LoadUnsigned (io );
40-
41- const size_t ninputs = RedisModule_LoadUnsigned (io );
42- const char * * inputs = RedisModule_Alloc (ninputs * sizeof (char * ));
43-
44- for (size_t i = 0 ; i < ninputs ; i ++ ) {
45- inputs [i ] = RedisModule_LoadStringBuffer (io , NULL );
46- }
47-
48- const size_t noutputs = RedisModule_LoadUnsigned (io );
49-
50- const char * * outputs = RedisModule_Alloc (ninputs * sizeof (char * ));
51-
52- for (size_t i = 0 ; i < noutputs ; i ++ ) {
53- outputs [i ] = RedisModule_LoadStringBuffer (io , NULL );
54- }
55-
56- RAI_ModelOpts opts = {
57- .batchsize = batchsize ,
58- .minbatchsize = minbatchsize ,
59- .backends_intra_op_parallelism = getBackendsIntraOpParallelism (),
60- .backends_inter_op_parallelism = getBackendsInterOpParallelism (),
61- };
62-
63- size_t len ;
64- char * buffer = NULL ;
65-
66- if (encver <= 100 ) {
67- buffer = RedisModule_LoadStringBuffer (io , & len );
68- } else {
69- len = RedisModule_LoadUnsigned (io );
70- buffer = RedisModule_Alloc (len );
71- const size_t n_chunks = RedisModule_LoadUnsigned (io );
72- long long chunk_offset = 0 ;
73- for (size_t i = 0 ; i < n_chunks ; i ++ ) {
74- size_t chunk_len ;
75- char * chunk_buffer = RedisModule_LoadStringBuffer (io , & chunk_len );
76- memcpy (buffer + chunk_offset , chunk_buffer , chunk_len );
77- chunk_offset += chunk_len ;
78- RedisModule_Free (chunk_buffer );
79- }
80- }
81-
82- RAI_Error err = {0 };
83-
84- RAI_Model * model = RAI_ModelCreate (backend , devicestr , tag , opts , ninputs , inputs , noutputs ,
85- outputs , buffer , len , & err );
86-
87- if (err .code == RAI_EBACKENDNOTLOADED ) {
88- RedisModuleCtx * ctx = RedisModule_GetContextFromIO (io );
89- int ret = RAI_LoadDefaultBackend (ctx , backend );
90- if (ret == REDISMODULE_ERR ) {
91- RedisModule_Log (ctx , "error" , "Could not load default backend" );
92- RAI_ClearError (& err );
93- return NULL ;
94- }
95- RAI_ClearError (& err );
96- model = RAI_ModelCreate (backend , devicestr , tag , opts , ninputs , inputs , noutputs , outputs ,
97- buffer , len , & err );
98- }
99-
100- if (err .code != RAI_OK ) {
101- RedisModuleCtx * ctx = RedisModule_GetContextFromIO (io );
102- RedisModule_Log (ctx , "error" , "%s" , err .detail );
103- RAI_ClearError (& err );
104- if (buffer ) {
105- RedisModule_Free (buffer );
106- }
107- return NULL ;
108- }
109-
110- for (size_t i = 0 ; i < ninputs ; i ++ ) {
111- RedisModule_Free (inputs [i ]);
112- }
113- for (size_t i = 0 ; i < noutputs ; i ++ ) {
114- RedisModule_Free (outputs [i ]);
115- }
116- RedisModule_Free (inputs );
117- RedisModule_Free (outputs );
118- RedisModule_Free (buffer );
119-
120- RedisModuleCtx * stats_ctx = RedisModule_GetContextFromIO (io );
121- RedisModuleString * stats_keystr =
122- RedisModule_CreateStringFromString (stats_ctx , RedisModule_GetKeyNameFromIO (io ));
123-
124- model -> infokey = RAI_AddStatsEntry (stats_ctx , stats_keystr , RAI_MODEL , backend , devicestr , tag );
125-
126- RedisModule_FreeString (NULL , tag );
127- RedisModule_Free (devicestr );
128- RedisModule_FreeString (NULL , stats_keystr );
129-
130- return model ;
131- }
132-
133- static void RAI_Model_RdbSave (RedisModuleIO * io , void * value ) {
134- RAI_Model * model = (RAI_Model * )value ;
135- char * buffer = NULL ;
136- size_t len = 0 ;
137- RAI_Error err = {0 };
138-
139- int ret = RAI_ModelSerialize (model , & buffer , & len , & err );
140-
141- if (err .code != RAI_OK ) {
142- RedisModuleCtx * stats_ctx = RedisModule_GetContextFromIO (io );
143- printf ("ERR: %s\n" , err .detail );
144- RAI_ClearError (& err );
145- if (buffer ) {
146- RedisModule_Free (buffer );
147- }
148- return ;
149- }
150-
151- RedisModule_SaveUnsigned (io , model -> backend );
152- RedisModule_SaveStringBuffer (io , model -> devicestr , strlen (model -> devicestr ) + 1 );
153- RedisModule_SaveString (io , model -> tag );
154- RedisModule_SaveUnsigned (io , model -> opts .batchsize );
155- RedisModule_SaveUnsigned (io , model -> opts .minbatchsize );
156- RedisModule_SaveUnsigned (io , model -> ninputs );
157- for (size_t i = 0 ; i < model -> ninputs ; i ++ ) {
158- RedisModule_SaveStringBuffer (io , model -> inputs [i ], strlen (model -> inputs [i ]) + 1 );
159- }
160- RedisModule_SaveUnsigned (io , model -> noutputs );
161- for (size_t i = 0 ; i < model -> noutputs ; i ++ ) {
162- RedisModule_SaveStringBuffer (io , model -> outputs [i ], strlen (model -> outputs [i ]) + 1 );
163- }
164- long long chunk_size = getModelChunkSize ();
165- const size_t n_chunks = len / chunk_size + 1 ;
166- RedisModule_SaveUnsigned (io , len );
167- RedisModule_SaveUnsigned (io , n_chunks );
168- for (size_t i = 0 ; i < n_chunks ; i ++ ) {
169- size_t chunk_len = i < n_chunks - 1 ? chunk_size : len % chunk_size ;
170- RedisModule_SaveStringBuffer (io , buffer + i * chunk_size , chunk_len );
171- }
172-
173- if (buffer ) {
174- RedisModule_Free (buffer );
175- }
176- }
177-
178- static void RAI_Model_AofRewrite (RedisModuleIO * aof , RedisModuleString * key , void * value ) {
179- RAI_Model * model = (RAI_Model * )value ;
180-
181- char * buffer = NULL ;
182- size_t len = 0 ;
183- RAI_Error err = {0 };
184-
185- int ret = RAI_ModelSerialize (model , & buffer , & len , & err );
186-
187- if (err .code != RAI_OK ) {
188-
189- printf ("ERR: %s\n" , err .detail );
190- RAI_ClearError (& err );
191- if (buffer ) {
192- RedisModule_Free (buffer );
193- }
194- return ;
195- }
196-
197- // AI.MODELSET model_key backend device [INPUTS name1 name2 ... OUTPUTS name1
198- // name2 ...] model_blob
199-
200- RedisModuleString * * inputs_ = array_new (RedisModuleString * , model -> ninputs );
201- RedisModuleString * * outputs_ = array_new (RedisModuleString * , model -> noutputs );
202-
203- RedisModuleCtx * ctx = RedisModule_GetContextFromIO (aof );
204-
205- for (size_t i = 0 ; i < model -> ninputs ; i ++ ) {
206- inputs_ = array_append (
207- inputs_ , RedisModule_CreateString (ctx , model -> inputs [i ], strlen (model -> inputs [i ])));
208- }
209-
210- for (size_t i = 0 ; i < model -> noutputs ; i ++ ) {
211- outputs_ = array_append (
212- outputs_ , RedisModule_CreateString (ctx , model -> outputs [i ], strlen (model -> outputs [i ])));
213- }
214-
215- long long chunk_size = getModelChunkSize ();
216- const size_t n_chunks = len / chunk_size + 1 ;
217- RedisModuleString * * buffers_ = array_new (RedisModuleString * , n_chunks );
218-
219- for (size_t i = 0 ; i < n_chunks ; i ++ ) {
220- size_t chunk_len = i < n_chunks - 1 ? chunk_size : len % chunk_size ;
221- buffers_ = array_append (buffers_ ,
222- RedisModule_CreateString (ctx , buffer + i * chunk_size , chunk_len ));
223- }
224-
225- if (buffer ) {
226- RedisModule_Free (buffer );
227- }
228-
229- const char * backendstr = RAI_BackendName (model -> backend );
230-
231- RedisModule_EmitAOF (aof , "AI.MODELSET" , "sccsclclcvcvcv" , key , backendstr , model -> devicestr ,
232- model -> tag , "BATCHSIZE" , model -> opts .batchsize , "MINBATCHSIZE" ,
233- model -> opts .minbatchsize , "INPUTS" , inputs_ , model -> ninputs , "OUTPUTS" ,
234- outputs_ , model -> noutputs , "BLOB" , buffers_ , n_chunks );
235-
236- for (size_t i = 0 ; i < model -> ninputs ; i ++ ) {
237- RedisModule_FreeString (ctx , inputs_ [i ]);
238- }
239- array_free (inputs_ );
240-
241- for (size_t i = 0 ; i < model -> noutputs ; i ++ ) {
242- RedisModule_FreeString (ctx , outputs_ [i ]);
243- }
244- array_free (outputs_ );
245-
246- for (size_t i = 0 ; i < n_chunks ; i ++ ) {
247- RedisModule_FreeString (ctx , buffers_ [i ]);
248- }
249- array_free (buffers_ );
250- }
251-
25224/* Return REDISMODULE_ERR if there was an error getting the Model.
25325 * Return REDISMODULE_OK if the model value stored at key was correctly
25426 * returned and available at *model variable. */
@@ -270,29 +42,6 @@ int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, Re
27042 return REDISMODULE_OK ;
27143}
27244
273- // TODO: pass err in?
274- static void RAI_Model_DTFree (void * value ) {
275- RAI_Error err = {0 };
276- RAI_ModelFree (value , & err );
277- if (err .code != RAI_OK ) {
278- printf ("ERR: %s\n" , err .detail );
279- RAI_ClearError (& err );
280- }
281- }
282-
283- int RAI_ModelInit (RedisModuleCtx * ctx ) {
284- RedisModuleTypeMethods tmModel = {.version = REDISMODULE_TYPE_METHOD_VERSION ,
285- .rdb_load = RAI_Model_RdbLoad ,
286- .rdb_save = RAI_Model_RdbSave ,
287- .aof_rewrite = RAI_Model_AofRewrite ,
288- .mem_usage = NULL ,
289- .free = RAI_Model_DTFree ,
290- .digest = NULL };
291-
292- RedisAI_ModelType = RedisModule_CreateDataType (ctx , "AI__MODEL" , RAI_ENC_VER_MM , & tmModel );
293- return RedisAI_ModelType != NULL ;
294- }
295-
29645RAI_Model * RAI_ModelCreate (RAI_Backend backend , const char * devicestr , RedisModuleString * tag ,
29746 RAI_ModelOpts opts , size_t ninputs , const char * * inputs , size_t noutputs ,
29847 const char * * outputs , const char * modeldef , size_t modellen ,
0 commit comments