Skip to content

Commit 4f81d1d

Browse files
committed
Implement graph_construct using arrow v2
* bringing terminationFlag to GdsArrowClient (V2) to interrupt upload * job client also support waiting for a given status
1 parent e44512e commit 4f81d1d

File tree

12 files changed

+370
-30
lines changed

12 files changed

+370
-30
lines changed

graphdatascience/arrow_client/v2/gds_arrow_client.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from graphdatascience.arrow_client.arrow_endpoint_version import ArrowEndpointVersion
1313
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo
14+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1415

1516
from ...procedure_surface.api.default_values import ALL_TYPES
1617
from ...procedure_surface.utils.config_converter import ConfigConverter
@@ -328,6 +329,7 @@ def upload_nodes(
328329
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
329330
batch_size: int = 10000,
330331
progress_callback: Callable[[int], None] = lambda x: None,
332+
termination_flag: TerminationFlag | None = None,
331333
) -> None:
332334
"""
333335
Uploads node data to the server for a given job.
@@ -342,15 +344,20 @@ def upload_nodes(
342344
The number of rows per batch
343345
progress_callback
344346
A callback function that is called with the number of rows uploaded after each batch
347+
termination_flag
348+
A termination flag to cancel the upload if requested
345349
"""
346-
self._upload_data("graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback)
350+
self._upload_data(
351+
"graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback, termination_flag
352+
)
347353

348354
def upload_relationships(
349355
self,
350356
job_id: str,
351357
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
352358
batch_size: int = 10000,
353359
progress_callback: Callable[[int], None] = lambda x: None,
360+
termination_flag: TerminationFlag | None = None,
354361
) -> None:
355362
"""
356363
Uploads relationship data to the server for a given job.
@@ -365,15 +372,20 @@ def upload_relationships(
365372
The number of rows per batch
366373
progress_callback
367374
A callback function that is called with the number of rows uploaded after each batch
375+
termination_flag
376+
A termination flag to cancel the upload if requested
368377
"""
369-
self._upload_data("graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback)
378+
self._upload_data(
379+
"graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback, termination_flag
380+
)
370381

371382
def upload_triplets(
372383
self,
373384
job_id: str,
374385
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
375386
batch_size: int = 10000,
376387
progress_callback: Callable[[int], None] = lambda x: None,
388+
termination_flag: TerminationFlag | None = None,
377389
) -> None:
378390
"""
379391
Uploads triplet data to the server for a given job.
@@ -388,8 +400,10 @@ def upload_triplets(
388400
The number of rows per batch
389401
progress_callback
390402
A callback function that is called with the number of rows uploaded after each batch
403+
termination_flag
404+
A termination flag to cancel the upload if requested
391405
"""
392-
self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback)
406+
self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback, termination_flag)
393407

394408
def abort_job(self, job_id: str) -> None:
395409
"""
@@ -464,6 +478,7 @@ def _upload_data(
464478
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
465479
batch_size: int = 10000,
466480
progress_callback: Callable[[int], None] = lambda x: None,
481+
termination_flag: TerminationFlag | None = None,
467482
) -> None:
468483
match data:
469484
case pyarrow.Table():
@@ -490,6 +505,10 @@ def upload_batch(p: RecordBatch) -> None:
490505

491506
with put_stream:
492507
for partition in batches:
508+
if termination_flag is not None and termination_flag.is_set():
509+
self.abort_job(job_id) # closing the put_stream will raise an error
510+
break
511+
493512
upload_batch(partition)
494513
ack_stream.read()
495514
progress_callback(partition.num_rows)

graphdatascience/arrow_client/v2/job_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,26 @@ def wait_for_job(
3939
client: AuthenticatedArrowClient,
4040
job_id: str,
4141
show_progress: bool,
42+
expected_status: str | None = None,
4243
termination_flag: TerminationFlag | None = None,
4344
) -> None:
4445
progress_bar: TqdmProgressBar | None = None
4546

47+
def check_expected_status(status: JobStatus) -> bool:
48+
return job_status.succeeded() if expected_status is None else status.status == expected_status
49+
4650
if termination_flag is None:
4751
termination_flag = TerminationFlag.create()
4852

49-
for attempt in Retrying(retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5)):
53+
for attempt in Retrying(
54+
retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5), reraise=True
55+
):
5056
with attempt:
5157
termination_flag.assert_running()
5258

5359
job_status = self.get_job_status(client, job_id)
5460

55-
if job_status.succeeded() or job_status.aborted():
61+
if check_expected_status(job_status) or job_status.aborted():
5662
if progress_bar:
5763
progress_bar.finish(success=job_status.succeeded())
5864
return

graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def construct(
2222
graph_name: str,
2323
nodes: DataFrame | list[DataFrame],
2424
relationships: DataFrame | list[DataFrame] | None = None,
25-
concurrency: int = 4,
25+
concurrency: int | None = None,
2626
undirected_relationship_types: list[str] | None = None,
2727
) -> GraphV2:
2828
"""Construct a graph from a list of node and relationship dataframes.

graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pandas import DataFrame
99

1010
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
11+
from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient
1112
from graphdatascience.arrow_client.v2.job_client import JobClient
1213
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
1314
from graphdatascience.procedure_surface.api.base_result import BaseResult
@@ -31,6 +32,7 @@
3132
)
3233
from graphdatascience.procedure_surface.arrow.catalog.relationship_arrow_endpoints import RelationshipArrowEndpoints
3334
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
35+
from graphdatascience.query_runner.progress.progress_bar import NoOpProgressBar, ProgressBar, TqdmProgressBar
3436
from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol
3537
from graphdatascience.query_runner.query_runner import QueryRunner
3638
from graphdatascience.query_runner.termination_flag import TerminationFlag
@@ -135,10 +137,73 @@ def construct(
135137
graph_name: str,
136138
nodes: DataFrame | list[DataFrame],
137139
relationships: DataFrame | list[DataFrame] | None = None,
138-
concurrency: int = 4,
140+
concurrency: int | None = None,
139141
undirected_relationship_types: list[str] | None = None,
140142
) -> GraphV2:
141-
raise NotImplementedError("Graph construction is not yet supported via V2 endpoints.")
143+
gds_arrow_client = GdsArrowClient(self._arrow_client)
144+
job_client = JobClient()
145+
termination_flag = TerminationFlag.create()
146+
147+
if self._show_progress:
148+
progress_bar: ProgressBar = TqdmProgressBar(task_name="Constructing graph", relative_progress=0.0)
149+
else:
150+
progress_bar = NoOpProgressBar()
151+
152+
with progress_bar:
153+
create_job_id: str = gds_arrow_client.create_graph(
154+
graph_name=graph_name,
155+
undirected_relationship_types=undirected_relationship_types or [],
156+
concurrency=concurrency,
157+
)
158+
node_count = nodes.shape[0] if isinstance(nodes, DataFrame) else sum(df.shape[0] for df in nodes)
159+
if isinstance(relationships, DataFrame):
160+
rel_count = relationships.shape[0]
161+
elif relationships is None:
162+
rel_count = 0
163+
relationships = []
164+
else:
165+
rel_count = sum(df.shape[0] for df in relationships)
166+
total_count = node_count + rel_count
167+
168+
gds_arrow_client.upload_nodes(
169+
create_job_id,
170+
nodes,
171+
progress_callback=lambda rows_imported: progress_bar.update(
172+
sub_tasks_description="Uploading nodes", progress=rows_imported / total_count, status="Running"
173+
),
174+
termination_flag=termination_flag,
175+
)
176+
177+
gds_arrow_client.node_load_done(create_job_id)
178+
179+
# skipping progress bar here as we have our own for the overall process
180+
job_client.wait_for_job(
181+
self._arrow_client,
182+
create_job_id,
183+
expected_status="RELATIONSHIP_LOADING",
184+
termination_flag=termination_flag,
185+
show_progress=False,
186+
)
187+
188+
if rel_count > 0:
189+
gds_arrow_client.upload_relationships(
190+
create_job_id,
191+
relationships,
192+
progress_callback=lambda rows_imported: progress_bar.update(
193+
sub_tasks_description="Uploading relationships",
194+
progress=rows_imported / total_count,
195+
status="Running",
196+
),
197+
termination_flag=termination_flag,
198+
)
199+
200+
gds_arrow_client.relationship_load_done(create_job_id)
201+
202+
# will produce a second progress bar to show graph construction on the server side
203+
job_client.wait_for_job(
204+
self._arrow_client, create_job_id, termination_flag=termination_flag, show_progress=True
205+
)
206+
return get_graph(graph_name, self._arrow_client)
142207

143208
def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None:
144209
graph_name = G.name() if isinstance(G, GraphV2) else G

graphdatascience/procedure_surface/cypher/catalog_cypher_endpoints.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from graphdatascience.query_runner.arrow_graph_constructor import ArrowGraphConstructor
2424
from graphdatascience.query_runner.cypher_graph_constructor import CypherGraphConstructor
2525
from graphdatascience.query_runner.graph_constructor import GraphConstructor
26+
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
2627

2728
from ...call_parameters import CallParameters
2829
from ..api.base_result import BaseResult
@@ -34,9 +35,8 @@
3435

3536

3637
class CatalogCypherEndpoints(CatalogEndpoints):
37-
def __init__(self, cypher_runner: Neo4jQueryRunner, arrow_client: GdsArrowClient | None = None, arrow_client: GdsArrowClient | None = None):
38-
self.cypher_runner = cypher_runner
39-
self._arrow_client = arrow_client
38+
def __init__(self, cypher_runner: Neo4jQueryRunner, arrow_client: GdsArrowClient | None = None):
39+
self._cypher_runner = cypher_runner
4040
self._arrow_client = arrow_client
4141

4242
def construct(
@@ -56,7 +56,7 @@ def construct(
5656

5757
graph_constructor: GraphConstructor
5858
if self._arrow_client is not None:
59-
database = require_database(self._query_runner)
59+
database = require_database(self._cypher_runner)
6060

6161
graph_constructor = ArrowGraphConstructor(
6262
database=database,
@@ -67,20 +67,20 @@ def construct(
6767
)
6868
else:
6969
graph_constructor = CypherGraphConstructor(
70-
query_runner=self._query_runner,
70+
query_runner=self._cypher_runner,
7171
graph_name=graph_name,
7272
concurrency=concurrency,
7373
undirected_relationship_types=undirected_relationship_types,
7474
)
7575

7676
graph_constructor.run(node_dfs=nodes, relationship_dfs=relationships)
77-
return GraphV2(name=graph_name, backend=CypherGraphBackend(graph_name, self._query_runner))
77+
return GraphV2(name=graph_name, backend=CypherGraphBackend(graph_name, self._cypher_runner))
7878

7979
def list(self, G: GraphV2 | str | None = None) -> list[GraphInfoWithDegrees]:
8080
graph_name = G if isinstance(G, str) else G.name() if G is not None else None
8181
params = CallParameters(graphName=graph_name) if graph_name else CallParameters()
8282

83-
result = self.cypher_runner.call_procedure(endpoint="gds.graph.list", params=params)
83+
result = self._cypher_runner.call_procedure(endpoint="gds.graph.list", params=params)
8484
return [GraphInfoWithDegrees(**row.to_dict()) for _, row in result.iterrows()]
8585

8686
def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None:
@@ -92,7 +92,7 @@ def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | No
9292
else CallParameters(graphName=graph_name)
9393
)
9494

95-
result = self.cypher_runner.call_procedure(endpoint="gds.graph.drop", params=params)
95+
result = self._cypher_runner.call_procedure(endpoint="gds.graph.drop", params=params)
9696
if len(result) > 0:
9797
return GraphInfo(**result.iloc[0].to_dict())
9898
else:
@@ -128,11 +128,11 @@ def project(
128128
)
129129
params.ensure_job_id_in_config()
130130

131-
result = self.cypher_runner.call_procedure(
131+
result = self._cypher_runner.call_procedure(
132132
endpoint="gds.graph.project", params=params, logging=log_progress
133133
).squeeze()
134134
project_result = GraphProjectResult(**result.to_dict())
135-
return GraphWithProjectResult(get_graph(project_result.graph_name, self.cypher_runner), project_result)
135+
return GraphWithProjectResult(get_graph(project_result.graph_name, self._cypher_runner), project_result)
136136

137137
def filter(
138138
self,
@@ -158,10 +158,10 @@ def filter(
158158
)
159159
params.ensure_job_id_in_config()
160160

161-
result = self.cypher_runner.call_procedure(
161+
result = self._cypher_runner.call_procedure(
162162
endpoint="gds.graph.filter", params=params, logging=log_progress
163163
).squeeze()
164-
return GraphWithFilterResult(get_graph(graph_name, self.cypher_runner), GraphFilterResult(**result.to_dict()))
164+
return GraphWithFilterResult(get_graph(graph_name, self._cypher_runner), GraphFilterResult(**result.to_dict()))
165165

166166
def generate(
167167
self,
@@ -202,28 +202,28 @@ def generate(
202202

203203
params.ensure_job_id_in_config()
204204

205-
result = self.cypher_runner.call_procedure(
205+
result = self._cypher_runner.call_procedure(
206206
endpoint="gds.graph.generate", params=params, logging=log_progress
207207
).squeeze()
208208
return GraphWithGenerationStats(
209-
get_graph(graph_name, self.cypher_runner), GraphGenerationStats(**result.to_dict())
209+
get_graph(graph_name, self._cypher_runner), GraphGenerationStats(**result.to_dict())
210210
)
211211

212212
@property
213213
def sample(self) -> GraphSamplingEndpoints:
214-
return GraphSamplingCypherEndpoints(self.cypher_runner)
214+
return GraphSamplingCypherEndpoints(self._cypher_runner)
215215

216216
@property
217217
def node_labels(self) -> NodeLabelCypherEndpoints:
218-
return NodeLabelCypherEndpoints(self.cypher_runner)
218+
return NodeLabelCypherEndpoints(self._cypher_runner)
219219

220220
@property
221221
def node_properties(self) -> NodePropertiesCypherEndpoints:
222-
return NodePropertiesCypherEndpoints(self.cypher_runner, self._arrow_client)
222+
return NodePropertiesCypherEndpoints(self._cypher_runner, self._arrow_client)
223223

224224
@property
225225
def relationships(self) -> RelationshipCypherEndpoints:
226-
return RelationshipCypherEndpoints(self.cypher_runner, self._arrow_client)
226+
return RelationshipCypherEndpoints(self._cypher_runner, self._arrow_client)
227227

228228

229229
class GraphProjectResult(BaseResult):

0 commit comments

Comments
 (0)