Skip to content

Commit ce555a1

Browse files
committed
Run black
1 parent c6e5eb1 commit ce555a1

39 files changed

Lines changed: 1687 additions & 743 deletions

src_py/async_connection.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,17 @@ async def execute(
168168
conn, conn_index = self.__get_connection_with_least_queries()
169169

170170
try:
171-
return await loop.run_in_executor(self.executor, conn.execute, query, parameters)
171+
return await loop.run_in_executor(
172+
self.executor, conn.execute, query, parameters
173+
)
172174
except asyncio.CancelledError:
173175
conn.interrupt()
174176
finally:
175177
self.__decrement_connection_counter(conn_index)
176178

177-
async def _prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement:
179+
async def _prepare(
180+
self, query: str, parameters: dict[str, Any] | None = None
181+
) -> PreparedStatement:
178182
"""
179183
The only parameters supported during prepare are dataframes.
180184
Any remaining parameters will be ignored and should be passed to execute().
@@ -183,12 +187,16 @@ async def _prepare(self, query: str, parameters: dict[str, Any] | None = None) -
183187
conn, conn_index = self.__get_connection_with_least_queries()
184188

185189
try:
186-
preparedStatement = await loop.run_in_executor(self.executor, conn.prepare, query, parameters)
190+
preparedStatement = await loop.run_in_executor(
191+
self.executor, conn.prepare, query, parameters
192+
)
187193
return preparedStatement
188194
finally:
189195
self.__decrement_connection_counter(conn_index)
190196

191-
async def prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement:
197+
async def prepare(
198+
self, query: str, parameters: dict[str, Any] | None = None
199+
) -> PreparedStatement:
192200
"""
193201
Create a prepared statement for a query asynchronously.
194202

src_py/connection.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import TYPE_CHECKING, Any, Callable
4+
from collections.abc import Callable
5+
from typing import TYPE_CHECKING, Any
56

67
from . import _lbug
78
from .prepared_statement import PreparedStatement
@@ -130,8 +131,12 @@ def execute(
130131
if len(parameters) == 0 and isinstance(query, str):
131132
query_result_internal = self._connection.query(query)
132133
else:
133-
prepared_statement = self._prepare(query, parameters) if isinstance(query, str) else query
134-
query_result_internal = self._connection.execute(prepared_statement._prepared_statement, parameters)
134+
prepared_statement = (
135+
self._prepare(query, parameters) if isinstance(query, str) else query
136+
)
137+
query_result_internal = self._connection.execute(
138+
prepared_statement._prepared_statement, parameters
139+
)
135140
if not query_result_internal.isSuccess():
136141
raise RuntimeError(query_result_internal.getErrorMessage())
137142
current_query_result = QueryResult(self, query_result_internal)
@@ -234,7 +239,9 @@ def _get_rel_table_names(self) -> list[dict[str, Any]]:
234239
row = tables_result.get_next()
235240
if row[2] == "REL":
236241
name = row[1]
237-
connections_result = self.execute(f"CALL show_connection({name!r}) RETURN *;")
242+
connections_result = self.execute(
243+
f"CALL show_connection({name!r}) RETURN *;"
244+
)
238245
src_dst_row = connections_result.get_next()
239246
src_node = src_dst_row[0]
240247
dst_node = src_dst_row[1]
@@ -345,7 +352,9 @@ def create_arrow_table(
345352
346353
"""
347354
self.init_connection()
348-
query_result_internal = self._connection.create_arrow_table(table_name, dataframe)
355+
query_result_internal = self._connection.create_arrow_table(
356+
table_name, dataframe
357+
)
349358
if not query_result_internal.isSuccess():
350359
raise RuntimeError(query_result_internal.getErrorMessage())
351360
return QueryResult(self, query_result_internal)

src_py/database.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,19 +253,29 @@ def _scan_node_table(
253253

254254
if prop_type == Type.INT64.value:
255255
result = np.empty(len(indices) * dim, dtype=np.int64)
256-
self._database.scan_node_table_as_int64(table_name, prop_name, indices_cast, result, num_threads)
256+
self._database.scan_node_table_as_int64(
257+
table_name, prop_name, indices_cast, result, num_threads
258+
)
257259
elif prop_type == Type.INT32.value:
258260
result = np.empty(len(indices) * dim, dtype=np.int32)
259-
self._database.scan_node_table_as_int32(table_name, prop_name, indices_cast, result, num_threads)
261+
self._database.scan_node_table_as_int32(
262+
table_name, prop_name, indices_cast, result, num_threads
263+
)
260264
elif prop_type == Type.INT16.value:
261265
result = np.empty(len(indices) * dim, dtype=np.int16)
262-
self._database.scan_node_table_as_int16(table_name, prop_name, indices_cast, result, num_threads)
266+
self._database.scan_node_table_as_int16(
267+
table_name, prop_name, indices_cast, result, num_threads
268+
)
263269
elif prop_type == Type.DOUBLE.value:
264270
result = np.empty(len(indices) * dim, dtype=np.float64)
265-
self._database.scan_node_table_as_double(table_name, prop_name, indices_cast, result, num_threads)
271+
self._database.scan_node_table_as_double(
272+
table_name, prop_name, indices_cast, result, num_threads
273+
)
266274
elif prop_type == Type.FLOAT.value:
267275
result = np.empty(len(indices) * dim, dtype=np.float32)
268-
self._database.scan_node_table_as_float(table_name, prop_name, indices_cast, result, num_threads)
276+
self._database.scan_node_table_as_float(
277+
table_name, prop_name, indices_cast, result, num_threads
278+
)
269279

270280
if result is not None:
271281
return result

src_py/prepared_statement.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@ class PreparedStatement:
1212
same query for repeated execution.
1313
"""
1414

15-
def __init__(self, connection: Connection, query: str, parameters: dict[str, Any] | None = None):
15+
def __init__(
16+
self,
17+
connection: Connection,
18+
query: str,
19+
parameters: dict[str, Any] | None = None,
20+
):
1621
"""
1722
Parameters
1823
----------

src_py/query_result.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
from typing import TYPE_CHECKING
44

5+
from .constants import DST, ID, LABEL, NODES, RELS, SRC
56
from .torch_geometric_result_converter import TorchGeometricResultConverter
67
from .types import Type
78

8-
from .constants import ID, LABEL, SRC, DST, NODES, RELS
9-
109
if TYPE_CHECKING:
1110
import sys
1211
from collections.abc import Iterator
@@ -191,7 +190,9 @@ def get_as_pl(self) -> pl.DataFrame:
191190
data=self.get_as_arrow(chunk_size=-1, fallbackExtensionTypes=True),
192191
)
193192

194-
def get_as_arrow(self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False) -> pa.Table:
193+
def get_as_arrow(
194+
self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False
195+
) -> pa.Table:
195196
"""
196197
Get the query result as a PyArrow Table.
197198
@@ -314,7 +315,9 @@ def get_as_networkx(
314315
table_to_label_dict = {}
315316
table_primary_key_dict = {}
316317

317-
def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str:
318+
def encode_node_id(
319+
node: dict[str, Any], table_primary_key_dict: dict[str, Any]
320+
) -> str:
318321
node_label = node[LABEL]
319322
return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}"
320323

src_py/torch_geometric_feature_store.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,29 @@ def __get_tensor_by_scan(self, attr: TensorAttr) -> FeatureTensorType | None:
6868
if indices.step is None or indices.step == 1:
6969
indices = np.arange(indices.start, indices.stop, dtype=np.uint64)
7070
else:
71-
indices = np.arange(indices.start, indices.stop, indices.step, dtype=np.uint64)
71+
indices = np.arange(
72+
indices.start, indices.stop, indices.step, dtype=np.uint64
73+
)
7274
elif isinstance(indices, int):
7375
indices = np.array([indices])
7476

7577
if table_name not in self.node_properties_cache:
76-
self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
78+
self.node_properties_cache[table_name] = (
79+
self.connection._get_node_property_names(table_name)
80+
)
7781
attr_info = self.node_properties_cache[table_name][attr_name]
7882

7983
flat_dim = 1
8084
if attr_info["dimension"] > 0:
8185
for i in range(attr_info["dimension"]):
8286
flat_dim *= attr_info["shape"][i]
8387
scan_result = self.connection.database._scan_node_table(
84-
table_name, attr_name, attr_info["type"], flat_dim, indices, self.num_threads
88+
table_name,
89+
attr_name,
90+
attr_info["type"],
91+
flat_dim,
92+
indices,
93+
self.num_threads,
8594
)
8695

8796
if attr_info["dimension"] > 0 and "shape" in attr_info:
@@ -151,11 +160,16 @@ def _get_tensor_size(self, attr: TensorAttr) -> tuple[Any, ...]:
151160
return (length,) + attr_info["shape"]
152161

153162
def __get_node_property(self, table_name: str, attr_name: str) -> dict[str, Any]:
154-
if table_name in self.node_properties_cache and attr_name in self.node_properties_cache[table_name]:
163+
if (
164+
table_name in self.node_properties_cache
165+
and attr_name in self.node_properties_cache[table_name]
166+
):
155167
return self.node_properties_cache[table_name][attr_name]
156168
self.__get_connection()
157169
if table_name not in self.node_properties_cache:
158-
self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
170+
self.node_properties_cache[table_name] = (
171+
self.connection._get_node_property_names(table_name)
172+
)
159173
if attr_name not in self.node_properties_cache[table_name]:
160174
msg = f"Attribute {attr_name} not found in group {table_name}"
161175
raise ValueError(msg)
@@ -168,7 +182,9 @@ def get_all_tensor_attrs(self) -> list[TensorAttr]:
168182
self.__get_connection()
169183
for table_name in self.connection._get_node_table_names():
170184
if table_name not in self.node_properties_cache:
171-
self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
185+
self.node_properties_cache[table_name] = (
186+
self.connection._get_node_property_names(table_name)
187+
)
172188
for attr_name in self.node_properties_cache[table_name]:
173189
if self.node_properties_cache[table_name][attr_name]["type"] in [
174190
Type.INT64.value,

src_py/torch_geometric_graph_store.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
if sys.version_info >= (3, 10):
2121
from typing import TypeAlias
2222
else:
23-
from typing_extensions import TypeAlias
23+
from typing import TypeAlias
2424

2525
StoreKeyType: TypeAlias = tuple[tuple[str], Any, bool]
2626

@@ -61,7 +61,9 @@ def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> No
6161
self.store[key].materialized = True
6262
self.store[key].size = edge_attr.size
6363
else:
64-
self.store[key] = Rel(key[0], key[1], key[2], edge_attr.size, True, edge_index)
64+
self.store[key] = Rel(
65+
key[0], key[1], key[2], edge_attr.size, True, edge_index
66+
)
6567

6668
def _get_edge_index(self, edge_attr: EdgeAttr) -> EdgeTensorType | None:
6769
if edge_attr.layout.value == EdgeLayout.COO.value: # noqa: SIM102
@@ -91,7 +93,10 @@ def _remove_edge_index(self, edge_attr: EdgeAttr) -> None:
9193

9294
def get_all_edge_attrs(self) -> list[EdgeAttr]:
9395
"""Return all EdgeAttr from the store values."""
94-
return [EdgeAttr(rel.edge_type, rel.layout, rel.is_sorted, rel.size) for rel in self.store.values()]
96+
return [
97+
EdgeAttr(rel.edge_type, rel.layout, rel.is_sorted, rel.size)
98+
for rel in self.store.values()
99+
]
95100

96101
def __get_edge_coo_from_database(self, key: StoreKeyType) -> None:
97102
if not self.connection:

0 commit comments

Comments
 (0)