@@ -95,6 +95,12 @@ int Tensor_DataTypeStr(DLDataType dtype, char *dtypestr) {
9595 return result ;
9696}
9797
98+ RAI_Tensor * RAI_TensorNew (void ) {
99+ RAI_Tensor * ret = RedisModule_Calloc (1 , sizeof (* ret ));
100+ ret -> refCount = 1 ;
101+ ret -> len = LEN_UNKOWN ;
102+ }
103+
98104RAI_Tensor * RAI_TensorCreateWithDLDataType (DLDataType dtype , long long * dims , int ndims ,
99105 int tensorAllocMode ) {
100106
@@ -103,7 +109,7 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
103109 return NULL ;
104110 }
105111
106- RAI_Tensor * ret = RedisModule_Alloc ( sizeof ( * ret ) );
112+ RAI_Tensor * ret = RAI_TensorNew ( );
107113 int64_t * shape = RedisModule_Alloc (ndims * sizeof (* shape ));
108114 int64_t * strides = RedisModule_Alloc (ndims * sizeof (* strides ));
109115
@@ -144,7 +150,6 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
144150 .manager_ctx = NULL ,
145151 .deleter = NULL };
146152
147- ret -> refCount = 1 ;
148153 return ret ;
149154}
150155
@@ -195,7 +200,7 @@ RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtype
195200 memcpy (data , blob , nbytes );
196201 RAI_HoldString (NULL , rstr );
197202
198- RAI_Tensor * ret = RedisModule_Alloc ( sizeof ( * ret ) );
203+ RAI_Tensor * ret = RAI_TensorNew ( );
199204 ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = ctx ,
200205 .data = data ,
201206 .ndim = ndims ,
@@ -206,7 +211,6 @@ RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtype
206211 .manager_ctx = rstr ,
207212 .deleter = RAI_RStringDataTensorDeleter };
208213
209- ret -> refCount = 1 ;
210214 return ret ;
211215}
212216
@@ -335,7 +339,7 @@ int RAI_TensorDeepCopy(RAI_Tensor *t, RAI_Tensor **dest) {
335339// Beware: this will take ownership of dltensor
336340RAI_Tensor * RAI_TensorCreateFromDLTensor (DLManagedTensor * dl_tensor ) {
337341
338- RAI_Tensor * ret = RedisModule_Calloc ( 1 , sizeof ( * ret ) );
342+ RAI_Tensor * ret = RAI_TensorNew ( );
339343
340344 ret -> tensor =
341345 (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = dl_tensor -> dl_tensor .ctx ,
@@ -348,7 +352,6 @@ RAI_Tensor *RAI_TensorCreateFromDLTensor(DLManagedTensor *dl_tensor) {
348352 .manager_ctx = dl_tensor -> manager_ctx ,
349353 .deleter = dl_tensor -> deleter };
350354
351- ret -> refCount = 1 ;
352355 return ret ;
353356}
354357
@@ -361,12 +364,15 @@ int RAI_TensorIsDataTypeEqual(RAI_Tensor *t1, RAI_Tensor *t2) {
361364}
362365
363366size_t RAI_TensorLength (RAI_Tensor * t ) {
364- int64_t * shape = t -> tensor .dl_tensor .shape ;
365- size_t len = 1 ;
366- for (size_t i = 0 ; i < t -> tensor .dl_tensor .ndim ; ++ i ) {
367- len *= shape [i ];
367+ if (t -> len == LEN_UNKOWN ) {
368+ int64_t * shape = t -> tensor .dl_tensor .shape ;
369+ size_t len = 1 ;
370+ for (size_t i = 0 ; i < t -> tensor .dl_tensor .ndim ; ++ i ) {
371+ len *= shape [i ];
372+ }
373+ t -> len = len ;
368374 }
369- return len ;
375+ return t -> len ;
370376}
371377
372378size_t RAI_TensorDataSize (RAI_Tensor * t ) { return Tensor_DataTypeSize (RAI_TensorDataType (t )); }
0 commit comments