Skip to content

Commit 80ff4d1

Browse files
committed
Use new GdsArrowClient internally
new one based on AuthenticatedArrowClient
1 parent 7dafba4 commit 80ff4d1

File tree

12 files changed

+69
-46
lines changed

12 files changed

+69
-46
lines changed

graphdatascience/arrow_client/authenticated_flight_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def connection_info(self) -> ConnectionInfo:
124124
125125
Returns
126126
-------
127-
tuple[str, int]
127+
ConnectionInfo
128128
the host and port of the GDS Arrow server
129129
"""
130130
return ConnectionInfo(self._host, self._port, self._encrypted)

graphdatascience/arrow_client/v1/gds_arrow_client.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pydantic import BaseModel
1515

1616
from graphdatascience.arrow_client.arrow_endpoint_version import ArrowEndpointVersion
17-
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
17+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo
1818
from graphdatascience.arrow_client.v1.data_mapper_utils import deserialize_single
1919

2020
from ...semantic_version.semantic_version import SemanticVersion
@@ -451,6 +451,29 @@ def upload_triplets(
451451
"""
452452
self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback)
453453

454+
def advertised_connection_info(self) -> ConnectionInfo:
455+
"""
456+
Returns the host and port of the GDS Arrow server.
457+
458+
Returns
459+
-------
460+
ConnectionInfo
461+
the host and port of the GDS Arrow server
462+
"""
463+
return self._flight_client.advertised_connection_info()
464+
465+
def request_token(self) -> str | None:
466+
"""
467+
Requests a token from the server and returns it.
468+
469+
Returns
470+
-------
471+
str | None
472+
a token from the server and returns it.
473+
"""
474+
475+
return self._flight_client.request_token()
476+
454477
def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]:
455478
action_type = f"{ArrowEndpointVersion.V1.prefix()}{action_type}"
456479
raw_result = self._flight_client.do_action_with_retry(action_type, meta_data)

graphdatascience/procedure_surface/cypher/catalog/node_properties_cypher_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pandas import DataFrame
22

3+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
34
from graphdatascience.call_parameters import CallParameters
45
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
56
from graphdatascience.procedure_surface.api.catalog.node_properties_endpoints import (
@@ -11,7 +12,6 @@
1112
from graphdatascience.procedure_surface.api.default_values import ALL_LABELS
1213
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
1314
from graphdatascience.procedure_surface.utils.result_utils import join_db_node_properties, transpose_property_columns
14-
from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient
1515
from graphdatascience.query_runner.query_runner import QueryRunner
1616

1717

graphdatascience/procedure_surface/cypher/catalog/relationship_cypher_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pandas import DataFrame
22

33
from graphdatascience import QueryRunner
4+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
45
from graphdatascience.call_parameters import CallParameters
56
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
67
from graphdatascience.procedure_surface.api.catalog.relationships_endpoints import (
@@ -14,7 +15,6 @@
1415
)
1516
from graphdatascience.procedure_surface.api.default_values import ALL_TYPES
1617
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
17-
from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient
1818

1919

2020
class RelationshipCypherEndpoints(RelationshipsEndpoints):

graphdatascience/query_runner/arrow_graph_constructor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from pandas import DataFrame
1212
from tqdm.auto import tqdm
1313

14-
from .gds_arrow_client import GdsArrowClient
14+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
15+
1516
from .graph_constructor import GraphConstructor
1617

1718

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55

66
from pandas import DataFrame
77

8+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
9+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
810
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication
911
from graphdatascience.query_runner.query_mode import QueryMode
10-
from graphdatascience.retry_utils.retry_config import RetryConfig
12+
from graphdatascience.retry_utils.retry_config import RetryConfigV2
1113

1214
from ..call_parameters import CallParameters
1315
from ..query_runner.arrow_info import ArrowInfo
1416
from ..server_version.server_version import ServerVersion
1517
from .arrow_graph_constructor import ArrowGraphConstructor
16-
from .gds_arrow_client import GdsArrowClient
1718
from .graph_constructor import GraphConstructor
1819
from .query_runner import QueryRunner
1920

@@ -27,20 +28,22 @@ def create(
2728
encrypted: bool = False,
2829
arrow_client_options: dict[str, Any] | None = None,
2930
connection_string_override: str | None = None,
30-
retry_config: RetryConfig | None = None,
31+
retry_config: RetryConfigV2 | None = None,
3132
) -> ArrowQueryRunner:
3233
if not arrow_info.enabled:
3334
raise ValueError("Arrow is not enabled on the server")
3435

35-
gds_arrow_client = GdsArrowClient.create(
36+
arrow_client = AuthenticatedArrowClient.create(
3637
arrow_info=arrow_info,
3738
auth=arrow_authentication,
3839
encrypted=encrypted,
40+
arrow_client_options=arrow_client_options,
3941
connection_string_override=connection_string_override,
4042
retry_config=retry_config,
41-
arrow_client_options=arrow_client_options,
4243
)
4344

45+
gds_arrow_client = GdsArrowClient(flight_client=arrow_client)
46+
4447
return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version())
4548

4649
def __init__(

graphdatascience/query_runner/session_query_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pandas import DataFrame
88

9+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
910
from graphdatascience.query_runner.graph_constructor import GraphConstructor
1011
from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger
1112
from graphdatascience.query_runner.query_mode import QueryMode
@@ -14,7 +15,6 @@
1415

1516
from ..call_parameters import CallParameters
1617
from ..session.dbms.protocol_resolver import ProtocolVersionResolver
17-
from .gds_arrow_client import GdsArrowClient
1818
from .protocol.project_protocols import ProjectProtocol
1919
from .protocol.write_protocols import WriteProtocol
2020
from .query_runner import QueryRunner
@@ -263,12 +263,12 @@ def _resolve_show_progress(self, show_progress: bool) -> bool:
263263
return self._show_progress and show_progress
264264

265265
def _inject_arrow_config(self, params: dict[str, Any]) -> None:
266-
host, port = self._gds_arrow_client.connection_info()
266+
connection_info = self._gds_arrow_client.advertised_connection_info()
267267
token = self._gds_arrow_client.request_token()
268268
if token is None:
269269
token = "IGNORED"
270270

271-
params["host"] = host
272-
params["port"] = port
271+
params["host"] = connection_info.host
272+
params["port"] = connection_info.port
273273
params["token"] = token
274-
params["encrypted"] = self._gds_query_runner.encrypted()
274+
params["encrypted"] = connection_info.encrypted

graphdatascience/session/aura_graph_data_science.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from graphdatascience import QueryRunner, ServerVersion
88
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
9+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
910
from graphdatascience.call_builder import IndirectCallBuilder
1011
from graphdatascience.endpoints import (
1112
AlphaRemoteEndpoints,
@@ -17,7 +18,6 @@
1718
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication
1819
from graphdatascience.query_runner.arrow_info import ArrowInfo
1920
from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
20-
from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient
2121
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
2222
from graphdatascience.query_runner.query_mode import QueryMode
2323
from graphdatascience.query_runner.session_query_runner import SessionQueryRunner
@@ -59,23 +59,17 @@ def create(
5959
arrow_client_options=arrow_client_options,
6060
)
6161

62-
# TODO: merge with the gds_arrow_client created inside ArrowQueryRunner
63-
session_arrow_client = GdsArrowClient.create(
64-
arrow_info,
65-
arrow_authentication,
66-
session_bolt_query_runner.encrypted(),
67-
arrow_client_options=arrow_client_options,
68-
)
69-
70-
gds_version = session_bolt_query_runner.server_version()
71-
7262
session_auth_arrow_client = AuthenticatedArrowClient.create(
7363
arrow_info=arrow_info,
7464
auth=arrow_authentication,
7565
encrypted=session_bolt_query_runner.encrypted(),
7666
arrow_client_options=arrow_client_options,
7767
)
7868

69+
session_arrow_client = GdsArrowClient(flight_client=session_auth_arrow_client)
70+
71+
gds_version = session_bolt_query_runner.server_version()
72+
7973
if db_endpoint is not None:
8074
if isinstance(db_endpoint, Neo4jQueryRunner):
8175
db_bolt_query_runner = db_endpoint

graphdatascience/tests/integrationV2/procedure_surface/cypher/catalog/test_node_properties_cypher_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import pytest
44

55
from graphdatascience import QueryRunner
6+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
67
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
78
from graphdatascience.procedure_surface.cypher.catalog.node_properties_cypher_endpoints import (
89
NodePropertiesCypherEndpoints,
910
)
10-
from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient
1111
from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import create_graph
1212

1313

graphdatascience/tests/integrationV2/procedure_surface/cypher/catalog/test_relationship_cypher_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import pytest
44

55
from graphdatascience import QueryRunner
6+
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
67
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
78
from graphdatascience.procedure_surface.api.catalog.relationships_endpoints import Aggregation
89
from graphdatascience.procedure_surface.cypher.catalog.relationship_cypher_endpoints import (
910
RelationshipCypherEndpoints,
1011
)
11-
from graphdatascience.query_runner.gds_arrow_client import GdsArrowClient
1212
from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import create_graph
1313

1414

0 commit comments

Comments
 (0)