@@ -94,14 +94,20 @@ int Tensor_DataTypeStr(DLDataType dtype, char *dtypestr) {
9494 return result ;
9595}
9696
97+ RAI_Tensor * RAI_TensorNew (void ) {
98+ RAI_Tensor * ret = RedisModule_Calloc (1 , sizeof (* ret ));
99+ ret -> refCount = 1 ;
100+ ret -> len = LEN_UNKOWN ;
101+ }
102+
97103RAI_Tensor * RAI_TensorCreateWithDLDataType (DLDataType dtype , long long * dims , int ndims ,
98104 int tensorAllocMode ) {
99105 const size_t dtypeSize = Tensor_DataTypeSize (dtype );
100106 if (dtypeSize == 0 ) {
101107 return NULL ;
102108 }
103109
104- RAI_Tensor * ret = RedisModule_Alloc ( sizeof ( * ret ) );
110+ RAI_Tensor * ret = RAI_TensorNew ( );
105111 int64_t * shape = RedisModule_Alloc (ndims * sizeof (* shape ));
106112 int64_t * strides = RedisModule_Alloc (ndims * sizeof (* strides ));
107113
@@ -147,7 +153,6 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
147153 .manager_ctx = NULL ,
148154 .deleter = NULL };
149155
150- ret -> refCount = 1 ;
151156 return ret ;
152157}
153158
@@ -173,7 +178,7 @@ RAI_Tensor *RAI_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, long long
173178 return NULL ;
174179 }
175180
176- RAI_Tensor * ret = RedisModule_Alloc ( sizeof ( * ret ) );
181+ RAI_Tensor * ret = RAI_TensorNew ( );
177182 int64_t * shape = RedisModule_Alloc (ndims * sizeof (* shape ));
178183 int64_t * strides = RedisModule_Alloc (ndims * sizeof (* strides ));
179184
@@ -201,7 +206,6 @@ RAI_Tensor *RAI_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, long long
201206 .manager_ctx = rstr ,
202207 .deleter = RAI_RStringDataTensorDeleter };
203208
204- ret -> refCount = 1 ;
205209 return ret ;
206210}
207211
@@ -330,7 +334,7 @@ int RAI_TensorDeepCopy(RAI_Tensor *t, RAI_Tensor **dest) {
330334// Beware: this will take ownership of dltensor
331335RAI_Tensor * RAI_TensorCreateFromDLTensor (DLManagedTensor * dl_tensor ) {
332336
333- RAI_Tensor * ret = RedisModule_Calloc ( 1 , sizeof ( * ret ) );
337+ RAI_Tensor * ret = RAI_TensorNew ( );
334338
335339 ret -> tensor =
336340 (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = dl_tensor -> dl_tensor .ctx ,
@@ -343,7 +347,6 @@ RAI_Tensor *RAI_TensorCreateFromDLTensor(DLManagedTensor *dl_tensor) {
343347 .manager_ctx = dl_tensor -> manager_ctx ,
344348 .deleter = dl_tensor -> deleter };
345349
346- ret -> refCount = 1 ;
347350 return ret ;
348351}
349352
@@ -356,12 +359,15 @@ int RAI_TensorIsDataTypeEqual(RAI_Tensor *t1, RAI_Tensor *t2) {
356359}
357360
358361size_t RAI_TensorLength (RAI_Tensor * t ) {
359- int64_t * shape = t -> tensor .dl_tensor .shape ;
360- size_t len = 1 ;
361- for (size_t i = 0 ; i < t -> tensor .dl_tensor .ndim ; ++ i ) {
362- len *= shape [i ];
362+ if (t -> len == LEN_UNKOWN ) {
363+ int64_t * shape = t -> tensor .dl_tensor .shape ;
364+ size_t len = 1 ;
365+ for (size_t i = 0 ; i < t -> tensor .dl_tensor .ndim ; ++ i ) {
366+ len *= shape [i ];
367+ }
368+ t -> len = len ;
363369 }
364- return len ;
370+ return t -> len ;
365371}
366372
367373size_t RAI_TensorDataSize (RAI_Tensor * t ) { return Tensor_DataTypeSize (RAI_TensorDataType (t )); }
0 commit comments