diff --git a/TESTING.md b/TESTING.md index 46cec387b..e74479119 100644 --- a/TESTING.md +++ b/TESTING.md @@ -7,8 +7,7 @@ Tests can be found in `graphdatascience/tests`. In each of the folders there, `u Please see the section [Specifically for this project](CONTRIBUTING.md#specifically-for-this-project) of our [contribution guidelines](CONTRIBUTING.md) for how to set up an environment for testing and style checking. > **_NOTE:_** This document does not cover documentation testing. -Please see the [documentation README](doc/README.md#testing) for that. - +> Please see the [documentation README](doc/README.md#testing) for that. ## Unit testing @@ -20,11 +19,13 @@ To run the unit tests (with default options), simply call: pytest graphdatascience/tests/unit ``` +or for just `just unit-tests` ## Integration testing In order to run the integration tests one must have a [Neo4j DBMS](https://neo4j.com/docs/getting-started/current/) with the Neo4j Graph Data Science library installed running. +If you want to use just, you can use `just it`. ### V2 endpoints @@ -32,33 +33,33 @@ The integration tests for the V2 endpoints are located in `graphdatascience/test In order to run the tests, you need to have Docker running. You also need to either bring two Docker images, or configure authenticated access to the GCP repository where the production Docker images are stored. +If you want to use just, you can use `just it-v2`. ### Bringing your own Docker images Set the environment variables `NEO4J_DATABASE_IMAGE` and `GDS_SESSION_IMAGE` to the names of the Docker images you want to use. - ### Configuring authenticated access to the GCP repository 1. `gcloud init` 2. `gcloud auth login` 3. `gcloud auth configure-docker europe-west1-docker.pkg.dev` - ### Configuring +If you do not want to use a custom neo4j db, you can use the test-envs under `scripts/test-envs`. + The tests will through the [Neo4j Python driver](https://neo4j.com/docs/python-manual/current/) connect to a Neo4j database based on the environment variables: -* `NEO4J_URI` (defaulting to "bolt://localhost:7687" if unset), -* `NEO4J_USER`, -* `NEO4J_PASSWORD` (defaulting to "neo4j" if unset), -* `NEO4J_DB` (defaulting to "neo4j" if unset). +- `NEO4J_URI` (defaulting to "bolt://localhost:7687" if unset), +- `NEO4J_USER`, +- `NEO4J_PASSWORD` (defaulting to "neo4j" if unset), +- `NEO4J_DB` (defaulting to "neo4j" if unset). However, if `NEO4J_USER` is not set the tests will try to connect without authentication. Once the driver connects successfully to the Neo4j DBMS the tests will go on to execute against the `NEO4J_DB` database. - ### Running To run the integration tests (with default options), simply call: @@ -75,7 +76,6 @@ Note however that this also requires you to have specified a valid path for the If the database you are targeting is an AuraDS instance, you should use the option `--target-aura` which makes sure that tests of operations not supported on AuraDS are skipped. - ### Running tests that require encrypted connections In order to run integration tests that test encryption features, you must setup the Neo4j server accordingly: @@ -99,14 +99,12 @@ To run only integration tests that are marked as `encrypted_only`, call: pytest graphdatascience/tests/integration --encrypted-only ``` - ### GDS library versions There are integration tests that are only compatible with certain versions of the GDS library. For example, a procedure (which does not follow the standard algorithm procedure pattern) introduced in version 2.1.0 of the library will not exist in version 2.0.3, and so any client side integration tests that call this procedure should not run when testing against server library version 2.0.3. For this reason only tests compatible with the GDS library server version you are running against will run. - ## Style guide The code and examples use [ruff](hhttps://docs.astral.sh/ruff/) to format and lint. @@ -116,7 +114,6 @@ Use `SKIP_NOTEBOOKS=true` to only format the code. See `pyproject.toml` for the configuration. - ### Static typing The code is annotated with type hints in order to provide documentation and allow for static type analysis with [mypy](http://mypy-lang.org/). @@ -129,18 +126,16 @@ mypy . from the root. See `mypy.ini` for our custom mypy settings. - ## Notebook examples The notebooks under `/examples` can be run using `scripts/run_notebooks`. - ### Cell Tags -*Verify version* +_Verify version_ If you only want to let CI run the notebook given a certain condition, tag a given cell in the notebook with `verify-version`. As the name suggests, the tag was introduced to only run for given GDS server versions. -*Teardown* +_Teardown_ To make sure certain cells are always run even in case of failure, tag the cell with `teardown`. diff --git a/changelog.md b/changelog.md index 1ae5c3630..dae9c4de6 100644 --- a/changelog.md +++ b/changelog.md @@ -2,14 +2,14 @@ ## Breaking changes - ## New features - ## Bug fixes - ## Improvements +- `GdsSessions.get_or_create` now allows to specify the `aura_instance_id` instead of `uri` as part of the `db_connection`. This is required if the instance id could not be derived from the provided database connection URI such as for Multi-Database instances. ## Other changes + +- Deprecate deriving `aura_instance_id` from provided `uri` for `GdsSessions.get_or_create`. diff --git a/doc/modules/ROOT/pages/graph-analytics-serverless.adoc b/doc/modules/ROOT/pages/graph-analytics-serverless.adoc index 19c097f11..f5b86be8c 100644 --- a/doc/modules/ROOT/pages/graph-analytics-serverless.adoc +++ b/doc/modules/ROOT/pages/graph-analytics-serverless.adoc @@ -85,6 +85,7 @@ sessions.get_or_create( db_connection: Optional[DbmsConnectionInfo] = None, ttl: Optional[timedelta] = None, cloud_location: Optional[CloudLocation] = None, + aura_instance_id: Optional[str] = None, timeout: Optional[int] = None, neo4j_driver_options: Optional[dict[str, Any]] = None, arrow_client_options: Optional[dict[str, Any]] = None, @@ -97,7 +98,8 @@ sessions.get_or_create( | Name | Type | Optional | Default | Description | session_name | str | no | - | Name of the session. Must be unique within the project. | memory | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/session_memory[SessionMemory] | no | - | Amount of memory available to the session. -| db_connection | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/dbms_connection_info[DbmsConnectionInfo] | yes | None | Bolt server URL, username, and password to a Neo4j DBMS. Required for the Attached and Self-managed types. Alternatively to username and password, you can provide a `neo4j.Auth` https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods[object]. +| db_connection | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/dbms_connection_info[DbmsConnectionInfo] | yes | None | Aura instance-id, username, and password to a Neo4j DBMS. Required for the Attached and Self-managed types. For self-managed, provide the URI instead of the instance-id. +Alternatively to username and password, you can provide a `neo4j.Auth` https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods[object]. | ttl | datetime.timedelta | yes | 1h | Time-to-live for the session. | cloud_location | https://neo4j.com/docs/graph-data-science-client/{docs-version}/api/sessions/cloud_location[CloudLocation] | yes | None | Aura-supported cloud provider and region where the GDS Session will run. Required for the Self-managed and Standalone types. | timeout | int | yes | None | Seconds to wait for the session to enter Ready state. If the time is exceeded, an error will be returned. @@ -123,9 +125,9 @@ gds = sessions.get_or_create( session_name="my-attached-session", memory=SessionMemory.m_4GB, db_connection=DbmsConnectionInfo( - "neo4j+s://mydbid.databases.neo4j.io", - "my-user", - "my-password" + aura_instance_id="mydbid", + username="my-user", + password="my-password" ), ) ---- @@ -141,7 +143,11 @@ from graphdatascience.session import DbmsConnectionInfo, CloudLocation, SessionM gds = sessions.get_or_create( session_name="my-self-managed-session", memory=SessionMemory.m_4GB, - db_connection=DbmsConnectionInfo("neo4j://localhost", "my-user", "my-password"), + db_connection=DbmsConnectionInfo( + uri="neo4j://localhost", + username="my-user", + password="my-password" + ), cloud_location=CloudLocation(provider="gcp", region="europe-west1"), ) ---- @@ -315,7 +321,13 @@ from graphdatascience.session import SessionMemory, DbmsConnectionInfo, GdsSessi sessions = GdsSessions(api_credentials=AuraAPICredentials(os.environ["CLIENT_ID"], os.environ["CLIENT_SECRET"])) -db_connection = DbmsConnectionInfo(os.environ["DB_URI"], os.environ["DB_USER"], os.environ["DB_PASSWORD"]) +# you can also use DbmsConnectionInfo.from_env() to load credentials from environment variables +db_connection = DbmsConnectionInfo( + uri=os.environ["NEO4J_URI"], + username=os.environ["NEO4J_USERNAME"], + password=os.environ["NEO4J_PASSWORD"], + aura_instance_id=os.environ["AURA_INSTANCEID"] +) gds = sessions.get_or_create( session_name="my-new-session", memory=SessionMemory.m_8GB, diff --git a/doc/modules/ROOT/pages/tutorials/graph-analytics-serverless-spark.adoc b/doc/modules/ROOT/pages/tutorials/graph-analytics-serverless-spark.adoc new file mode 100644 index 000000000..f1e302e7e --- /dev/null +++ b/doc/modules/ROOT/pages/tutorials/graph-analytics-serverless-spark.adoc @@ -0,0 +1,294 @@ +// DO NOT EDIT - AsciiDoc file generated automatically + += Aura Graph Analytics with Spark + + +https://colab.research.google.com/github/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb[image:https://colab.research.google.com/assets/colab-badge.svg[Open +In Colab]] + + +This Jupyter notebook is hosted +https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless-spark.ipynb[here] +in the Neo4j Graph Data Science Client Github repository. + +The notebook shows how to use the `graphdatascience` Python library to +create, manage, and use a GDS Session from within an Apache Spark +cluster. + +We consider a graph of bicycle rentals, which we’re using as a simple +example to show how to project data from Spark to a GDS Session, run +algorithms, and eventually return results back to Spark. In this +notebook we will focus on the interaction with Apache Spark, and will +not cover all possible actions using GDS sessions. We refer to other +Tutorials for additional details. + +== Prerequisites + +We also need to have the `graphdatascience` Python library installed, +version `1.18` or later, as well as `pyspark`. + +[source, python, role=no-test] +---- +%pip install "graphdatascience>=1.18" python-dotenv "pyspark[sql]" +---- + +[source, python, role=no-test] +---- +from dotenv import load_dotenv + +# This allows to load required secrets from `.env` file in local directory +# This can include Aura API Credentials and Database Credentials. +# If file does not exist this is a noop. +load_dotenv("sessions.env") +---- + +=== Connecting to a Spark Session + +To interact with the Spark cluster we need to first instantiate a Spark +session. In this example we will use a local Spark session, which will +run Spark on the same machine. Working with a remote Spark cluster will +work similarly. For more information about setting up pyspark visit +https://spark.apache.org/docs/latest/api/python/getting++_++started/ + +[source, python, role=no-test] +---- +from pyspark.sql import SparkSession + +spark = SparkSession.builder.master("local[4]").appName("GraphAnalytics").getOrCreate() + +# Enable Arrow-based columnar data transfers +spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") +---- + +== Aura API credentials + +The entry point for managing GDS Sessions is the `GdsSessions` object, +which requires creating +https://neo4j.com/docs/aura/api/authentication[Aura API credentials]. + +[source, python, role=no-test] +---- +import os + +from graphdatascience.session import AuraAPICredentials, GdsSessions + +# you can also use AuraAPICredentials.from_env() to load credentials from environment variables +api_credentials = AuraAPICredentials( + client_id=os.environ["CLIENT_ID"], + client_secret=os.environ["CLIENT_SECRET"], + # If your account is a member of several projects, you must also specify the project ID to use + project_id=os.environ.get("PROJECT_ID", None), +) + +sessions = GdsSessions(api_credentials=api_credentials) +---- + +== Creating a new session + +A new session is created by calling `sessions.get++_++or++_++create()` +with the following parameters: + +* A session name, which lets you reconnect to an existing session by +calling `get++_++or++_++create` again. +* The session memory. +* The cloud location. +* A time-to-live (TTL), which ensures that the session is automatically +deleted after being unused for the set time, to avoid incurring costs. + +See the API reference +https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create[documentation] +or the manual for more details on the parameters. + +[source, python, role=no-test] +---- +from datetime import timedelta + +from graphdatascience.session import CloudLocation, SessionMemory + +# Create a GDS session! +gds = sessions.get_or_create( + # we give it a representative name + session_name="bike_trips", + memory=SessionMemory.m_2GB, + ttl=timedelta(minutes=30), + cloud_location=CloudLocation("gcp", "europe-west1"), +) +---- + +== Adding a dataset + +As the next step we will setup a dataset in Spark. In this example we +will use the New York Bike trips dataset +(https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike +trips form a graph where nodes represent bike renting stations and +relationships represent start and end points for a bike rental trip. + +[source, python, role=no-test] +---- +import io +import os +import zipfile + +import requests + +download_path = "bike_trips_data" +if not os.path.exists(download_path): + url = "https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips" + + response = requests.get(url) + response.raise_for_status() + + # Unzip the content + with zipfile.ZipFile(io.BytesIO(response.content)) as z: + z.extractall(download_path) + +df = spark.read.csv(download_path, header=True, inferSchema=True) +df.createOrReplaceTempView("bike_trips") +df.limit(10).show() +---- + +== Projecting Graphs + +Now that we have our dataset available within our Spark session it is +time to project it to the GDS Session. + +We first need to get access to the GDSArrowClient. This client allows us +to directly communicate with the Arrow Flight server provided by the +session. + +Our input data already resembles triplets, where each row represents an +edge from a source station to a target station. This allows us to use +the Arrow Server’s "`graph import from triplets`" functionality, which +requires the following protocol: + +[arabic] +. Send an action `v2/graph.project.fromTriplets` This will initialize +the import process and allows us to specify the graph name, and settings +like `undirected++_++relationship++_++types`. It returns a job id, that +we need to reference the import job in the following steps. +. Send the data in batches to the Arrow server. +. Send another action called `v2/graph.project.fromTriplets.done` to +tell the import process that no more data will be sent. This will +trigger the final graph creation inside the GDS session. +. Wait for the import process to reach the `DONE` state. + +The most complicated step here is to run the actual data upload on each +spark worker. We will use the `mapInArrow` function to run custom code +on each spark worker. Each worker will receive a number of arrow record +batches that we can directly send to the GDS session’s Arrow server. + +The user wants to add a 1-second delay (sleep) within the loop that +waits for the import job to finish. This requires importing the `time` +module and adding `time.sleep(1)` inside the `while` loop at the end of +the cell. + +graph-analytics-serverless-spark.ipynb + +[source, python, role=no-test] +---- +import time + +import pandas as pd +import pyarrow +from pyspark.sql import functions + +graph_name = "bike_trips" + +arrow_client = gds.arrow_client() + +# 1. Start the import process +job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4) + + +# Define a function that receives an arrow batch and uploads it to the GDS session +def upload_batch(iterator): + for batch in iterator: + arrow_client.upload_triplets(job_id, [batch]) + yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({"batch_rows_imported": [len(batch)]})) + + +# Select the source target pairs from our source data +source_target_pairs = spark.sql(""" + SELECT start_station_id AS sourceNode, end_station_id AS targetNode + FROM bike_trips + """) + +# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes. +uploaded_batches = source_target_pairs.mapInArrow(upload_batch, "batch_rows_imported long") + +# Aggregate the batch sizes to receive the row count. +aggregated_batch_sizes = uploaded_batches.agg(functions.sum("batch_rows_imported").alias("rows_imported")) + +# Show the result. This will trigger the computation and thus run the data upload. +aggregated_batch_sizes.show() + +# 3. Finish the import process +arrow_client.triplet_load_done(job_id) + +# 4. Wait for the import to finish +while not arrow_client.job_status(job_id).succeeded(): + time.sleep(1) + +G = gds.v2.graph.get(graph_name) +G +---- + +== Running Algorithms + +We can run algorithms on the constructed graph using the standard GDS +Python Client API. See the other tutorials for more examples. + +[source, python, role=no-test] +---- +print("Running PageRank ...") +pr_result = gds.v2.page_rank.mutate(G, mutate_property="pagerank") +---- + +== Sending the computation result back to Spark + +Once the computation is done, we might want to further use the result in +Spark. We can do this in a similar way to the projection, by streaming +batches of data into each of the Spark workers. Retrieving the data is a +bit more complicated since we need some input DataFrame in order to +trigger computations on the Spark workers. We use a data range equal to +the size of workers we have in our cluster as our driving table. On the +workers we will disregard the input and instead stream the computation +data from the GDS Session. + +[source, python, role=no-test] +---- +# 1. Start the node property export on the GDS session +job_id = arrow_client.get_node_properties(G.name(), ["pagerank"]) + + +# Define a function that receives data from the GDS Session and turns it into data batches +def retrieve_data(ignored): + stream_data = arrow_client.stream_job(G.name(), job_id) + batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000) + for b in batches: + yield b + + +# Create DataFrame with a single column and one row per worker +input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF("batch_id") +# 2. Stream the data from the GDS Session into the Spark workers +received_batches = input_partitions.mapInArrow(retrieve_data, "nodeId long, pagerank double") +# Optional: Repartition the data to make sure it is distributed equally +result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism) + +result.toPandas() +---- + +== Cleanup + +Now that we have finished our analysis, we can delete the GDS session +and stop the Spark session. + +Deleting the GDS session will release all resources associated with it, +and stop incurring costs. + +[source, python, role=no-test] +---- +gds.delete() +spark.stop() +---- diff --git a/doc/modules/ROOT/pages/tutorials/graph-analytics-serverless.adoc b/doc/modules/ROOT/pages/tutorials/graph-analytics-serverless.adoc index a0f7bcf07..6a17d3628 100644 --- a/doc/modules/ROOT/pages/tutorials/graph-analytics-serverless.adoc +++ b/doc/modules/ROOT/pages/tutorials/graph-analytics-serverless.adoc @@ -115,9 +115,9 @@ from graphdatascience.session import DbmsConnectionInfo # Identify the AuraDB instance # you can also use DbmsConnectionInfo.from_env() to load credentials from environment variables db_connection = DbmsConnectionInfo( - uri=os.environ["NEO4J_URI"], username=os.environ["NEO4J_USERNAME"], password=os.environ["NEO4J_PASSWORD"], + aura_instance_id=os.environ["AURA_INSTANCEID"], ) # Create a GDS session! diff --git a/examples/graph-analytics-serverless.ipynb b/examples/graph-analytics-serverless.ipynb index 7852fd4aa..e1b0cc8b1 100644 --- a/examples/graph-analytics-serverless.ipynb +++ b/examples/graph-analytics-serverless.ipynb @@ -155,9 +155,9 @@ "# Identify the AuraDB instance\n", "# you can also use DbmsConnectionInfo.from_env() to load credentials from environment variables\n", "db_connection = DbmsConnectionInfo(\n", - " uri=os.environ[\"NEO4J_URI\"],\n", " username=os.environ[\"NEO4J_USERNAME\"],\n", " password=os.environ[\"NEO4J_PASSWORD\"],\n", + " aura_instance_id=os.environ[\"AURA_INSTANCEID\"],\n", ")\n", "\n", "# Create a GDS session!\n", diff --git a/justfile b/justfile new file mode 100644 index 000000000..5ff11e072 --- /dev/null +++ b/justfile @@ -0,0 +1,28 @@ +style: + ./scripts/makestyle && ./scripts/checkstyle + +convert-notebooks: + ./scripts/nb2doc/convert.sh + +unit-tests: + pytest tests/unit + +it filter="" enterprise="true": + #!/usr/bin/env bash + set -e + if [ "{{enterprise}}" = "true" ]; then + ENV_DIR="scripts/test_envs/gds_plugin_enterprise" + EXTRA_FLAGS="--include-model-store-location --include-enterprise" + else + ENV_DIR="scripts/test_envs/gds_plugin_community" + EXTRA_FLAGS="" + fi + trap "cd $ENV_DIR && docker compose down" EXIT + cd $ENV_DIR && docker compose up -d + cd - + pytest tests/integration $EXTRA_FLAGS --basetemp=tmp/ {{ if filter != "" { "-k '" + filter + "'" } else { "" } }} + + +# such as `just it-v2 wcc` +it-v2 filter="": + pytest tests/integrationV2 --include-integration-v2 --basetemp=tmp/ {{ if filter != "" { "-k '" + filter + "'" } else { "" } }} diff --git a/scripts/ci/run_targeting_aura_sessions.py b/scripts/ci/run_targeting_aura_sessions.py index 58ba32906..963588c83 100644 --- a/scripts/ci/run_targeting_aura_sessions.py +++ b/scripts/ci/run_targeting_aura_sessions.py @@ -41,6 +41,7 @@ def handle_signal(sig: int, frame: FrameType | None) -> None: uri = create_result["connection_url"] username = create_result["username"] password = create_result["password"] + instance_id = create_result["id"] if project_id: # Avoid the `None` literal as the value for the variable @@ -48,7 +49,7 @@ def handle_signal(sig: int, frame: FrameType | None) -> None: else: project_id_part = "" - cmd = f"AURA_ENV=staging CLIENT_ID={client_id} CLIENT_SECRET={client_secret} {project_id_part} NEO4J_URI={uri} NEO4J_USERNAME={username} NEO4J_PASSWORD={password} tox -e jupyter-notebook-session-ci" + cmd = f"AURA_ENV=staging CLIENT_ID={client_id} CLIENT_SECRET={client_secret} {project_id_part} AURA_INSTANCEID={instance_id} NEO4J_URI={uri} NEO4J_USERNAME={username} NEO4J_PASSWORD={password} tox -e jupyter-notebook-session-ci" if os.system(cmd) != 0: raise Exception("Failed to run notebooks") diff --git a/src/graphdatascience/query_runner/db_environment_resolver.py b/src/graphdatascience/query_runner/db_environment_resolver.py new file mode 100644 index 000000000..e966b2936 --- /dev/null +++ b/src/graphdatascience/query_runner/db_environment_resolver.py @@ -0,0 +1,17 @@ +from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner + + +class DbEnvironmentResolver: + @staticmethod + def hosted_in_aura(db_runner: Neo4jQueryRunner) -> bool: + return ( + db_runner.run_retryable_cypher(""" + CALL dbms.components() YIELD name, versions + WHERE name = "Neo4j Kernel" + UNWIND versions as v + WITH name, v + WHERE v ENDS WITH "aura" + RETURN count(*) <> 0 + """).squeeze() + is True + ) diff --git a/src/graphdatascience/session/aura_api.py b/src/graphdatascience/session/aura_api.py index 275fd3a1f..4ec651377 100644 --- a/src/graphdatascience/session/aura_api.py +++ b/src/graphdatascience/session/aura_api.py @@ -121,14 +121,14 @@ def get_or_create_session( self, name: str, memory: SessionMemoryValue, - dbid: str | None = None, + instance_id: str | None = None, ttl: timedelta | None = None, cloud_location: CloudLocation | None = None, ) -> SessionDetails: json = {"name": name, "memory": memory.value, "project_id": self._project_id} - if dbid: - json["instance_id"] = dbid + if instance_id: + json["instance_id"] = instance_id if ttl: json["ttl"] = f"{ttl.total_seconds()}s" diff --git a/src/graphdatascience/session/aura_graph_data_science.py b/src/graphdatascience/session/aura_graph_data_science.py index 47397deda..d9c18e114 100644 --- a/src/graphdatascience/session/aura_graph_data_science.py +++ b/src/graphdatascience/session/aura_graph_data_science.py @@ -46,7 +46,7 @@ def create( show_progress: bool = True, ) -> AuraGraphDataScience: session_bolt_query_runner = Neo4jQueryRunner.create_for_session( - endpoint=session_bolt_connection_info.uri, + endpoint=session_bolt_connection_info.get_uri(), auth=session_bolt_connection_info.get_auth(), show_progress=show_progress, ) @@ -76,7 +76,7 @@ def create( db_bolt_query_runner = db_endpoint else: db_bolt_query_runner = Neo4jQueryRunner.create_for_db( - db_endpoint.uri, + db_endpoint.get_uri(), db_endpoint.get_auth(), aura_ds=True, show_progress=False, diff --git a/src/graphdatascience/session/dbms_connection_info.py b/src/graphdatascience/session/dbms_connection_info.py index 756400627..46d219734 100644 --- a/src/graphdatascience/session/dbms_connection_info.py +++ b/src/graphdatascience/session/dbms_connection_info.py @@ -13,13 +13,17 @@ class DbmsConnectionInfo: Supports both username/password as well as the authentication options provided by the Neo4j Python driver. """ - uri: str + # 'uri' or 'aura_instance_id' must be provided. + uri: str | None = None + username: str | None = None password: str | None = None database: str | None = None # Optional: typed authentication, used instead of username/password. Supports for example a token. See https://neo4j.com/docs/python-manual/current/connect-advanced/#authentication-methods auth: Auth | None = None + aura_instance_id: str | None = None + def __post_init__(self) -> None: # Validate auth fields if (self.username or self.password) and self.auth: @@ -28,6 +32,9 @@ def __post_init__(self) -> None: "Please provide either a username/password or a token." ) + if (self.aura_instance_id is None) and (self.uri is None): + raise ValueError("Either 'uri' or 'aura_instance_id' must be provided.") + def get_auth(self) -> Auth | None: """ Returns: @@ -38,6 +45,14 @@ def get_auth(self) -> Auth | None: auth = basic_auth(self.username, self.password) return auth + def set_uri(self, uri: str) -> None: + self.uri = uri + + def get_uri(self) -> str: + if not self.uri: + raise ValueError("'uri' is not provided.") + return self.uri + @staticmethod def from_env() -> DbmsConnectionInfo: """ @@ -47,10 +62,17 @@ def from_env() -> DbmsConnectionInfo: - NEO4J_USERNAME - NEO4J_PASSWORD - NEO4J_DATABASE + - AURA_INSTANCEID """ - uri = os.environ["NEO4J_URI"] username = os.environ.get("NEO4J_USERNAME", "neo4j") password = os.environ["NEO4J_PASSWORD"] database = os.environ.get("NEO4J_DATABASE") + aura_instance_id = os.environ.get("AURA_INSTANCEID") + + # instance id takes precedence over uri + if not aura_instance_id: + uri = os.environ["NEO4J_URI"] + else: + uri = None - return DbmsConnectionInfo(uri, username, password, database) + return DbmsConnectionInfo(uri, username, password, database, aura_instance_id=aura_instance_id) diff --git a/src/graphdatascience/session/dedicated_sessions.py b/src/graphdatascience/session/dedicated_sessions.py index b4296398e..ce8f8a75e 100644 --- a/src/graphdatascience/session/dedicated_sessions.py +++ b/src/graphdatascience/session/dedicated_sessions.py @@ -6,6 +6,7 @@ from typing import Any from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication +from graphdatascience.query_runner.db_environment_resolver import DbEnvironmentResolver from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.session.algorithm_category import AlgorithmCategory from graphdatascience.session.aura_api import AuraApi @@ -66,29 +67,49 @@ def get_or_create( arrow_client_options: dict[str, Any] | None = None, ) -> AuraGraphDataScience: if db_connection is None: - if not cloud_location: - raise ValueError("cloud_location must be provided for creating standalone sessions.") - - session_details = self._get_or_create_standalone_session(session_name, memory.value, cloud_location, ttl) db_runner = None + aura_db_instance = None else: - db_runner = self._create_db_runner(db_connection, neo4j_driver_options) + if aura_instance_id := db_connection.aura_instance_id: + aura_db_instance = self._aura_api.list_instance(aura_instance_id) - dbid = AuraApi.extract_id(db_connection.uri) - aura_db_instance = self._aura_api.list_instance(dbid) + if not aura_db_instance: + raise ValueError( + f"Aura instance with id `{aura_instance_id}` could not be found. Please verify that the instance id is correct and that you have access to the Aura instance." + ) - if aura_db_instance is None: - if not cloud_location: - raise ValueError("cloud_location must be provided for sessions against a self-managed DB.") + db_connection.set_uri(aura_db_instance.connection_url) - session_details = self._get_or_create_self_managed_session( - session_name, memory.value, cloud_location, ttl - ) + db_runner = self._create_db_runner(db_connection, neo4j_driver_options) else: - if cloud_location is not None: - raise ValueError("cloud_location cannot be provided for sessions against an AuraDB.") + db_runner = self._create_db_runner(db_connection, neo4j_driver_options) + + if DbEnvironmentResolver.hosted_in_aura(db_runner): + warnings.warn( + DeprecationWarning( + "Deriving the Aura instance from the database URI is deprecated and will be removed in a future release. " + "Please specify the `aura_instance_id` in the `db_connection` argument." + ) + ) + + aura_instance_id = AuraApi.extract_id(db_connection.get_uri()) + aura_db_instance = self._aura_api.list_instance(aura_instance_id) + if not aura_db_instance: + raise ValueError( + f"Aura instance with id `{aura_instance_id}` could not be found. Please specify the `aura_instance_id` in the `db_connection` argument." + ) + else: + aura_db_instance = None + + if aura_db_instance is None: + if not cloud_location: + raise ValueError("cloud_location must be provided for sessions not attached to an AuraDB.") - session_details = self._get_or_create_attached_session(session_name, memory.value, dbid, ttl) + session_details = self._get_or_create_self_managed_session(session_name, memory.value, cloud_location, ttl) + else: + if cloud_location is not None: + raise ValueError("cloud_location cannot be provided for sessions against an AuraDB.") + session_details = self._get_or_create_attached_session(session_name, memory.value, aura_db_instance.id, ttl) self._await_session_running(session_details, timeout) @@ -112,7 +133,7 @@ def _create_db_runner( self, db_connection: DbmsConnectionInfo, config: dict[str, Any] | None = None ) -> Neo4jQueryRunner: db_runner = Neo4jQueryRunner.create_for_db( - endpoint=db_connection.uri, + endpoint=db_connection.get_uri(), auth=db_connection.get_auth(), aura_ds=True, show_progress=False, @@ -187,9 +208,9 @@ def _get_or_create_standalone_session( return self._aura_api.get_or_create_session(session_name, memory, ttl=ttl, cloud_location=cloud_location) def _get_or_create_attached_session( - self, session_name: str, memory: SessionMemoryValue, dbid: str, ttl: timedelta | None = None + self, session_name: str, memory: SessionMemoryValue, instance_id: str, ttl: timedelta | None = None ) -> SessionDetails: - return self._aura_api.get_or_create_session(name=session_name, dbid=dbid, memory=memory, ttl=ttl) + return self._aura_api.get_or_create_session(name=session_name, instance_id=instance_id, memory=memory, ttl=ttl) def _get_or_create_self_managed_session( self, diff --git a/src/graphdatascience/session/gds_sessions.py b/src/graphdatascience/session/gds_sessions.py index 5daecf1db..536bbe1e9 100644 --- a/src/graphdatascience/session/gds_sessions.py +++ b/src/graphdatascience/session/gds_sessions.py @@ -122,7 +122,7 @@ def get_or_create( ttl: (timedelta | None): The sessions time to live after inactivity in seconds. cloud_location (CloudLocation | None): The cloud location. Required if the GDS session is for a self-managed database. timeout (int | None): Optional timeout (in seconds) when waiting for session to become ready. If unset the method will wait forever. If set and session does not become ready an exception will be raised. It is user responsibility to ensure resource gets cleaned up in this situation. - neo4j_driver_config (dict[str, Any] | None): Optional configuration for the Neo4j driver to the Neo4j DBMS. Only relevant if `db_connection` is specified.. + neo4j_driver_config (dict[str, Any] | None): Optional configuration for the Neo4j driver to the Neo4j DBMS. Only relevant if `db_connection` is specified.. arrow_client_options (dict[str, Any] | None): Optional configuration for the Arrow Flight client. Returns: AuraGraphDataScience: The session. diff --git a/tests/integration/test_db_environment_resolver.py b/tests/integration/test_db_environment_resolver.py new file mode 100644 index 000000000..571d32d97 --- /dev/null +++ b/tests/integration/test_db_environment_resolver.py @@ -0,0 +1,14 @@ +import pytest + +from graphdatascience.query_runner.db_environment_resolver import DbEnvironmentResolver +from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner + + +@pytest.mark.only_on_aura +def test_hosted_in_aura_aura_dbms(aura_runner: Neo4jQueryRunner) -> None: + assert DbEnvironmentResolver.hosted_in_aura(aura_runner) + + +@pytest.mark.skip_on_aura +def test_hosted_in_aura_self_managed_dbms(runner: Neo4jQueryRunner) -> None: + assert not DbEnvironmentResolver.hosted_in_aura(runner) diff --git a/tests/integration/test_progress_logging.py b/tests/integration/test_progress_logging.py index 3d82604ec..af0ce8b80 100644 --- a/tests/integration/test_progress_logging.py +++ b/tests/integration/test_progress_logging.py @@ -5,7 +5,7 @@ from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner from graphdatascience.query_runner.session_query_runner import SessionQueryRunner from tests.unit.conftest import CollectingQueryRunner -from tests.unit.test_session_query_runner import FakeArrowClient +from tests.unit.session.test_session_query_runner import FakeArrowClient def test_disabled_progress_logging(neo4j_driver: Driver) -> None: diff --git a/tests/integrationV2/procedure_surface/plugin/conftest.py b/tests/integrationV2/procedure_surface/plugin/conftest.py index b962eff37..64bc64599 100644 --- a/tests/integrationV2/procedure_surface/plugin/conftest.py +++ b/tests/integrationV2/procedure_surface/plugin/conftest.py @@ -31,7 +31,7 @@ def neo4j_connection(gds_plugin_container: Neo4jContainer) -> Generator[DbmsConn @pytest.fixture(scope="package") def query_runner(neo4j_connection: DbmsConnectionInfo) -> Generator[Neo4jQueryRunner, None, None]: query_runner = Neo4jQueryRunner.create_for_db( - neo4j_connection.uri, + neo4j_connection.get_uri(), (neo4j_connection.username, neo4j_connection.password), # type: ignore ) diff --git a/tests/integrationV2/procedure_surface/plugin/test_plugin_walking_skeleton.py b/tests/integrationV2/procedure_surface/plugin/test_plugin_walking_skeleton.py index 9ad7c1129..8fe8a50f6 100644 --- a/tests/integrationV2/procedure_surface/plugin/test_plugin_walking_skeleton.py +++ b/tests/integrationV2/procedure_surface/plugin/test_plugin_walking_skeleton.py @@ -9,7 +9,7 @@ @pytest.fixture(scope="package") def gds(neo4j_connection: DbmsConnectionInfo) -> GraphDataScience: return GraphDataScience( - endpoint=neo4j_connection.uri, + endpoint=neo4j_connection.get_uri(), auth=(neo4j_connection.username, neo4j_connection.password), # type: ignore ) diff --git a/tests/unit/session/test_dbms_connection_info.py b/tests/unit/session/test_dbms_connection_info.py index df06bb8f9..899d471c1 100644 --- a/tests/unit/session/test_dbms_connection_info.py +++ b/tests/unit/session/test_dbms_connection_info.py @@ -1,4 +1,5 @@ import neo4j +import pytest from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo @@ -24,6 +25,22 @@ def test_dbms_connection_info_advanced_auth() -> None: assert dci.get_auth() == advanced_auth +def test_dbms_connection_info_aura_instance() -> None: + dci = DbmsConnectionInfo( + aura_instance_id="instance-id", + username="neo4j", + password="password", + database="neo4j", + ) + + with pytest.raises(ValueError, match="'uri' is not provided."): + dci.get_uri() + + dci.set_uri("neo4j+s://instance-id.databases.neo4j.io") + + assert dci.get_uri() == "neo4j+s://instance-id.databases.neo4j.io" + + def test_dbms_connection_info_fail_on_auth_and_username() -> None: try: DbmsConnectionInfo( @@ -39,3 +56,11 @@ def test_dbms_connection_info_fail_on_auth_and_username() -> None: ) else: assert False, "Expected ValueError was not raised" + + +def test_dbms_connection_info_fail_on_missing_instance_uri() -> None: + with pytest.raises(ValueError, match="Either 'uri' or 'aura_instance_id' must be provided."): + DbmsConnectionInfo( + username="neo4j", + password="password", + ) diff --git a/tests/unit/test_dedicated_sessions.py b/tests/unit/session/test_dedicated_sessions.py similarity index 80% rename from tests/unit/test_dedicated_sessions.py rename to tests/unit/session/test_dedicated_sessions.py index 161ed7b3b..9b01fd7ff 100644 --- a/tests/unit/test_dedicated_sessions.py +++ b/tests/unit/session/test_dedicated_sessions.py @@ -57,21 +57,24 @@ def get_or_create_session( self, name: str, memory: SessionMemoryValue, - dbid: str | None = None, + instance_id: str | None = None, ttl: timedelta | None = None, cloud_location: CloudLocation | None = None, ) -> SessionDetails: - if not cloud_location and dbid: - instance_details = self.list_instance(dbid) + if not cloud_location and instance_id: + instance_details = self.list_instance(instance_id) if instance_details: cloud_location = CloudLocation(instance_details.cloud_provider, instance_details.region) + id_prefix = instance_id + else: + id_prefix = "selfmanaged" for s in self._sessions.values(): if s.name == name: if ( s.memory == memory and s.user_id == self._console_user - and (not dbid or s.instance_id == dbid) + and (not instance_id or s.instance_id == instance_id) and (not cloud_location or s.cloud_location == cloud_location) ): if errors := s.errors: @@ -81,9 +84,9 @@ def get_or_create_session( raise RuntimeError("Session exists with different config") details = SessionDetailsWithErrors( - id=f"{dbid}-ffff{self.id_counter}", + id=f"{id_prefix}-ffff{self.id_counter}", name=name, - instance_id=dbid, + instance_id=instance_id, memory=memory, status="Creating", created_at=datetime.fromisoformat("2021-01-01T00:00:00+00:00"), @@ -130,7 +133,7 @@ def create_instance( id=id, username="neo4j", password="fake-pw", - connection_url=f"neo4j+s://{id}.neo4j.io", + connection_url=f"neo4j+s://{id}.databases.neo4j.io", ) specific_details = InstanceSpecificDetails( @@ -236,7 +239,7 @@ def test_list_session(aura_api: AuraApi) -> None: _setup_db_instance(aura_api) session = aura_api.get_or_create_session( name="gds-session-my-session-name", - dbid=aura_api.list_instances()[0].id, + instance_id=aura_api.list_instances()[0].id, memory=SessionMemory.m_8GB.value, ) sessions = DedicatedSessions(aura_api) @@ -264,7 +267,7 @@ def test_list_session_paused_instance(aura_api: AuraApi) -> None: session = aura_api.get_or_create_session( name="gds-session-my-session-name", - dbid=db.id, + instance_id=db.id, memory=SessionMemory.m_8GB.value, ) sessions = DedicatedSessions(aura_api) @@ -322,7 +325,7 @@ def test_list_session_gds_instance(aura_api: AuraApi) -> None: session = aura_api.get_or_create_session( name="gds-session-my-session-name", - dbid=db.id, + instance_id=db.id, memory=SessionMemory.m_8GB.value, ) sessions = DedicatedSessions(aura_api) @@ -342,7 +345,7 @@ def test_create_attached_session(mocker: MockerFixture, aura_api: AuraApi) -> No gds_parameters = sessions.get_or_create( "my-session", SessionMemory.m_8GB, - DbmsConnectionInfo("neo4j+s://ffff0.databases.neo4j.io", "dbuser", "db_pw"), + DbmsConnectionInfo(username="dbuser", password="db_pw", aura_instance_id="ffff0"), ttl=ttl, ) @@ -378,6 +381,58 @@ def test_create_attached_session(mocker: MockerFixture, aura_api: AuraApi) -> No assert actual_session.ttl == ttl +def test_create_attached_session_with_only_uri(mocker: MockerFixture, aura_api: AuraApi) -> None: + _setup_db_instance(aura_api) + + sessions = DedicatedSessions(aura_api) + + patch_construct_client(mocker) + patch_neo4j_query_runner(mocker) + + ttl = timedelta(hours=42) + with pytest.warns( + DeprecationWarning, + match=re.escape("Deriving the Aura instance from the database URI is deprecated"), + ): + gds_parameters = sessions.get_or_create( + "my-session", + SessionMemory.m_8GB, + DbmsConnectionInfo("neo4j+s://ffff0.databases.neo4j.io", "dbuser", "db_pw"), + ttl=ttl, + ) + + arrow_authentication = gds_parameters["arrow_authentication"] # type: ignore + del gds_parameters["arrow_authentication"] + + dbms_authentication = gds_parameters["db_runner"].pop("auth") # type: ignore + + assert (dbms_authentication.principal, dbms_authentication.credentials) == ("dbuser", "db_pw") + + assert gds_parameters == { # type: ignore + "db_runner": { + "endpoint": "neo4j+s://ffff0.databases.neo4j.io", + "aura_ds": True, + "database": None, + "show_progress": False, + "config": None, + }, + "session_bolt_connection_info": DbmsConnectionInfo( + uri="neo4j+s://foo.bar", username="client-id", password="client_secret" + ), + "session_id": "ffff0-ffff1", + "arrow_client_options": None, + } + + assert isinstance(arrow_authentication, AuraApiTokenAuthentication) + + assert len(sessions.list()) == 1 + actual_session = sessions.list()[0] + + assert actual_session.name == "my-session" + assert actual_session.user_id == "user-1" + assert actual_session.ttl == ttl + + def test_create_attached_session_passthrough_arrow_settings(mocker: MockerFixture, aura_api: AuraApi) -> None: _setup_db_instance(aura_api) @@ -390,7 +445,7 @@ def test_create_attached_session_passthrough_arrow_settings(mocker: MockerFixtur gds_parameters = sessions.get_or_create( "my-session", SessionMemory.m_8GB, - DbmsConnectionInfo("neo4j+s://ffff0.databases.neo4j.io", "dbuser", "db_pw"), + DbmsConnectionInfo(username="dbuser", password="db_pw", aura_instance_id="ffff0"), ttl=ttl, arrow_client_options={"foo": "bar"}, ) @@ -431,7 +486,7 @@ def test_create_standalone_session(mocker: MockerFixture, aura_api: AuraApi) -> sessions = DedicatedSessions(aura_api) patch_construct_client(mocker) - patch_neo4j_query_runner(mocker) + patch_neo4j_query_runner(mocker, hosted_in_aura=False) ttl = timedelta(hours=42) @@ -461,7 +516,7 @@ def test_create_standalone_session(mocker: MockerFixture, aura_api: AuraApi) -> "session_bolt_connection_info": DbmsConnectionInfo( uri="neo4j+s://foo.bar", username="client-id", password="client_secret" ), - "session_id": "None-ffff0", + "session_id": "selfmanaged-ffff0", "arrow_client_options": None, } @@ -475,7 +530,7 @@ def test_create_standalone_session(mocker: MockerFixture, aura_api: AuraApi) -> assert actual_session.ttl == ttl -def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: +def test_get_or_create_existing_session(mocker: MockerFixture, aura_api: AuraApi) -> None: _setup_db_instance(aura_api) sessions = DedicatedSessions(aura_api) @@ -486,12 +541,12 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: gds_args1 = sessions.get_or_create( "my-session", SessionMemory.m_8GB, - DbmsConnectionInfo("neo4j+s://ffff0.databases.neo4j.io", "dbuser", "db_pw"), + DbmsConnectionInfo(username="dbuser", password="db_pw", aura_instance_id="ffff0"), ) gds_args2 = sessions.get_or_create( "my-session", SessionMemory.m_8GB, - DbmsConnectionInfo("neo4j+s://ffff0.databases.neo4j.io", "dbuser", "db_pw"), + DbmsConnectionInfo(username="dbuser", password="db_pw", aura_instance_id="ffff0"), ) arrow_authentication = gds_args1["arrow_authentication"] # type: ignore @@ -526,6 +581,22 @@ def test_get_or_create(mocker: MockerFixture, aura_api: AuraApi) -> None: assert [i.name for i in sessions.list()] == ["my-session"] +def test_get_or_create_with_explicit_aura_instance_id(mocker: MockerFixture, aura_api: AuraApi) -> None: + db = _setup_db_instance(aura_api) + sessions = DedicatedSessions(aura_api) + patch_construct_client(mocker) + patch_neo4j_query_runner(mocker) + + sessions.get_or_create( + "my-session", + SessionMemory.m_8GB, + DbmsConnectionInfo( + username="dbuser", password="db_pw", aura_instance_id=db.id + ), # not part of list instances result + cloud_location=None, + ) + + def test_get_or_create_expired_session(mocker: MockerFixture, aura_api: AuraApi) -> None: db = _setup_db_instance(aura_api) @@ -552,7 +623,11 @@ def test_get_or_create_expired_session(mocker: MockerFixture, aura_api: AuraApi) with pytest.raises(SessionStatusError, match=re.escape("Session is in an unhealthy state")): sessions = DedicatedSessions(aura_api) - sessions.get_or_create("one", SessionMemory.m_8GB, DbmsConnectionInfo(db.connection_url, "", "")) + sessions.get_or_create( + "one", + SessionMemory.m_8GB, + DbmsConnectionInfo(username="dbuser", password="db_pw", aura_instance_id="ffff0"), + ) def test_get_or_create_soon_expired_session(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -579,7 +654,11 @@ def test_get_or_create_soon_expired_session(mocker: MockerFixture, aura_api: Aur with pytest.raises(Warning, match=re.escape("Session `one` is expiring in less than a day.")): sessions = DedicatedSessions(aura_api) - sessions.get_or_create("one", SessionMemory.m_8GB, DbmsConnectionInfo(db.connection_url, "", "")) + sessions.get_or_create( + "one", + SessionMemory.m_8GB, + DbmsConnectionInfo(username="dbuser", password="db_pw", aura_instance_id="ffff0"), + ) def test_get_or_create_for_auradb_with_cloud_location(mocker: MockerFixture, aura_api: AuraApi) -> None: @@ -594,17 +673,17 @@ def test_get_or_create_for_auradb_with_cloud_location(mocker: MockerFixture, aur sessions.get_or_create( "my-session", SessionMemory.m_8GB, - DbmsConnectionInfo(db.connection_url, "dbuser", "db_pw"), + DbmsConnectionInfo(username="dbuser", password="db_pw", aura_instance_id=db.id), cloud_location=CloudLocation(region="leipzig-1", provider="aws"), ) def test_get_or_create_for_without_cloud_location(mocker: MockerFixture, aura_api: AuraApi) -> None: sessions = DedicatedSessions(aura_api) - patch_neo4j_query_runner(mocker) + patch_neo4j_query_runner(mocker, hosted_in_aura=False) with pytest.raises( - ValueError, match=re.escape("cloud_location must be provided for sessions against a self-managed DB.") + ValueError, match=re.escape("cloud_location must be provided for sessions not attached to an AuraDB.") ): sessions.get_or_create( "my-session", @@ -614,17 +693,60 @@ def test_get_or_create_for_without_cloud_location(mocker: MockerFixture, aura_ap ) -def test_get_or_create_failed_session(mocker: MockerFixture, aura_api: AuraApi) -> None: - db = _setup_db_instance(aura_api) +def test_get_or_create_for_non_derivable_aura_instance_id(mocker: MockerFixture, aura_api: AuraApi) -> None: + sessions = DedicatedSessions(aura_api) + patch_neo4j_query_runner(mocker) + with ( + pytest.raises( + ValueError, + match=re.escape( + "Aura instance with id `06cba79f` could not be found. Please specify the `aura_instance_id` in the `db_connection` argument." + ), + ), + pytest.warns( + DeprecationWarning, match=re.escape("Deriving the Aura instance from the database URI is deprecated") + ), + ): + sessions.get_or_create( + "my-session", + SessionMemory.m_8GB, + DbmsConnectionInfo( + "neo4j+s://06cba79f.databases.neo4j.io", "dbuser", "db_pw" + ), # not part of list instances result + cloud_location=None, + ) + + +def test_get_or_create_for_non_accessible_aura_instance(mocker: MockerFixture, aura_api: AuraApi) -> None: + sessions = DedicatedSessions(aura_api) patch_neo4j_query_runner(mocker) + with pytest.raises( + ValueError, + match=re.escape( + "Aura instance with id `06cba79f` could not be found. Please verify that the instance id is correct and that you have access to the Aura instance." + ), + ): + sessions.get_or_create( + "my-session", + SessionMemory.m_8GB, + DbmsConnectionInfo( + "neo4j+s://foo.bar", "dbuser", "db_pw", aura_instance_id="06cba79f" + ), # not part of list instances result + cloud_location=None, + ) + + +def test_get_or_create_failed_session(mocker: MockerFixture, aura_api: AuraApi) -> None: + patch_neo4j_query_runner(mocker, False) + fake_aura_api = cast(FakeAuraApi, aura_api) fake_aura_api.add_session( SessionDetailsWithErrors( id="ffff0-ffff1", name="one", - instance_id=db.id, + instance_id=None, memory=SessionMemory.m_8GB.value, status="Failed", created_at=datetime.now(), @@ -638,19 +760,21 @@ def test_get_or_create_failed_session(mocker: MockerFixture, aura_api: AuraApi) ) ) - db_connection = DbmsConnectionInfo(db.connection_url, "", "") + db_connection = DbmsConnectionInfo("foo.bar", "", "") sessions = DedicatedSessions(aura_api) with pytest.raises( SessionStatusError, match=re.escape("Session is in an unhealthy state. Details: ['Reason: reason, Message: error']"), ): - sessions.get_or_create("one", SessionMemory.m_8GB, db_connection) + sessions.get_or_create( + "one", SessionMemory.m_8GB, db_connection, cloud_location=CloudLocation("aws", "leipzig-1") + ) def test_delete_session_by_name(aura_api: AuraApi) -> None: - aura_api.get_or_create_session("one", memory=SessionMemory.m_8GB.value, dbid="12345") - aura_api.get_or_create_session("other", memory=SessionMemory.m_8GB.value, dbid="123123") + aura_api.get_or_create_session("one", memory=SessionMemory.m_8GB.value, instance_id="12345") + aura_api.get_or_create_session("other", memory=SessionMemory.m_8GB.value, instance_id="123123") sessions = DedicatedSessions(aura_api) @@ -705,8 +829,8 @@ def test_delete_session_by_name_admin() -> None: def test_delete_session_by_id(aura_api: AuraApi) -> None: - s1 = aura_api.get_or_create_session("one", memory=SessionMemory.m_8GB.value, dbid="12345") - s2 = aura_api.get_or_create_session("other", memory=SessionMemory.m_8GB.value, dbid="123123") + s1 = aura_api.get_or_create_session("one", memory=SessionMemory.m_8GB.value, instance_id="12345") + s2 = aura_api.get_or_create_session("other", memory=SessionMemory.m_8GB.value, instance_id="123123") sessions = DedicatedSessions(aura_api) assert sessions.delete(session_id=s1.id) @@ -715,7 +839,7 @@ def test_delete_session_by_id(aura_api: AuraApi) -> None: def test_delete_nonexisting_session(aura_api: AuraApi) -> None: db1 = aura_api.create_instance("db1", SessionMemory.m_4GB.value, "aura", "leipzig").id - aura_api.get_or_create_session("one", memory=SessionMemory.m_8GB.value, dbid=db1) + aura_api.get_or_create_session("one", memory=SessionMemory.m_8GB.value, instance_id=db1) sessions = DedicatedSessions(aura_api) assert sessions.delete(session_name="other") is False @@ -741,7 +865,7 @@ def test_delete_session_paused_instance(aura_api: AuraApi) -> None: session = aura_api.get_or_create_session( name="gds-session-my-session-name", - dbid=paused_db.id, + instance_id=paused_db.id, memory=SessionMemory.m_8GB.value, ) sessions = DedicatedSessions(aura_api) @@ -754,15 +878,17 @@ def test_create_waiting_forever( mocker: MockerFixture, ) -> None: aura_api = FakeAuraApi(status_after_creating="updating") - _setup_db_instance(aura_api) sessions = DedicatedSessions(aura_api) - patch_neo4j_query_runner(mocker) + patch_neo4j_query_runner(mocker, False) with pytest.raises( - RuntimeError, match="Failed to get or create session `one`: Session `ffff0-ffff1` is not running" + RuntimeError, match="Failed to get or create session `one`: Session `selfmanaged-ffff0` is not running" ): sessions.get_or_create( - "one", SessionMemory.m_8GB, DbmsConnectionInfo("neo4j+ssc://ffff0.databases.neo4j.io", "", "") + "one", + SessionMemory.m_8GB, + DbmsConnectionInfo("neo4j+ssc://ffff0.databases.neo4j.io", "", ""), + cloud_location=CloudLocation("aws", "leipzig-1"), ) @@ -788,12 +914,16 @@ def _setup_db_instance(aura_api: AuraApi) -> InstanceCreateDetails: return aura_api.create_instance("test", SessionMemory.m_8GB.value, "aws", "leipzig-1") -def patch_neo4j_query_runner(mocker: MockerFixture) -> None: +def patch_neo4j_query_runner(mocker: MockerFixture, hosted_in_aura: bool = True) -> None: mocker.patch( "graphdatascience.query_runner.neo4j_query_runner.Neo4jQueryRunner.create_for_db", lambda *args, **kwargs: kwargs, ) mocker.patch("graphdatascience.session.dedicated_sessions.DedicatedSessions._validate_db_connection") + mocker.patch( + "graphdatascience.query_runner.db_environment_resolver.DbEnvironmentResolver.hosted_in_aura", + lambda *args, **kwargs: hosted_in_aura, + ) def patch_construct_client(mocker: MockerFixture) -> None: diff --git a/tests/unit/test_session_query_runner.py b/tests/unit/session/test_session_query_runner.py similarity index 100% rename from tests/unit/test_session_query_runner.py rename to tests/unit/session/test_session_query_runner.py diff --git a/tests/unit/test_session_sizes.py b/tests/unit/session/test_session_sizes.py similarity index 100% rename from tests/unit/test_session_sizes.py rename to tests/unit/session/test_session_sizes.py diff --git a/tests/unit/test_aura_api.py b/tests/unit/test_aura_api.py index c158346eb..01276f1c1 100644 --- a/tests/unit/test_aura_api.py +++ b/tests/unit/test_aura_api.py @@ -74,7 +74,7 @@ def assert_body(request: _RequestObjectProxy) -> bool: ) result = api.get_or_create_session( - name="name-0", dbid="dbid-1", memory=SessionMemory.m_4GB.value, ttl=timedelta(seconds=42) + name="name-0", instance_id="dbid-1", memory=SessionMemory.m_4GB.value, ttl=timedelta(seconds=42) ) assert result == SessionDetails( diff --git a/tox.ini b/tox.ini index 917307671..06eaf2bb1 100644 --- a/tox.ini +++ b/tox.ini @@ -89,6 +89,8 @@ passenv = CLIENT_SECRET PROJECT_ID + + AURA_INSTANCEID NEO4J_URI NEO4J_USERNAME NEO4J_PASSWORD