diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index d093c094..21869889 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -40,7 +40,7 @@ jobs: - name: Install dependencies run: | - uv sync --frozen --group dev + uv sync --frozen --group docs - name: Check changelog entry exists run: | diff --git a/docs/changelog/next_release/327.feature.rst b/docs/changelog/next_release/327.feature.rst new file mode 100644 index 00000000..d5ded3bc --- /dev/null +++ b/docs/changelog/next_release/327.feature.rst @@ -0,0 +1 @@ +Add SQL transformation to API diff --git a/syncmaster/schemas/v1/transfers/transformations/__init__.py b/syncmaster/schemas/v1/transfers/transformations/__init__.py index e9d9f05d..a8052e22 100644 --- a/syncmaster/schemas/v1/transfers/transformations/__init__.py +++ b/syncmaster/schemas/v1/transfers/transformations/__init__.py @@ -14,9 +14,12 @@ from syncmaster.schemas.v1.transfers.transformations.file_metadata_filter import ( FileMetadataFilter, ) +from syncmaster.schemas.v1.transfers.transformations.sql import ( + Sql, +) TransformationSchema = Annotated[ - DataframeRowsFilter | DataframeColumnsFilter | FileMetadataFilter, + DataframeRowsFilter | DataframeColumnsFilter | FileMetadataFilter | Sql, Field(discriminator="type"), ] diff --git a/syncmaster/schemas/v1/transfers/transformations/sql.py b/syncmaster/schemas/v1/transfers/transformations/sql.py new file mode 100644 index 00000000..9723844b --- /dev/null +++ b/syncmaster/schemas/v1/transfers/transformations/sql.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: 2023-present MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +import re +from typing import Literal + +from pydantic import BaseModel, field_validator + +ALLOWED_SELECT_QUERY = re.compile(r"^\s*(SELECT|WITH)\b", re.IGNORECASE) + + +class Sql(BaseModel): + type: Literal["sql"] + query: str + dialect: Literal["spark"] = "spark" + + @field_validator("query", mode="after") + @classmethod + def _validate_query(cls, value: str) -> str: + if not ALLOWED_SELECT_QUERY.match(value): + msg = f"Query must be a SELECT statement, got '{value}'" + raise ValueError(msg) + + return value diff --git a/syncmaster/schemas/v1/transformation_types.py b/syncmaster/schemas/v1/transformation_types.py index 2560db0c..2b884d49 100644 --- a/syncmaster/schemas/v1/transformation_types.py +++ b/syncmaster/schemas/v1/transformation_types.py @@ -5,3 +5,4 @@ DATAFRAME_ROWS_FILTER = Literal["dataframe_rows_filter"] DATAFRAME_COLUMNS_FILTER = Literal["dataframe_columns_filter"] FILE_METADATA_FILTER = Literal["file_metadata_filter"] +SQL = Literal["sql"] diff --git a/tests/test_unit/test_transfers/test_create_transfer.py b/tests/test_unit/test_transfers/test_create_transfer.py index b554f2ba..5a370721 100644 --- a/tests/test_unit/test_transfers/test_create_transfer.py +++ b/tests/test_unit/test_transfers/test_create_transfer.py @@ -550,13 +550,13 @@ async def test_superuser_can_create_transfer( "message": ( "Input tag 'some unknown transformation type' found using 'type' " "does not match any of the expected tags: 'dataframe_rows_filter', " - "'dataframe_columns_filter', 'file_metadata_filter'" + "'dataframe_columns_filter', 'file_metadata_filter', 'sql'" ), "code": "union_tag_invalid", "context": { "discriminator": "'type'", "expected_tags": ( - "'dataframe_rows_filter', 'dataframe_columns_filter', 'file_metadata_filter'" + "'dataframe_rows_filter', 'dataframe_columns_filter', 'file_metadata_filter', 'sql'" ), "tag": "some unknown transformation type", }, @@ -809,6 +809,72 @@ async def test_superuser_can_create_transfer( }, id="invalid_name_regexp", ), + pytest.param( + { + "transformations": [ + { + "type": "sql", + "query": "INSERT INTO table1 VALUES (1, 2)", + "dialect": "spark", + }, + ], + }, + { + "error": { + "code": "invalid_request", + "message": "Invalid request", + "details": [ + { + "location": [ + "body", + "transformations", + 0, + "sql", + "query", + ], + "message": "Value error, Query must be a SELECT statement, got 'INSERT INTO table1 VALUES (1, 2)'", + "code": "value_error", + "context": {}, + "input": "INSERT INTO table1 VALUES (1, 2)", + }, + ], + }, + }, + id="sql_non_select_query", + ), + pytest.param( + { + "transformations": [ + { + "type": "sql", + "query": None, + "dialect": "spark", + }, + ], + }, + { + "error": { + "code": "invalid_request", + "message": "Invalid request", + "details": [ + { + "location": [ + "body", + "transformations", + 0, + "sql", + "query", + ], + "message": "Input should be a valid string", + "code": "string_type", + "context": {}, + "input": None, + }, + ], + }, + }, + id="sql_non_string_query", + ), pytest.param( { "resources": { @@ -1290,3 +1356,106 @@ async def test_superuser_cannot_create_transfer_with_unknown_queue_error( "details": None, }, } + + +async def test_developer_plus_can_create_transfer_with_sql_transformation( + client: AsyncClient, + two_group_connections: tuple[MockConnection, MockConnection], + session: AsyncSession, + role_developer_plus: UserTestRoles, + group_queue: Queue, + mock_group: MockGroup, +): + first_connection, second_connection = two_group_connections + user = mock_group.get_member_of_role(role_developer_plus) + + response = await client.post( + "v1/transfers", + headers={"Authorization": f"Bearer {user.token}"}, + json={ + "group_id": mock_group.group.id, + "name": "new test transfer with sql", + "source_connection_id": first_connection.id, + "target_connection_id": second_connection.id, + "source_params": {"type": "postgres", "table_name": "schema.source_table"}, + "target_params": {"type": "postgres", "table_name": "schema.target_table"}, + "transformations": [ + { + "type": "sql", + "query": "SELECT * FROM table", + "dialect": "spark", + }, + ], + "queue_id": group_queue.id, + }, + ) + assert response.status_code == 200, response.text + + transfer = ( + await session.scalars( + select(Transfer).filter_by( + name="new test transfer with sql", + group_id=mock_group.group.id, + ), + ) + ).one() + + assert response.json() == { + "id": transfer.id, + "group_id": transfer.group_id, + "name": transfer.name, + "description": transfer.description, + "schedule": transfer.schedule, + "is_scheduled": transfer.is_scheduled, + "source_connection_id": transfer.source_connection_id, + "target_connection_id": transfer.target_connection_id, + "source_params": transfer.source_params, + "target_params": transfer.target_params, + "strategy_params": transfer.strategy_params, + "transformations": transfer.transformations, + "resources": transfer.resources, + "queue_id": transfer.queue_id, + } + + +@pytest.mark.parametrize( + "query", + [ + "SELECT col1, col2 FROM table1 WHERE col1 > 100;", + "select col1 from table1", + " SELECT col1 FROM table1 ", + " WITH some AS (SELECT col1 FROM table1) SELECT * FROM some;", + ], +) +async def test_sql_transformation_with_valid_select_queries( + client: AsyncClient, + two_group_connections: tuple[MockConnection, MockConnection], + session: AsyncSession, + superuser: MockUser, + group_queue: Queue, + mock_group: MockGroup, + query: str, +): + first_conn, second_conn = two_group_connections + + response = await client.post( + "v1/transfers", + headers={"Authorization": f"Bearer {superuser.token}"}, + json={ + "group_id": mock_group.id, + "name": f"test transfer sql {query}", + "source_connection_id": first_conn.id, + "target_connection_id": second_conn.id, + "source_params": {"type": "postgres", "table_name": "schema.source_table"}, + "target_params": {"type": "postgres", "table_name": "schema.target_table"}, + "transformations": [ + { + "type": "sql", + "query": query, + "dialect": "spark", + }, + ], + "queue_id": group_queue.id, + }, + ) + assert response.status_code == 200, response.text