Skip to content

Commit 54b2745

Browse files
committed
Expose arrow_client method in AuraGraphDataScience
1 parent 30a96ce commit 54b2745

File tree

9 files changed

+61
-31
lines changed

9 files changed

+61
-31
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import re
2+
3+
from neo4j.exceptions import ClientError
4+
from pyarrow import flight
5+
6+
7+
def handle_flight_error(e: Exception) -> None:
8+
if isinstance(e, flight.FlightServerError | flight.FlightInternalError | ClientError):
9+
original_message = e.args[0] if len(e.args) > 0 else e.message
10+
improved_message = original_message.replace(
11+
"Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
12+
)
13+
improved_message = improved_message.replace(
14+
"Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
15+
)
16+
improved_message = improved_message.replace(
17+
"Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ",
18+
"",
19+
)
20+
improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message)
21+
22+
raise flight.FlightServerError(improved_message)
23+
else:
24+
raise e

graphdatascience/arrow_client/v1/gds_arrow_client.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33
import json
44
import logging
5-
import re
65
from types import TracebackType
76
from typing import Any, Callable, Iterable, Type
87

98
import pandas
109
import pyarrow
11-
from neo4j.exceptions import ClientError
1210
from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Table, chunked_array, flight
1311
from pyarrow.types import is_dictionary
1412
from pydantic import BaseModel
@@ -18,6 +16,7 @@
1816
from graphdatascience.arrow_client.v1.data_mapper_utils import deserialize_single
1917

2018
from ...semantic_version.semantic_version import SemanticVersion
19+
from ..error_handler import handle_flight_error
2120

2221

2322
class GdsArrowClient:
@@ -515,7 +514,7 @@ def upload_batch(p: RecordBatch) -> None:
515514
ack_stream.read()
516515
progress_callback(partition.num_rows)
517516
except Exception as e:
518-
GdsArrowClient.handle_flight_error(e)
517+
handle_flight_error(e)
519518

520519
def _get_data(
521520
self,
@@ -560,7 +559,7 @@ def _fetch_get_result(self, get: flight.FlightStreamReader) -> pandas.DataFrame:
560559
try:
561560
arrow_table = get.read_all()
562561
except Exception as e:
563-
GdsArrowClient.handle_flight_error(e)
562+
handle_flight_error(e)
564563
arrow_table = self._sanitize_arrow_table(arrow_table)
565564
if SemanticVersion.from_string(pandas.__version__) >= SemanticVersion(2, 0, 0):
566565
return arrow_table.to_pandas(types_mapper=pandas.ArrowDtype) # type: ignore
@@ -615,26 +614,6 @@ def _decode_pyarrow_array(array: Array) -> Array:
615614
else:
616615
return array
617616

618-
@staticmethod
619-
def handle_flight_error(e: Exception) -> None:
620-
if isinstance(e, flight.FlightServerError | flight.FlightInternalError | ClientError):
621-
original_message = e.args[0] if len(e.args) > 0 else e.message
622-
improved_message = original_message.replace(
623-
"Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
624-
)
625-
improved_message = improved_message.replace(
626-
"Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: ", ""
627-
)
628-
improved_message = improved_message.replace(
629-
"Failed to invoke procedure `gds.arrow.project`: Caused by: org.apache.arrow.flight.FlightRuntimeException: ",
630-
"",
631-
)
632-
improved_message = re.sub(r"(\. )?gRPC client debug context: .+$", "", improved_message)
633-
634-
raise flight.FlightServerError(improved_message)
635-
else:
636-
raise e
637-
638617

639618
class NodeLoadDoneResult(BaseModel):
640619
name: str

graphdatascience/arrow_client/v2/api_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def aborted(self) -> bool:
3737
return self.status == "Aborted"
3838

3939
def succeeded(self) -> bool:
40-
return self.status == "Done"
40+
return self.status.lower() == "done"
4141

4242
def running(self) -> bool:
43-
return self.status == "Running"
43+
return self.status.lower() == "running"
4444

4545

4646
class MutateResult(ArrowBaseModel):

graphdatascience/query_runner/session_query_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
from pandas import DataFrame
88

9-
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
9+
from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient
1010
from graphdatascience.query_runner.graph_constructor import GraphConstructor
1111
from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger
1212
from graphdatascience.query_runner.query_mode import QueryMode
1313
from graphdatascience.query_runner.termination_flag import TerminationFlag
1414
from graphdatascience.server_version.server_version import ServerVersion
1515

16+
from ..arrow_client.error_handler import handle_flight_error
1617
from ..call_parameters import CallParameters
1718
from ..session.dbms.protocol_resolver import ProtocolVersionResolver
1819
from .protocol.project_protocols import ProjectProtocol
@@ -183,7 +184,7 @@ def run_projection() -> DataFrame:
183184
else:
184185
return run_projection()
185186
except Exception as e:
186-
GdsArrowClient.handle_flight_error(e)
187+
handle_flight_error(e)
187188
raise e # above should already raise
188189

189190
def _remote_write_back(

graphdatascience/session/aura_graph_data_science.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pandas import DataFrame
66

77
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8-
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
8+
from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient
99
from graphdatascience.call_builder import IndirectCallBuilder
1010
from graphdatascience.endpoints import (
1111
AlphaRemoteEndpoints,
@@ -94,6 +94,7 @@ def create(
9494
v2_endpoints=SessionV2Endpoints(
9595
session_auth_arrow_client, db_bolt_query_runner, show_progress=show_progress
9696
),
97+
authenticated_arrow_client=session_auth_arrow_client,
9798
)
9899
else:
99100
standalone_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner)
@@ -102,6 +103,7 @@ def create(
102103
delete_fn=delete_fn,
103104
gds_version=gds_version,
104105
v2_endpoints=SessionV2Endpoints(session_auth_arrow_client, None, show_progress=show_progress),
106+
authenticated_arrow_client=session_auth_arrow_client,
105107
)
106108

107109
def __init__(
@@ -110,11 +112,13 @@ def __init__(
110112
delete_fn: Callable[[], bool],
111113
gds_version: ServerVersion,
112114
v2_endpoints: SessionV2Endpoints,
115+
authenticated_arrow_client: AuthenticatedArrowClient,
113116
):
114117
self._query_runner = query_runner
115118
self._delete_fn = delete_fn
116119
self._server_version = gds_version
117120
self._v2_endpoints = v2_endpoints
121+
self._authenticated_arrow_client = authenticated_arrow_client
118122

119123
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
120124

@@ -177,6 +181,18 @@ def v2(self) -> SessionV2Endpoints:
177181
def __getattr__(self, attr: str) -> IndirectCallBuilder:
178182
return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version)
179183

184+
def arrow_client(self) -> GdsArrowClient:
185+
"""
186+
Returns a GdsArrowClient that is authenticated to communicate with the Aura Graph Analytics Session.
187+
This client can be used to get direct access to the sessions Arrow Flight server.
188+
189+
Returns:
190+
A GdsArrowClient
191+
-------
192+
193+
"""
194+
return GdsArrowClient(self._authenticated_arrow_client)
195+
180196
def set_database(self, database: str) -> None:
181197
"""
182198
Set the database which cypher queries are run against.

graphdatascience/session/aurads_sessions.py

Whitespace-only changes.

graphdatascience/tests/integrationV2/procedure_surface/session/test_walking_skeleton.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def gds(arrow_client: AuthenticatedArrowClient, db_query_runner: QueryRunner) ->
1515
delete_fn=lambda: True,
1616
gds_version=ServerVersion.from_string("1.2.3"),
1717
v2_endpoints=SessionV2Endpoints(arrow_client, db_query_runner),
18+
authenticated_arrow_client=arrow_client,
1819
)
1920

2021

graphdatascience/tests/unit/arrow_client/V1/test_gds_arrow_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
18+
from graphdatascience.arrow_client.error_handler import handle_flight_error
1819
from graphdatascience.arrow_client.v1.gds_arrow_client import GdsArrowClient
1920
from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication
2021
from graphdatascience.query_runner.arrow_info import ArrowInfo
@@ -397,7 +398,7 @@ def test_handle_flight_error() -> None:
397398
FlightServerError,
398399
match="FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.",
399400
):
400-
GdsArrowClient.handle_flight_error(
401+
handle_flight_error(
401402
FlightServerError(
402403
'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal'
403404
)
@@ -407,7 +408,7 @@ def test_handle_flight_error() -> None:
407408
FlightServerError,
408409
match=re.escape("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"),
409410
):
410-
GdsArrowClient.handle_flight_error(
411+
handle_flight_error(
411412
FlightServerError(
412413
"FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"
413414
)

graphdatascience/tests/unit/session/test_aura_graph_data_science.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def test_remote_projection_configuration(mocker: MockerFixture) -> None:
1414
delete_fn=lambda: True,
1515
gds_version=v,
1616
v2_endpoints=mocker.Mock(),
17+
authenticated_arrow_client=mocker.Mock(),
1718
)
1819

1920
g = gds.graph.project(
@@ -50,6 +51,7 @@ def test_remote_projection_defaults(mocker: MockerFixture) -> None:
5051
delete_fn=lambda: True,
5152
gds_version=v,
5253
v2_endpoints=mocker.Mock(),
54+
authenticated_arrow_client=mocker.Mock(),
5355
)
5456

5557
g = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -78,6 +80,7 @@ def test_remote_algo_write(mocker: MockerFixture) -> None:
7880
delete_fn=lambda: True,
7981
gds_version=v,
8082
v2_endpoints=mocker.Mock(),
83+
authenticated_arrow_client=mocker.Mock(),
8184
)
8285

8386
G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -99,6 +102,7 @@ def test_remote_algo_write_configuration(mocker: MockerFixture) -> None:
99102
delete_fn=lambda: True,
100103
gds_version=v,
101104
v2_endpoints=mocker.Mock(),
105+
authenticated_arrow_client=mocker.Mock(),
102106
)
103107

104108
G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -125,6 +129,7 @@ def test_remote_graph_write(mocker: MockerFixture) -> None:
125129
delete_fn=lambda: True,
126130
gds_version=v,
127131
v2_endpoints=mocker.Mock(),
132+
authenticated_arrow_client=mocker.Mock(),
128133
)
129134

130135
G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -149,6 +154,7 @@ def test_remote_graph_write_configuration(mocker: MockerFixture) -> None:
149154
delete_fn=lambda: True,
150155
gds_version=v,
151156
v2_endpoints=mocker.Mock(),
157+
authenticated_arrow_client=mocker.Mock(),
152158
)
153159

154160
G, _ = gds.graph.project("foo", "RETURN gds.graph.project(0, 1)")
@@ -176,6 +182,7 @@ def test_run_cypher_write(mocker: MockerFixture) -> None:
176182
delete_fn=lambda: True,
177183
gds_version=v,
178184
v2_endpoints=mocker.Mock(),
185+
authenticated_arrow_client=mocker.Mock(),
179186
)
180187

181188
gds.run_cypher("RETURN 1", params={"foo": 1}, mode=QueryMode.WRITE, database="bar", retryable=True)
@@ -193,6 +200,7 @@ def test_run_cypher_read(mocker: MockerFixture) -> None:
193200
delete_fn=lambda: True,
194201
gds_version=v,
195202
v2_endpoints=mocker.Mock(),
203+
authenticated_arrow_client=mocker.Mock(),
196204
)
197205

198206
gds.run_cypher("RETURN 1", params={"foo": 1}, mode=QueryMode.READ, retryable=False)

0 commit comments

Comments
 (0)