@@ -68,20 +68,29 @@ def __get_tensor_by_scan(self, attr: TensorAttr) -> FeatureTensorType | None:
6868 if indices .step is None or indices .step == 1 :
6969 indices = np .arange (indices .start , indices .stop , dtype = np .uint64 )
7070 else :
71- indices = np .arange (indices .start , indices .stop , indices .step , dtype = np .uint64 )
71+ indices = np .arange (
72+ indices .start , indices .stop , indices .step , dtype = np .uint64
73+ )
7274 elif isinstance (indices , int ):
7375 indices = np .array ([indices ])
7476
7577 if table_name not in self .node_properties_cache :
76- self .node_properties_cache [table_name ] = self .connection ._get_node_property_names (table_name )
78+ self .node_properties_cache [table_name ] = (
79+ self .connection ._get_node_property_names (table_name )
80+ )
7781 attr_info = self .node_properties_cache [table_name ][attr_name ]
7882
7983 flat_dim = 1
8084 if attr_info ["dimension" ] > 0 :
8185 for i in range (attr_info ["dimension" ]):
8286 flat_dim *= attr_info ["shape" ][i ]
8387 scan_result = self .connection .database ._scan_node_table (
84- table_name , attr_name , attr_info ["type" ], flat_dim , indices , self .num_threads
88+ table_name ,
89+ attr_name ,
90+ attr_info ["type" ],
91+ flat_dim ,
92+ indices ,
93+ self .num_threads ,
8594 )
8695
8796 if attr_info ["dimension" ] > 0 and "shape" in attr_info :
@@ -151,11 +160,16 @@ def _get_tensor_size(self, attr: TensorAttr) -> tuple[Any, ...]:
151160 return (length ,) + attr_info ["shape" ]
152161
153162 def __get_node_property (self , table_name : str , attr_name : str ) -> dict [str , Any ]:
154- if table_name in self .node_properties_cache and attr_name in self .node_properties_cache [table_name ]:
163+ if (
164+ table_name in self .node_properties_cache
165+ and attr_name in self .node_properties_cache [table_name ]
166+ ):
155167 return self .node_properties_cache [table_name ][attr_name ]
156168 self .__get_connection ()
157169 if table_name not in self .node_properties_cache :
158- self .node_properties_cache [table_name ] = self .connection ._get_node_property_names (table_name )
170+ self .node_properties_cache [table_name ] = (
171+ self .connection ._get_node_property_names (table_name )
172+ )
159173 if attr_name not in self .node_properties_cache [table_name ]:
160174 msg = f"Attribute { attr_name } not found in group { table_name } "
161175 raise ValueError (msg )
@@ -168,7 +182,9 @@ def get_all_tensor_attrs(self) -> list[TensorAttr]:
168182 self .__get_connection ()
169183 for table_name in self .connection ._get_node_table_names ():
170184 if table_name not in self .node_properties_cache :
171- self .node_properties_cache [table_name ] = self .connection ._get_node_property_names (table_name )
185+ self .node_properties_cache [table_name ] = (
186+ self .connection ._get_node_property_names (table_name )
187+ )
172188 for attr_name in self .node_properties_cache [table_name ]:
173189 if self .node_properties_cache [table_name ][attr_name ]["type" ] in [
174190 Type .INT64 .value ,
0 commit comments