diff --git a/.github/workflows/ci-cd-ds-platform-utils.yaml b/.github/workflows/ci-cd-ds-platform-utils.yaml index f620ff8..4a7031f 100644 --- a/.github/workflows/ci-cd-ds-platform-utils.yaml +++ b/.github/workflows/ci-cd-ds-platform-utils.yaml @@ -1,4 +1,4 @@ -name: Publish DS Projen +name: Publish DS Platform Utils on: workflow_dispatch: @@ -16,7 +16,7 @@ jobs: - name: Checkout Repository uses: actions/checkout@v4 with: - fetch-depth: 0 # Fetch all history for version tagging + fetch-depth: 0 # Fetch all history for version tagging - name: Set up uv uses: astral-sh/setup-uv@v5 @@ -44,7 +44,7 @@ jobs: cache-dependency-glob: "${{ github.workspace }}/uv.lock" - name: Run pre-commit hooks - run: SKIP=no-commit-to-branch uv run poe lint # using poethepoet needs to be setup before using poe lint + run: SKIP=no-commit-to-branch uv run poe lint # using poethepoet needs to be setup before using poe lint build-wheel: name: Build Wheel diff --git a/pyproject.toml b/pyproject.toml index 0f163f4..4ba58fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ds-platform-utils" -version = "0.2.3" +version = "0.3.0" description = "Utility library for Pattern Data Science." readme = "README.md" authors = [ diff --git a/src/ds_platform_utils/_snowflake/run_query.py b/src/ds_platform_utils/_snowflake/run_query.py new file mode 100644 index 0000000..1b6b517 --- /dev/null +++ b/src/ds_platform_utils/_snowflake/run_query.py @@ -0,0 +1,41 @@ +"""Shared Snowflake utility functions.""" + +import warnings +from typing import Iterable, Optional + +from snowflake.connector import SnowflakeConnection +from snowflake.connector.cursor import SnowflakeCursor +from snowflake.connector.errors import ProgrammingError + + +def _execute_sql(conn: SnowflakeConnection, sql: str) -> Optional[SnowflakeCursor]: + """Execute SQL statement(s) using Snowflake's ``connection.execute_string()`` and return the *last* resulting cursor. + + Snowflake's ``execute_string`` allows a single string containing multiple SQL + statements (separated by semicolons) to be executed at once. Unlike + ``cursor.execute()``, which handles exactly one statement and returns a single + cursor object, ``execute_string`` returns a **list of cursors**—one cursor for each + individual SQL statement in the batch. + + :param conn: Snowflake connection object + :param sql: SQL query or batch of semicolon-delimited SQL statements + :return: The cursor corresponding to the last executed statement, or None if no + statements were executed or if the SQL contains only whitespace/comments + """ + if not sql.strip(): + return None + + try: + cursors: Iterable[SnowflakeCursor] = conn.execute_string(sql.strip()) + + if cursors is None: + return None + + *_, last = cursors + return last + except ProgrammingError as e: + if "Empty SQL statement" in str(e): + # raise a warning and return None + warnings.warn("Empty SQL statement encountered; returning None.", category=UserWarning, stacklevel=2) + return None + raise diff --git a/src/ds_platform_utils/_snowflake/write_audit_publish.py b/src/ds_platform_utils/_snowflake/write_audit_publish.py index 6475bed..5ae6046 100644 --- a/src/ds_platform_utils/_snowflake/write_audit_publish.py +++ b/src/ds_platform_utils/_snowflake/write_audit_publish.py @@ -7,6 +7,7 @@ from jinja2 import DebugUndefined, Template from snowflake.connector.cursor import SnowflakeCursor +from ds_platform_utils._snowflake.run_query import _execute_sql from ds_platform_utils.metaflow._consts import NON_PROD_SCHEMA, PROD_SCHEMA @@ -200,8 +201,8 @@ def run_query(query: str, cursor: Optional[SnowflakeCursor] = None) -> None: print(f"Would execute query:\n{query}") return - # Count statements so we can tell Snowflake exactly how many to expect - cursor.execute(query, num_statements=0) # 0 means any number of statements + # run the query using _execute_sql utility which handles multiple statements via execute_string + _execute_sql(cursor.connection, query) cursor.connection.commit() @@ -216,7 +217,10 @@ def run_audit_query(query: str, cursor: Optional[SnowflakeCursor] = None) -> dic if cursor is None: return {"mock_result": True} - cursor.execute(query) + cursor = _execute_sql(cursor.connection, query) + if cursor is None: + return {} + result = cursor.fetchone() if not result: return {} @@ -243,11 +247,17 @@ def fetch_table_preview( if not cursor: return [{"mock_col": "mock_val"}] - cursor.execute(f""" + cursor = _execute_sql( + cursor.connection, + f""" SELECT * FROM {database}.{schema}.{table_name} LIMIT {n_rows}; - """) + """, + ) + if cursor is None: + return [] + columns = [col[0] for col in cursor.description] rows = cursor.fetchall() return [dict(zip(columns, row)) for row in rows] diff --git a/src/ds_platform_utils/metaflow/get_snowflake_connection.py b/src/ds_platform_utils/metaflow/get_snowflake_connection.py index 8eaeba1..926393f 100644 --- a/src/ds_platform_utils/metaflow/get_snowflake_connection.py +++ b/src/ds_platform_utils/metaflow/get_snowflake_connection.py @@ -4,6 +4,8 @@ from metaflow import Snowflake, current from snowflake.connector import SnowflakeConnection +from ds_platform_utils._snowflake.run_query import _execute_sql + #################### # --- Metaflow --- # #################### @@ -41,7 +43,12 @@ def get_snowflake_connection( In metaflow, each step is a separate Python process, so the connection will automatically be closed at the end of any steps that use this singleton. """ - return _create_snowflake_connection(use_utc=use_utc, query_tag=current.project_name) + if current and hasattr(current, "project_name"): + query_tag = current.project_name + else: + query_tag = None + + return _create_snowflake_connection(use_utc=use_utc, query_tag=query_tag) ##################### @@ -66,11 +73,10 @@ def _create_snowflake_connection( if query_tag: queries.append(f"ALTER SESSION SET QUERY_TAG = '{query_tag}';") - # Execute all queries in single batch - with conn.cursor() as cursor: - sql = "\n".join(queries) - _debug_print_query(sql) - cursor.execute(sql, num_statements=0) + # Merge into single SQL batch + sql = "\n".join(queries) + _debug_print_query(sql) + _execute_sql(conn, sql) return conn diff --git a/src/ds_platform_utils/metaflow/pandas.py b/src/ds_platform_utils/metaflow/pandas.py index 743d48d..d89c0ec 100644 --- a/src/ds_platform_utils/metaflow/pandas.py +++ b/src/ds_platform_utils/metaflow/pandas.py @@ -11,6 +11,7 @@ from snowflake.connector import SnowflakeConnection from snowflake.connector.pandas_tools import write_pandas +from ds_platform_utils._snowflake.run_query import _execute_sql from ds_platform_utils.metaflow._consts import NON_PROD_SCHEMA, PROD_SCHEMA from ds_platform_utils.metaflow.get_snowflake_connection import _debug_print_query, get_snowflake_connection from ds_platform_utils.metaflow.write_audit_publish import ( @@ -111,15 +112,14 @@ def publish_pandas( # noqa: PLR0913 (too many arguments) # set warehouse if warehouse is not None: - with conn.cursor() as cur: - cur.execute(f"USE WAREHOUSE {warehouse};") + _execute_sql(conn, f"USE WAREHOUSE {warehouse};") - # set query tag for cost tracking in select.dev - # REASON: because write_pandas() doesn't allow modifying the SQL query to add SQL comments in it directly, - # so we set a session query tag instead. - tags = get_select_dev_query_tags() - query_tag_str = json.dumps(tags) - cur.execute(f"ALTER SESSION SET QUERY_TAG = '{query_tag_str}';") + # set query tag for cost tracking in select.dev + # REASON: because write_pandas() doesn't allow modifying the SQL query to add SQL comments in it directly, + # so we set a session query tag instead. + tags = get_select_dev_query_tags() + query_tag_str = json.dumps(tags) + _execute_sql(conn, f"ALTER SESSION SET QUERY_TAG = '{query_tag_str}';") # https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/snowpark/api/snowflake.snowpark.Session.write_pandas write_pandas( @@ -198,16 +198,20 @@ def query_pandas_from_snowflake( current.card.append(Markdown(f"```sql\n{query}\n```")) conn: SnowflakeConnection = get_snowflake_connection(use_utc) - with conn.cursor() as cur: - if warehouse is not None: - cur.execute(f"USE WAREHOUSE {warehouse};") + if warehouse is not None: + _execute_sql(conn, f"USE WAREHOUSE {warehouse};") + cursor_result = _execute_sql(conn, query) + if cursor_result is None: + # No statements to execute, return empty DataFrame + df = pd.DataFrame() + else: # force_return_table=True -- returns a Pyarrow Table always even if the result is empty - result: pyarrow.Table = cur.execute(query).fetch_arrow_all(force_return_table=True) - + result: pyarrow.Table = cursor_result.fetch_arrow_all(force_return_table=True) df = result.to_pandas() df.columns = df.columns.str.lower() - current.card.append(Markdown("### Query Result")) - current.card.append(Table.from_dataframe(df.head())) - return df + current.card.append(Markdown("### Query Result")) + current.card.append(Table.from_dataframe(df.head())) + + return df diff --git a/src/ds_platform_utils/metaflow/write_audit_publish.py b/src/ds_platform_utils/metaflow/write_audit_publish.py index 04423a6..4672a64 100644 --- a/src/ds_platform_utils/metaflow/write_audit_publish.py +++ b/src/ds_platform_utils/metaflow/write_audit_publish.py @@ -10,6 +10,7 @@ from metaflow.cards import Artifact, Markdown, Table from snowflake.connector.cursor import SnowflakeCursor +from ds_platform_utils._snowflake.run_query import _execute_sql from ds_platform_utils.metaflow.get_snowflake_connection import get_snowflake_connection if TYPE_CHECKING: @@ -97,7 +98,7 @@ def get_select_dev_query_tags() -> Dict[str, str]: stacklevel=2, ) - def extract(prefix: str, default: str = "unknown") -> str: + def _extract(prefix: str, default: str = "unknown") -> str: for tag in fetched_tags: if tag.startswith(prefix + ":"): return tag.split(":", 1)[1] @@ -106,19 +107,19 @@ def extract(prefix: str, default: str = "unknown") -> str: # most of these will be unknown if no tags are set on the flow # (most likely for the flow runs which are triggered manually locally) return { - "app": extract( + "app": _extract( "ds.domain" ), # first tag after 'app:', is the domain of the flow, fetched from current tags of the flow - "workload_id": extract( + "workload_id": _extract( "ds.project" ), # second tag after 'workload_id:', is the project of the flow which it belongs to - "flow_name": current.flow_name, # name of the metaflow flow + "flow_name": current.flow_name, "project": current.project_name, # Project name from the @project decorator, lets us # identify the flow’s project without relying on user tags (added via --tag). "step_name": current.step_name, # name of the current step "run_id": current.run_id, # run_id: unique id of the current run "user": current.username, # username of user who triggered the run (argo-workflows if its a deployed flow) - "domain": extract("ds.domain"), # business unit (domain) of the flow, same as app + "domain": _extract("ds.domain"), # business unit (domain) of the flow, same as app "namespace": current.namespace, # namespace of the flow "perimeter": str(os.environ.get("OB_CURRENT_PERIMETER") or os.environ.get("OBP_PERIMETER")), "is_production": str( @@ -216,7 +217,7 @@ def publish( # noqa: PLR0913, D417 with conn.cursor() as cur: if warehouse is not None: - cur.execute(f"USE WAREHOUSE {warehouse}") + _execute_sql(conn, f"USE WAREHOUSE {warehouse}") last_op_was_write = False for operation in write_audit_publish( @@ -334,20 +335,28 @@ def fetch_table_preview( :param table_name: Table name :param cursor: Snowflake cursor """ - cursor.execute(f""" - SELECT * - FROM {database}.{schema}.{table_name} - LIMIT {n_rows}; - """) - columns = [col[0] for col in cursor.description] - rows = cursor.fetchall() - - # Create header row plus data rows - table_rows = [[Artifact(col) for col in columns]] # Header row - for row in rows: - table_rows.append([Artifact(val) for val in row]) # Data rows - - return [ - Markdown(f"### Table Preview: ({database}.{schema}.{table_name})"), - Table(table_rows), - ] + if cursor is None: + return [] + else: + result_cursor = _execute_sql( + cursor.connection, + f""" + SELECT * + FROM {database}.{schema}.{table_name} + LIMIT {n_rows}; + """, + ) + if result_cursor is None: + return [] + columns = [col[0] for col in result_cursor.description] + rows = result_cursor.fetchall() + + # Create header row plus data rows + table_rows = [[Artifact(col) for col in columns]] # Header row + for row in rows: + table_rows.append([Artifact(val) for val in row]) # Data rows + + return [ + Markdown(f"### Table Preview: ({database}.{schema}.{table_name})"), + Table(table_rows), + ] diff --git a/tests/unit_tests/snowflake/test__execute_sql.py b/tests/unit_tests/snowflake/test__execute_sql.py new file mode 100644 index 0000000..4816841 --- /dev/null +++ b/tests/unit_tests/snowflake/test__execute_sql.py @@ -0,0 +1,59 @@ +"""Functional test for _execute_sql.""" + +from typing import Generator + +import pytest +from snowflake.connector import SnowflakeConnection + +from ds_platform_utils._snowflake.run_query import _execute_sql +from ds_platform_utils.metaflow.get_snowflake_connection import get_snowflake_connection + + +@pytest.fixture(scope="module") +def snowflake_conn() -> Generator[SnowflakeConnection, None, None]: + """Get a Snowflake connection for testing.""" + yield get_snowflake_connection(use_utc=True) + + +def test_execute_sql_empty_string(snowflake_conn): + """Empty string returns None.""" + cursor = _execute_sql(snowflake_conn, "") + assert cursor is None + + +def test_execute_sql_whitespace_only(snowflake_conn): + """Whitespace-only string returns None.""" + cursor = _execute_sql(snowflake_conn, " \n\t ") + assert cursor is None + + +def test_execute_sql_only_semicolons(snowflake_conn): + """String with only semicolons returns None and raises warning.""" + with pytest.warns(UserWarning, match="Empty SQL statement encountered"): + cursor = _execute_sql(snowflake_conn, " ; ;") + assert cursor is None + + +def test_execute_sql_only_comments(snowflake_conn): + """String with only comments returns None and raises warning.""" + with pytest.warns(UserWarning, match="Empty SQL statement encountered"): + cursor = _execute_sql(snowflake_conn, "/* only comments */") + assert cursor is None + + +def test_execute_sql_single_statement(snowflake_conn): + """Single statement returns cursor with expected result.""" + cursor = _execute_sql(snowflake_conn, "SELECT 1 AS x;") + assert cursor is not None + rows = cursor.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 1 + + +def test_execute_sql_multi_statement(snowflake_conn): + """Multi-statement returns cursor for last statement only.""" + cursor = _execute_sql(snowflake_conn, "SELECT 1 AS x; SELECT 2 AS x;") + assert cursor is not None + rows = cursor.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 2 # Last statement result diff --git a/uv.lock b/uv.lock index de6dc7f..c2d957a 100644 --- a/uv.lock +++ b/uv.lock @@ -478,7 +478,7 @@ wheels = [ [[package]] name = "ds-platform-utils" -version = "0.2.3" +version = "0.3.0" source = { editable = "." } dependencies = [ { name = "jinja2" },