Skip to content

Commit e511a66

Browse files
committed
Create new example notebook for Sessions + Spark
1 parent 54b2745 commit e511a66

File tree

1 file changed

+362
-0
lines changed

1 file changed

+362
-0
lines changed
Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"tags": [
7+
"aura"
8+
]
9+
},
10+
"source": [
11+
"# Aura Graph Analytics with Spark"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {
17+
"colab_type": "text"
18+
},
19+
"source": [
20+
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb\">\n",
21+
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
22+
"</a>"
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"metadata": {},
28+
"source": [
29+
"This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb) in the Neo4j Graph Data Science Client Github repository.\n",
30+
"\n",
31+
"The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session.\n",
32+
"\n",
33+
"We consider a graph of people and fruits, which we're using as a simple example to show how to connect your AuraDB instance to a GDS Session, run algorithms, and eventually write back your analytical results to the AuraDB database. \n",
34+
"We will cover all management operations: creation, listing, and deletion.\n",
35+
"\n",
36+
"If you are using self managed DB, follow [this example](../graph-analytics-serverless-self-managed)."
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"## Prerequisites\n",
44+
"\n",
45+
"This notebook requires having an AuraDB instance available and have the Aura Graph Analytics [feature](https://neo4j.com/docs/aura/graph-analytics/#aura-gds-serverless) enabled for your project.\n",
46+
"\n",
47+
"You also need to have the `graphdatascience` Python library installed, version `1.15` or later."
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"%pip install \"graphdatascience>=1.18a2\" python-dotenv \"pyspark[sql]\""
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": null,
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"from dotenv import load_dotenv\n",
66+
"\n",
67+
"# This allows to load required secrets from `.env` file in local directory\n",
68+
"# This can include Aura API Credentials and Database Credentials.\n",
69+
"# If file does not exist this is a noop.\n",
70+
"load_dotenv(\"sessions.env\")"
71+
]
72+
},
73+
{
74+
"cell_type": "markdown",
75+
"metadata": {},
76+
"source": [
77+
"### Connecting to a Spark Session\n",
78+
"\n",
79+
"To interact with the Spark Cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.\n",
80+
"Working with a remote Spark cluster will work similarly."
81+
]
82+
},
83+
{
84+
"cell_type": "code",
85+
"execution_count": null,
86+
"metadata": {},
87+
"outputs": [],
88+
"source": [
89+
"import os\n",
90+
"\n",
91+
"from pyspark.sql import SparkSession\n",
92+
"\n",
93+
"os.environ[\"JAVA_HOME\"] = \"/home/max/.sdkman/candidates/java/current\"\n",
94+
"\n",
95+
"spark = SparkSession.builder.master(\"local[4]\").appName(\"GraphAnalytics\").getOrCreate()\n",
96+
"\n",
97+
"# Enable Arrow-based columnar data transfers\n",
98+
"spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")"
99+
]
100+
},
101+
{
102+
"cell_type": "markdown",
103+
"metadata": {},
104+
"source": [
105+
"## Aura API credentials\n",
106+
"\n",
107+
"The entry point for managing GDS Sessions is the `GdsSessions` object, which requires creating [Aura API credentials](https://neo4j.com/docs/aura/api/authentication)."
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": null,
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"import os\n",
117+
"\n",
118+
"from graphdatascience.session import AuraAPICredentials, GdsSessions\n",
119+
"\n",
120+
"# you can also use AuraAPICredentials.from_env() to load credentials from environment variables\n",
121+
"api_credentials = AuraAPICredentials(\n",
122+
" client_id=os.environ[\"CLIENT_ID\"],\n",
123+
" client_secret=os.environ[\"CLIENT_SECRET\"],\n",
124+
" # If your account is a member of several project, you must also specify the project ID to use\n",
125+
" project_id=os.environ.get(\"PROJECT_ID\", None),\n",
126+
")\n",
127+
"\n",
128+
"sessions = GdsSessions(api_credentials=api_credentials)"
129+
]
130+
},
131+
{
132+
"cell_type": "markdown",
133+
"metadata": {},
134+
"source": [
135+
"## Creating a new session\n",
136+
"\n",
137+
"A new session is created by calling `sessions.get_or_create()` with the following parameters:\n",
138+
"\n",
139+
"* A session name, which lets you reconnect to an existing session by calling `get_or_create` again.\n",
140+
"* The session memory. \n",
141+
"* The cloud location.\n",
142+
"* A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.\n",
143+
"\n",
144+
"See the API reference [documentation](https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create) or the manual for more details on the parameters."
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": null,
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"from datetime import timedelta\n",
154+
"\n",
155+
"from graphdatascience.session import CloudLocation, SessionMemory\n",
156+
"\n",
157+
"# Create a GDS session!\n",
158+
"gds = sessions.get_or_create(\n",
159+
" # we give it a representative name\n",
160+
" session_name=\"people_and_fruits\",\n",
161+
" memory=SessionMemory.m_2GB,\n",
162+
" ttl=timedelta(minutes=30),\n",
163+
" cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n",
164+
")"
165+
]
166+
},
167+
{
168+
"cell_type": "markdown",
169+
"metadata": {},
170+
"source": [
171+
"## Adding a dataset\n",
172+
"\n",
173+
"As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips)."
174+
]
175+
},
176+
{
177+
"cell_type": "code",
178+
"execution_count": null,
179+
"metadata": {},
180+
"outputs": [],
181+
"source": [
182+
"import io\n",
183+
"import os\n",
184+
"import zipfile\n",
185+
"\n",
186+
"import requests\n",
187+
"\n",
188+
"download_path = \"bike_trips_data\"\n",
189+
"if not os.path.exists(download_path):\n",
190+
" url = \"https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips\"\n",
191+
"\n",
192+
" response = requests.get(url)\n",
193+
" response.raise_for_status()\n",
194+
"\n",
195+
" # Unzip the content\n",
196+
" with zipfile.ZipFile(io.BytesIO(response.content)) as z:\n",
197+
" z.extractall(download_path)\n",
198+
"\n",
199+
"df = spark.read.csv(download_path, header=True, inferSchema=True)\n",
200+
"df.createOrReplaceTempView(\"bike_trips\")\n",
201+
"df.limit(10).show()"
202+
]
203+
},
204+
{
205+
"cell_type": "markdown",
206+
"metadata": {},
207+
"source": [
208+
"## Projecting Graphs\n",
209+
"\n",
210+
"Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.\n",
211+
"\n",
212+
"We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.\n",
213+
"\n",
214+
"Our input data already resembles edge triplets, where each of the rows represents an edge from a source station to a target station. This allows us to use the arrows servers graph import from triplets functionality, which requires the following protocol:\n",
215+
"\n",
216+
"1. Send an action `v2/graph.project.fromTriplets`\n",
217+
" This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.\n",
218+
"2. Send the data in batches to the arrow server.\n",
219+
"3. Send another action called `v2/graph.project.fromTriples.done` to tell the import process that no more data will be send. This will trigger the final graph creation inside the session.\n",
220+
"4. Wait for the import process to reach the `DONE` state.\n",
221+
"\n",
222+
"While the overall process is straight forward, we need to somehow tell Spark to"
223+
]
224+
},
225+
{
226+
"cell_type": "code",
227+
"execution_count": null,
228+
"metadata": {},
229+
"outputs": [],
230+
"source": [
231+
"import pandas as pd\n",
232+
"import pyarrow\n",
233+
"from pyspark.sql import functions\n",
234+
"\n",
235+
"graph_name = \"bike_trips\"\n",
236+
"\n",
237+
"arrow_client = gds.arrow_client()\n",
238+
"\n",
239+
"# 1. Start the import process\n",
240+
"job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)\n",
241+
"\n",
242+
"\n",
243+
"# Define a function that receives an arrow batch and uploads it to the session\n",
244+
"def upload_batch(iterator):\n",
245+
" for batch in iterator:\n",
246+
" arrow_client.upload_triplets(job_id, [batch])\n",
247+
" yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({\"batch_rows_imported\": [len(batch)]}))\n",
248+
"\n",
249+
"\n",
250+
"# Select the source target pairs from our source data\n",
251+
"source_target_pairs = spark.sql(\"\"\"\n",
252+
" SELECT start_station_id AS sourceNode, end_station_id AS targetNode\n",
253+
" FROM bike_trips\n",
254+
"\"\"\")\n",
255+
"\n",
256+
"# 2. Use the `mapInArrow` function to upload the data to the sessions. Returns a dataframe with a single column with the batch sizes.\n",
257+
"uploaded_batches = source_target_pairs.mapInArrow(upload_batch, \"batch_rows_imported long\")\n",
258+
"\n",
259+
"# Aggregate the batch sizes to receive the row count.\n",
260+
"uploaded_batches.agg(functions.sum(\"batch_rows_imported\").alias(\"rows_imported\")).show()\n",
261+
"\n",
262+
"# 3. Finish the import process\n",
263+
"arrow_client.triplet_load_done(job_id)\n",
264+
"\n",
265+
"# 4. Wait for the import to finish\n",
266+
"while not arrow_client.job_status(job_id).succeeded():\n",
267+
" pass\n",
268+
"\n",
269+
"G = gds.v2.graph.get(graph_name)\n",
270+
"G"
271+
]
272+
},
273+
{
274+
"cell_type": "markdown",
275+
"metadata": {},
276+
"source": [
277+
"## Running Algorithms\n",
278+
"\n",
279+
"We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples."
280+
]
281+
},
282+
{
283+
"cell_type": "code",
284+
"execution_count": null,
285+
"metadata": {},
286+
"outputs": [],
287+
"source": [
288+
"print(\"Running PageRank ...\")\n",
289+
"pr_result = gds.v2.page_rank.mutate(G, mutate_property=\"pagerank\")"
290+
]
291+
},
292+
{
293+
"cell_type": "markdown",
294+
"metadata": {},
295+
"source": [
296+
"## Sending the computation result back to Spark\n",
297+
"\n",
298+
"Once the computation is done. We might want to further use the result in Spark.\n",
299+
"We can do this in a similar to the projection, by streaming batches of data into each of the Spark workers.\n",
300+
"Retrieving the data is a bit more complicated since we need some input data frame in order to trigger computations on the Spark workers.\n",
301+
"We use a data range equal to the size of workers we have in our cluster as our driving table.\n",
302+
"On the workers we will disregard the input and instead stream the computation data from the GDS Session."
303+
]
304+
},
305+
{
306+
"cell_type": "code",
307+
"execution_count": null,
308+
"metadata": {},
309+
"outputs": [],
310+
"source": [
311+
"# 1. Start the node property export on the session\n",
312+
"job_id = arrow_client.get_node_properties(G.name(), [\"pagerank\"])\n",
313+
"\n",
314+
"\n",
315+
"# Define a function that receives data from the GDS Session and turns it into data batches\n",
316+
"def retrieve_data(ignored):\n",
317+
" stream_data = arrow_client.stream_job(G.name(), job_id)\n",
318+
" batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)\n",
319+
" for b in batches:\n",
320+
" yield b\n",
321+
"\n",
322+
"\n",
323+
"# Create DataFrame with a single column and one row per worker\n",
324+
"input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF(\"batch_id\")\n",
325+
"# 2. Stream the data from the GDS Session into the Spark workers\n",
326+
"received_batches = input_partitions.mapInArrow(retrieve_data, \"nodeId long, pagerank double\")\n",
327+
"# Optional: Repartition the data to make sure it is distributed equally\n",
328+
"result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)\n",
329+
"\n",
330+
"result.show()"
331+
]
332+
},
333+
{
334+
"cell_type": "markdown",
335+
"metadata": {},
336+
"source": [
337+
"## Cleanup\n",
338+
"\n",
339+
"Now that we have finished our analysis, we can delete the session and stop the spark connection.\n",
340+
"\n",
341+
"Deleting the session will release all resources associated with it, and stop incurring costs."
342+
]
343+
},
344+
{
345+
"cell_type": "code",
346+
"execution_count": null,
347+
"metadata": {},
348+
"outputs": [],
349+
"source": [
350+
"gds.delete()\n",
351+
"spark.stop()"
352+
]
353+
}
354+
],
355+
"metadata": {
356+
"language_info": {
357+
"name": "python"
358+
}
359+
},
360+
"nbformat": 4,
361+
"nbformat_minor": 4
362+
}

0 commit comments

Comments
 (0)