diff --git a/README.md b/README.md index c18fda953..453183e3b 100644 --- a/README.md +++ b/README.md @@ -829,6 +829,16 @@ For example: | `role` | `DATACONTRACT_SNOWFLAKE_ROLE` | | `connection_timeout` | `DATACONTRACT_SNOWFLAKE_CONNECTION_TIMEOUT` | +##### EV Authentication options + +| Soda optionnal parameter | Environment Variable | +|--------------------------|-------------------------------------------------| +| `authenticator` | `DATACONTRACT_SNOWFLAKE_AUTHENTICATOR` | +| `private_key` | `DATACONTRACT_SNOWFLAKE_PRIVATE_KEY` | +| `private_key_passphrase` | `DATACONTRACT_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE` | +| `private_key_path` | `DATACONTRACT_SNOWFLAKE_PRIVATE_KEY_PATH` | + + Beware, that parameters: * `account` * `database` @@ -1619,6 +1629,7 @@ For more information about the Excel template structure, visit the [ODCS Excel T │ [default: None] │ │ --source TEXT The path to the file that │ │ should be imported. │ +│ also snowflake account │ │ [default: None] │ │ --dialect TEXT The SQL dialect to use │ │ when importing SQL files, │ @@ -1682,6 +1693,7 @@ For more information about the Excel template structure, visit the [ODCS Excel T │ [default: None] │ │ --id TEXT The identifier for the the │ │ data contract. │ +│ --database TEXT Snowflake target database │ │ [default: None] │ │ --debug --no-debug Enable debug logging │ │ [default: no-debug] │ @@ -1897,6 +1909,20 @@ Example: datacontract import --format protobuf --source "test.proto" ``` +#### snowflake + +Importing from snowflake schema. Specify snowflake workspace account in `source` parameter, database name `database` and schema in `schema`. +Multiple authentification are supported, +login/password using the `DATACONTRACT_SNOWFLAKE_ ...` test environement variable are setup, +MFA using external browser is selected when `DATACONTRACT_SNOWFLAKE_PASSWORD` is missing +TOML file authentification using the default profile when `SNOWFLAKE_DEFAULT_CONNECTION_NAME` environment variable is defined + +Example: + +```bash +datacontract import --format snowflake --source account.canada-central.azure --database databaseName --schema schemaName +``` + ### catalog ``` diff --git a/datacontract/cli.py b/datacontract/cli.py index fea5f8518..3e1a01599 100644 --- a/datacontract/cli.py +++ b/datacontract/cli.py @@ -445,6 +445,10 @@ def import_( Optional[str], typer.Option(help="The identifier for the the data contract."), ] = None, + database: Annotated[ + Optional[str], + typer.Option(help="The snowflake database name."), + ] = None, debug: debug_option = None, ): """ @@ -469,6 +473,7 @@ def import_( iceberg_table=iceberg_table, owner=owner, id=id, + database=database, ) if output is None: console.print(result.to_yaml(), markup=False, soft_wrap=True) diff --git a/datacontract/imports/importer.py b/datacontract/imports/importer.py index 24961fb8c..a87b71719 100644 --- a/datacontract/imports/importer.py +++ b/datacontract/imports/importer.py @@ -38,6 +38,7 @@ class ImportFormat(str, Enum): csv = "csv" protobuf = "protobuf" excel = "excel" + snowflake = "snowflake" @classmethod def get_supported_formats(cls): diff --git a/datacontract/imports/importer_factory.py b/datacontract/imports/importer_factory.py index 6566f18b1..55f2186a0 100644 --- a/datacontract/imports/importer_factory.py +++ b/datacontract/imports/importer_factory.py @@ -119,7 +119,11 @@ def load_module_class(module_path, class_name): module_path="datacontract.imports.excel_importer", class_name="ExcelImporter", ) - +importer_factory.register_lazy_importer( + name=ImportFormat.snowflake, + module_path="datacontract.imports.snowflake_importer", + class_name="SnowflakeImporter", +) importer_factory.register_lazy_importer( name=ImportFormat.json, diff --git a/datacontract/imports/snowflake_importer.py b/datacontract/imports/snowflake_importer.py new file mode 100644 index 000000000..6b535538d --- /dev/null +++ b/datacontract/imports/snowflake_importer.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import json +import os +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict + +import yaml +from open_data_contract_standard.model import OpenDataContractStandard + +from datacontract.imports.importer import Importer +from datacontract.model.exceptions import DataContractException + + +class SnowflakeImporter(Importer): + def import_source(self, source: str, import_args: dict) -> OpenDataContractStandard: + if source is not None: + return import_Snowflake_from_connector( + account=source, + database=import_args.get("database"), + schema=import_args.get("schema"), + ) + + +def import_Snowflake_from_connector(account: str, database: str, schema: str) -> OpenDataContractStandard: + ## connect to snowflake and get cursor + conn = snowflake_cursor(account, database, schema) + try: + # To catch double_quoted identifier + from snowflake.connector.errors import ProgrammingError + except ImportError as e: + raise DataContractException( + type="schema", + result="failed", + name="snowflake extra missing", + reason="Install the extra datacontract-cli[snowflake] to use snowflake", + engine="datacontract", + original_exception=e, + ) + + with conn.cursor() as cur: + try: + cur.execute(f"USE SCHEMA {database}.{schema}") + schema_identifier = schema + except ProgrammingError: + # schema with double-quoted identifiers issue https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers + cur.execute(f'USE SCHEMA {database}."{schema}"') + schema_identifier = f'"{schema}"' + + cur.execute(f"SHOW COLUMNS IN SCHEMA {database}.{schema_identifier}") + schema_sfqid = str(cur.sfqid) + cur.execute(f"SHOW PRIMARY KEYS IN SCHEMA {database}.{schema_identifier}") + businessKey_sfqid = str(cur.sfqid) + # -- AS + # SET(col, pk) = (SELECT LAST_QUERY_ID(-2), LAST_QUERY_ID(-1));" + cur.execute_async(snowflake_query(account, schema, schema_sfqid, businessKey_sfqid)) + cur.get_results_from_sfqid(cur.sfqid) + # extract and save ddl script into sql file + json_contract = cur.fetchall() + + # Try to Preserve order when dumping to yaml properties as columns order matters + yaml.add_representer(dict, map_representer, Dumper=yaml.Dumper) + yaml.add_representer(OrderedDict, map_representer, Dumper=yaml.Dumper) + + if len(json_contract) == 0 or len(json_contract[0]) == 0: + raise DataContractException( + type="import", + result="failed", + name="snowflake import", + reason=f"No data contract returned from schema {schema} in database {database} please check connectivity and schema existence", + engine="datacontract", + ) + result_set = json.loads(json_contract[0][0], object_pairs_hook=OrderedDict) + sorted_properties = sort_schema_by_name_properties_by_ordinalPosition(result_set) + + toYaml = yaml.dump(sorted_properties, sort_keys=False) + + return OpenDataContractStandard.from_string(toYaml) + + +def map_representer(dumper, data): + return dumper.represent_dict(getattr(data, "items")()) + + +def snowflake_query(account: str, schema: str, schema_sfqid: str, businessKey_sfqid: str) -> str: + sqlStatement = """ + --SHOW COLUMNS; + --SHOW PRIMARY KEYS; + + --SET(schema_sfqid, businessKey_sfqid) = (SELECT LAST_QUERY_ID(-2), LAST_QUERY_ID(-1)); + +WITH Quality_Metric AS ( + SELECT + TABLE_NAME, + TYPES, + COLUMN_NAME, + ARRAY_AGG(quality) as qualities +FROM ( + SELECT + TABLE_NAME, + IFF(ARRAY_SIZE(ARGUMENT_NAMES) = 1 AND ARGUMENT_TYPES[0]::string = 'COLUMN', 'COLUMN', 'RECORD') as TYPES , + ARRAY_TO_STRING(ARGUMENT_NAMES, ',') as COLUMN_NAME, + OBJECT_CONSTRUCT( + 'id', CONCAT(LOWER(TABLE_NAME),'_',LOWER( + IFF(ARRAY_SIZE(ARGUMENT_NAMES) = 1 AND ARGUMENT_TYPES[0]::string = 'COLUMN', 'COLUMN', 'RECORD') + ), + '_', + LOWER(METRIC_NAME), + '_', + LOWER(ARRAY_TO_STRING(ARGUMENT_NAMES, ',')), + '_id'), + 'metric', COALESCE(ODCS_RULES.odcs_metric, METRIC_NAME), + COALESCE(odcs_operator,'mustBe'), VALUE + ) as quality, + FROM SNOWFLAKE.LOCAL.DATA_QUALITY_MONITORING_RESULTS + -- https://bitol-io.github.io/open-data-contract-standard/latest/data-quality/#metrics + LEFT JOIN (VALUES('NULL_COUNT', 'nullValues', 'mustBe'), + ('BLANK_COUNT','missingValues', 'mustBe'), + ('ROW_COUNT', 'rowCount','mustBeGreaterOrEqualTo'), + ('ACCEPTED_VALUES', 'invalidValues','mustBeLessThan'), + ('DUPLICATE_COUNT', 'duplicateValues','mustBeLessThan') + ) as ODCS_RULES(snowflake_metricName, odcs_metric, odcs_operator) ON METRIC_NAME = snowflake_metricName + WHERE TABLE_DATABASE = CURRENT_DATABASE() AND TABLE_SCHEMA = CURRENT_SCHEMA() + QUALIFY ROW_NUMBER() OVER (PARTITION BY TABLE_NAME, ARGUMENT_TYPES, METRIC_NAME, ARGUMENT_NAMES ORDER BY MEASUREMENT_TIME DESC) = 1 + ) as DMF_METRICS_RESULT + GROUP BY TABLE_NAME, TYPES, COLUMN_NAME +), +Server_Roles AS ( + SELECT P.TABLE_SCHEMA as table_schema, ARRAY_AGG( + OBJECT_CONSTRUCT( + 'role', P.GRANTEE, + 'access',LOWER(P.PRIVILEGE_TYPE), + 'firstLevelApprovers', P.GRANTOR + )) as Roles + FROM information_schema.table_privileges P + WHERE P.GRANTED_TO = 'ROLE' AND P.TABLE_SCHEMA = '{schema}' + GROUP BY P.TABLE_CATALOG, P.TABLE_SCHEMA +), +TagRef AS ( + SELECT + OBJECT_SCHEMA, + OBJECT_NAME, + COLUMN_NAME, + ARRAY_AGG(CONCAT(TAG_NAME,'=',TAG_VALUE)) as Tags + FROM SNOWFLAKE.ACCOUNT_USAGE.TAG_REFERENCES + WHERE OBJECT_SCHEMA = CURRENT_SCHEMA() + AND OBJECT_DATABASE = CURRENT_DATABASE() + GROUP BY OBJECT_SCHEMA, OBJECT_NAME, COLUMN_NAME +), +INFO_SCHEMA_COLUMNS AS ( + SELECT + "schema_name" as schema_name, + "table_name" as table_name, + "column_name" as "name", + "null?" = 'NOT_NULL' as required, + RIGHT("column_name",3) = '_SK' as "unique", + coalesce(GET_PATH(TRY_PARSE_JSON("comment"),'description'), "comment") as description, + CASE GET_PATH(TRY_PARSE_JSON("data_type"),'type')::string + WHEN 'TEXT' THEN 'string' + WHEN 'STRING' THEN 'string' + WHEN 'CHAR' THEN 'string' + WHEN 'FIXED' THEN 'number' + WHEN 'REAL' THEN 'number' + WHEN 'BOOLEAN' THEN 'boolean' + WHEN 'VARIANT' THEN 'object' + WHEN 'TIMESTAMP_TZ' THEN 'timestamp' + WHEN 'TIMESTAMP_NTZ' THEN 'timestamp' + WHEN 'TIMESTAMP_LTZ' THEN 'timestamp' + WHEN 'DATE' THEN 'date' + ELSE 'object' END as LogicalType, -- FIXED NUMBER + CASE GET_PATH(TRY_PARSE_JSON("data_type"),'type')::string + WHEN 'TEXT' THEN CONCAT('VARCHAR','(',GET_PATH(TRY_PARSE_JSON("data_type"),'length'),')') + WHEN 'STRING' THEN CONCAT('VARCHAR','(',GET_PATH(TRY_PARSE_JSON("data_type"),'length'),')') + WHEN 'CHAR' THEN CONCAT('VARCHAR','(',GET_PATH(TRY_PARSE_JSON("data_type"),'length'),')') + WHEN 'FIXED' THEN CONCAT('NUMBER','(',GET_PATH(TRY_PARSE_JSON("data_type"),'precision')::string,',',GET_PATH(TRY_PARSE_JSON("data_type"),'scale'),')',' ',"autoincrement") + WHEN 'BOOLEAN' THEN 'BOOLEAN' + WHEN 'TIMESTAMP_NTZ' THEN CONCAT('TIMESTAMP_NTZ','(',GET_PATH(TRY_PARSE_JSON("data_type"),'scale'),')') + ELSE GET_PATH(TRY_PARSE_JSON("data_type"),'type') END as PhysicalType, + IFF (GET_PATH(TRY_PARSE_JSON("data_type"),'type')::string = 'TEXT', GET_PATH(TRY_PARSE_JSON("data_type"),'length')::string , NULL) as logicalTypeOptions_maxlength, + IFF ("column_name" IN ('APP_NAME','CREATE_TS','CREATE_AUDIT_ID','UPDATE_TS','UPDATE_AUDIT_ID','CURRENT_RECORD_IND','DELETED_RECORD_IND', 'FILE_BLOB_PATH', 'FILE_ROW_NUMBER', 'FILE_LAST_MODIFIED', 'IS_VALID_IND', 'INVALID_MESSAGE' ), ARRAY_CONSTRUCT('metadata'), TR.Tags ) as tags, + IS_C.ORDINAL_POSITION, + GET_PATH(TRY_PARSE_JSON("data_type"),'precision') as CP_precision, + GET_PATH(TRY_PARSE_JSON("data_type"),'scale') as CP_scale, + "autoincrement" as CP_autoIncrement, + "default" as CP_default, + Q.qualities + FROM TABLE(RESULT_SCAN('$schema_sfqid')) as T + JOIN INFORMATION_SCHEMA.COLUMNS as IS_C ON T."table_name"= IS_C.TABLE_NAME + AND T."schema_name" = IS_C.TABLE_SCHEMA + AND T."column_name" = IS_C.COLUMN_NAME + AND T."database_name" = IS_C.TABLE_CATALOG + LEFT JOIN TagRef TR ON T."table_name" = TR.OBJECT_NAME + AND T."schema_name" = TR.OBJECT_SCHEMA + AND T."column_name" = TR.COLUMN_NAME + LEFT JOIN Quality_Metric Q ON( T."table_name" = Q.TABLE_NAME + AND T."column_name" = Q.COLUMN_NAME + AND 'COLUMN' = Q.TYPES) +), +INFO_SCHEMA_CONSTRAINTS AS ( +SELECT + "schema_name" as schema_name, + "table_name" as table_name, + "column_name" as "name", + IFF(RIGHT("column_name",3)='_SK', -1, "key_sequence") as primaryKeyPosition, + true as primaryKey, +FROM(TABLE(RESULT_SCAN('$businessKey_sfqid'))) +), +INFO_SCHEMA_TABLES AS ( +SELECT + T.TABLE_SCHEMA as table_schema, + T.TABLE_NAME as "name", + UPPER(T.TABLE_NAME) as physical_name, + NULLIF(coalesce(GET_PATH(TRY_PARSE_JSON(COMMENT),'description'), COMMENT),'') as description, + 'object' as logicalType, + lower(REPLACE(TABLE_TYPE,'BASE ','')) as physicalType, + 'quality', Q.qualities +FROM INFORMATION_SCHEMA.TABLES as T +LEFT JOIN Quality_Metric Q ON (T.TABLE_NAME= Q.TABLE_NAME + AND 'RECORD' = Q.TYPES) +), +PROPERTIES AS ( +SELECT +C.schema_name, +C.table_name, +ARRAY_AGG(OBJECT_CONSTRUCT( + 'id', REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(C."name", ' ', '_'), '(', ''), ')',''),'/','vs'),'.','_'),'&','and') ||'_propId', + 'name', C."name", + 'required', C.required, + 'unique', C."unique", + 'description', C.description, + 'logicalType', C.LogicalType, + IFF(logicalTypeOptions_maxlength::number IS NOT NULL,'logicalTypeOptions',NULL), IFF( logicalTypeOptions_maxlength::number IS NOT NULL , OBJECT_CONSTRUCT('maxLength',logicalTypeOptions_maxlength::number), NULL), + 'physicalType', TRIM(C.PhysicalType), + IFF(BK.primaryKey = true, 'primaryKey', NULL), IFF(BK.primaryKey = true,true, NULL), + IFF(BK.primaryKey = true, 'primaryKeyPosition', NULL), IFF(BK.primaryKey = true, BK.primaryKeyPosition, NULL), + IFF(C.tags IS NOT NULL,'tags',NULL) , IFF(C.tags IS NOT NULL, C.tags, NULL), + 'customProperties', ARRAY_CONSTRUCT_COMPACT( + OBJECT_CONSTRUCT( + 'property', 'ordinalPosition', + 'value', C.ORDINAL_POSITION + ), + OBJECT_CONSTRUCT( + 'property','scdType', + 'value', IFF( COALESCE(BK.primaryKey,false) ,0,1) + ), + IFF(BK.primaryKey = True AND Right(C."name",3) != '_SK', OBJECT_CONSTRUCT( + 'property','businessKey', + 'value', True), NULL), + IFF(C.CP_precision IS NOT NULL, OBJECT_CONSTRUCT( + 'property','precision', + 'value', C.CP_precision), NULL), + IFF(C.CP_scale IS NOT NULL, OBJECT_CONSTRUCT( + 'property','scale', + 'value', C.CP_scale), NULL), + IFF(NULLIF(C.CP_autoIncrement,'') IS NOT NULL, OBJECT_CONSTRUCT( + 'property','autoIncrement', + 'value', C.CP_autoIncrement), NULL), + IFF(NULLIF(C.CP_default,'') IS NOT NULL, OBJECT_CONSTRUCT( + 'property','defaultValue', + 'value', C.CP_default), NULL) + ), + 'quality', C.qualities + )) as properties +FROM INFO_SCHEMA_COLUMNS C +LEFT JOIN INFO_SCHEMA_CONSTRAINTS BK ON (C.schema_name = BK.schema_name + AND C.table_name = BK.table_name + AND C."name" = BK."name") + +GROUP BY C.schema_name, C.table_name +) +, SCHEMA_DEF AS ( +SELECT +T.table_schema, +ARRAY_AGG( OBJECT_CONSTRUCT( + 'id', REPLACE(T."name", ' ', '_') ||'_schId', + 'name',T."name", + 'physicalName',T.physical_name, + 'logicalType',T.logicalType, + 'physicalType',T.physicalType, + 'description',T.description, + 'properties', P.properties, + 'quality', T.qualities) + ) + as "schema" +FROM PROPERTIES P +LEFT JOIN INFO_SCHEMA_TABLES T ON (P.schema_name = T.table_schema + AND P.table_name = T."name") +WHERE T.table_schema = '{schema}' -- Ignore PUBLIC (default) +GROUP BY T.table_schema +) +SELECT +OBJECT_CONSTRUCT('apiVersion', 'v3.1.0', +'kind','DataContract', +'id', UUID_STRING(), +'name',SCHEMA_DEF.table_schema, +'version','0.0.1', +'domain','dataplatform', +'status','development', +'description', OBJECT_CONSTRUCT( + 'purpose','This data can be used for analytical purposes', + 'limitations', 'not defined', + 'usage', 'not defined'), +'customProperties', ARRAY_CONSTRUCT( OBJECT_CONSTRUCT('property','owner', 'value','dataplatform')), +'servers', ARRAY_CONSTRUCT( + OBJECT_CONSTRUCT( + 'server','snowflake_dev', + 'type','snowflake', + 'account', '{account}', + 'environment', 'dev', + 'host', '{account}.snowflakecomputing.com', + 'port', 443, + 'database', CURRENT_DATABASE(), + 'warehouse', CURRENT_WAREHOUSE(), + 'schema', SCHEMA_DEF.table_schema, + 'roles', Server_Roles.Roles + )), +'schema', "schema") as "DataContract (ODCS)" +FROM SCHEMA_DEF +LEFT JOIN Server_Roles ON SCHEMA_DEF.table_schema = Server_Roles.table_schema +WHERE SCHEMA_DEF.table_schema IS NOT NULL + """ + return ( + sqlStatement.replace("$schema_sfqid", schema_sfqid) + .replace("$businessKey_sfqid", businessKey_sfqid) + .replace("{schema}", schema) + .replace("{account}", account) + ) + + +def snowflake_cursor(account: str, databasename: str = "DEMO_DB", schema: str = "PUBLIC"): + try: + from snowflake.connector import connect + except ImportError as e: + raise DataContractException( + type="schema", + result="failed", + name="snowflake extra missing", + reason="Install the extra datacontract-cli[snowflake] to use snowflake", + engine="datacontract", + original_exception=e, + ) + + ### + ## Snowflake connection + ## https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect + ### + # gather connection parameters from environment variables + user_connect = os.environ.get("DATACONTRACT_SNOWFLAKE_USERNAME", None) + password_connect = os.environ.get("DATACONTRACT_SNOWFLAKE_PASSWORD", None) + account_connect = account + role_connect = os.environ.get("DATACONTRACT_SNOWFLAKE_ROLE", None) + authenticator_connect = ( + "externalbrowser" + if password_connect is None + else os.environ.get("DATACONTRACT_SNOWFLAKE_AUTHENTICATOR", "snowflake") + ) + warehouse_connect = os.environ.get("DATACONTRACT_SNOWFLAKE_WAREHOUSE", "COMPUTE_WH") + database_connect = databasename or "DEMO_DB" + schema_connect = schema or "PUBLIC" + snowflake_home = os.environ.get("DATACONTRACT_SNOWFLAKE_HOME") or os.environ.get("SNOWFLAKE_HOME") + snowflake_connections_file = os.environ.get("DATACONTRACT_SNOWFLAKE_CONNECTIONS_FILE") or os.environ.get( + "SNOWFLAKE_CONNECTIONS_FILE" + ) + if not snowflake_connections_file and snowflake_home: + snowflake_connections_file = os.path.join(snowflake_home, "connections.toml") + + default_connection = os.environ.get("DATACONTRACT_SNOWFLAKE_DEFAULT_CONNECTION_NAME") or os.environ.get( + "SNOWFLAKE_DEFAULT_CONNECTION_NAME" + ) + + private_key_file = os.environ.get("DATACONTRACT_SNOWFLAKE_PRIVATE_KEY_FILE") or os.environ.get( + "SNOWFLAKE_PRIVATE_KEY_FILE" + ) + private_key_file_pwd = os.environ.get("DATACONTRACT_SNOWFLAKE_PRIVATE_KEY_FILE_PWD") or os.environ.get( + "SNOWFLAKE_PRIVATE_KEY_FILE_PWD" + ) + + # build connection + if default_connection is not None and password_connect is None: + # use the default connection defined in the snowflake config file : connections.toml and config.toml + + # optional connection params (will override the defaults if set) + connection_params = {} + if default_connection: + connection_params["connection_name"] = default_connection + if snowflake_connections_file: + connection_params["connections_file_path"] = Path(snowflake_connections_file) + if role_connect: + connection_params["role"] = role_connect + if account_connect: + connection_params["account"] = account_connect + # don't override default connection with defaults set above + if database_connect and database_connect != "DEMO_DB": + connection_params["database"] = database_connect + if schema_connect and schema_connect != "PUBLIC": + connection_params["schema"] = schema_connect + if warehouse_connect and warehouse_connect != "COMPUTE_WH": + connection_params["warehouse"] = warehouse_connect + + conn = connect( + session_parameters={ + "QUERY_TAG": "datacontract-cli import", + "use_openssl_only": False, + }, + **connection_params, + ) + elif private_key_file is not None: + # use private key auth + if not os.path.exists(private_key_file): + raise FileNotFoundError(f"Private key file not found at: {private_key_file}") + + conn = connect( + user=user_connect, + account=account_connect, + private_key_file=private_key_file, + private_key_file_pwd=private_key_file_pwd, + session_parameters={ + "QUERY_TAG": "datacontract-cli import", + "use_openssl_only": False, + }, + warehouse=warehouse_connect, + role=role_connect, + database=database_connect, + schema=schema_connect, + ) + elif authenticator_connect == "externalbrowser": + # use external browser auth + conn = connect( + user=user_connect, + account=account_connect, + authenticator=authenticator_connect, + session_parameters={ + "QUERY_TAG": "datacontract-cli import", + "use_openssl_only": False, + }, + warehouse=warehouse_connect, + role=role_connect, + database=database_connect, + schema=schema_connect, + ) + else: + # use the login/password auth + conn = connect( + user=user_connect, + password=password_connect, + account=account_connect, + authenticator=authenticator_connect, + session_parameters={ + "QUERY_TAG": "datacontract-cli import", + "use_openssl_only": False, + }, + warehouse=warehouse_connect, + role=role_connect, + database=database_connect, + schema=schema_connect, + ) + return conn + + +def _get_ordinal_position_value(col: Dict[str, Any]) -> Any: + """Extract customProperties value where property == 'ordinalPosition'.""" + for cp in col.get("customProperties") or []: + if isinstance(cp, dict) and cp.get("property") == "ordinalPosition": + return cp.get("value") + return None + + +def sort_schema_by_name_properties_by_ordinalPosition(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + - Does NOT reorder payload['schema']. + - For each schema element (table), sorts table['properties'] by property ordinalPosition value. + - Does NOT sort customProperties themselves. + """ + schema = payload.get("schema") + if not isinstance(schema, list): + return payload + new_schema = [] + for table in schema: + props = table.get("properties") + if not isinstance(props, list): + continue + + def col_key(col: Dict[str, Any]): + ord_val = _get_ordinal_position_value(col) + return (ord_val, f"{table.get('name').lower()}.{col.get('name').lower()}") + + props.sort(key=col_key) + table["properties"] = props + new_schema.append(table) + + new_schema.sort(key=lambda t: t.get("name").lower()) + + payload["schema"] = new_schema + return payload diff --git a/tests/fixtures/snowflake/import/datacontract.yaml b/tests/fixtures/snowflake/import/datacontract.yaml new file mode 100644 index 000000000..cf7d5d132 --- /dev/null +++ b/tests/fixtures/snowflake/import/datacontract.yaml @@ -0,0 +1,181 @@ +version: 1.0.0 +kind: DataContract +apiVersion: v3.1.0 +id: my-data-contract +name: My Data Contract +status: draft +servers: +- server: snowflake + type: snowflake +schema: +- name: mytable + physicalType: table + logicalType: object + physicalName: mytable + properties: + - name: field_primary_key + physicalType: INT + description: Primary key + primaryKey: true + primaryKeyPosition: 1 + logicalType: integer + - name: field_not_null + physicalType: INT + description: Not null + logicalType: integer + required: true + - name: field_char + physicalType: CHAR(10) + description: Fixed-length string + logicalType: string + logicalTypeOptions: + maxLength: 10 + - name: field_varchar + physicalType: VARCHAR(100) + description: Variable-length string + logicalType: string + logicalTypeOptions: + maxLength: 100 + - name: field_text + physicalType: VARCHAR + description: Large variable-length string (alias for VARCHAR(16777216)) + logicalType: string + - name: field_string + physicalType: VARCHAR + description: Alias for VARCHAR(16777216) + logicalType: string + - name: field_nchar + physicalType: CHAR(10) + description: Fixed-length string (no separate NCHAR) + logicalType: string + logicalTypeOptions: + maxLength: 10 + - name: field_nvarchar + physicalType: VARCHAR(100) + description: Variable-length string (no separate NVARCHAR) + logicalType: string + logicalTypeOptions: + maxLength: 100 + - name: field_ntext + physicalType: VARCHAR + description: Large variable-length string + logicalType: string + - name: field_tinyint + physicalType: SMALLINT + description: Snowflake doesn't have TINYINT, use SMALLINT + logicalType: integer + - name: field_smallint + physicalType: SMALLINT + description: Integer (-32,768 to 32,767) + logicalType: integer + - name: field_int + physicalType: INT + description: Integer (-2.1B to 2.1B) + logicalType: integer + - name: field_bigint + physicalType: BIGINT + description: Large integer + logicalType: integer + - name: field_decimal + physicalType: DECIMAL(10, 2) + description: Fixed precision decimal + logicalType: number + logicalTypeOptions: + precision: 10 + scale: 2 + - name: field_numeric + physicalType: DECIMAL(10, 2) + description: Same as DECIMAL + logicalType: number + logicalTypeOptions: + precision: 10 + scale: 2 + - name: field_number + physicalType: DECIMAL(38, 0) + description: Default numeric type (more flexible than DECIMAL) + logicalType: number + logicalTypeOptions: + precision: 38 + scale: 0 + - name: field_double + physicalType: DOUBLE + description: Double precision floating-point (synonym for FLOAT) + logicalType: number + - name: field_float + physicalType: DOUBLE + description: Approximate floating-point + logicalType: number + - name: field_real + physicalType: DOUBLE + description: Snowflake doesn't have REAL, use DOUBLE as synonym of FLOAT + logicalType: number + - name: field_bit + physicalType: BOOLEAN + description: Boolean (TRUE/FALSE) + logicalType: boolean + - name: field_date + physicalType: DATE + description: Date only (YYYY-MM-DD) + logicalType: date + - name: field_time + physicalType: TIME + description: Time only (HH:MM:SS) + logicalType: string + - name: field_datetime2 + physicalType: TIMESTAMPNTZ + description: Timestamp without timezone + logicalType: date + - name: field_smalldatetime + physicalType: TIMESTAMPNTZ + description: Timestamp without timezone + logicalType: date + - name: field_datetimeoffset + physicalType: TIMESTAMPTZ + description: Timestamp with timezone + logicalType: date + - name: field_timestamp_ltz + physicalType: TIMESTAMPLTZ + description: Timestamp with local timezone + logicalType: date + - name: field_binary + physicalType: BINARY(16) + description: Fixed-length binary + logicalType: array + - name: field_varbinary + physicalType: BINARY(100) + description: Variable-length binary + logicalType: array + - name: field_uniqueidentifier + physicalType: VARCHAR(36) + description: GUID stored as string + logicalType: string + logicalTypeOptions: + maxLength: 36 + - name: field_xml + physicalType: VARIANT + description: Semi-structured data (XML as VARIANT) + logicalType: object + - name: field_json + physicalType: VARIANT + description: Semi-structured data (native JSON support) + logicalType: object + - name: field_object + physicalType: OBJECT + description: Semi-structured object (key-value pairs) + logicalType: object + - name: field_array + physicalType: ARRAY + description: Semi-structured array + logicalType: object + - name: field_geography + physicalType: GEOGRAPHY + description: Geospatial data (points, lines, polygons) + logicalType: object + - name: field_geometry + physicalType: GEOMETRY + description: Geospatial data (planar coordinates) + logicalType: object + - name: field_vector + physicalType: VECTOR(FLOAT, 16) + description: Vector data for ML/AI (16-dimensional float vector) + logicalType: object \ No newline at end of file diff --git a/tests/test_import_snowflake.py b/tests/test_import_snowflake.py new file mode 100644 index 000000000..26068cc4e --- /dev/null +++ b/tests/test_import_snowflake.py @@ -0,0 +1,232 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +import yaml +from dotenv import load_dotenv +from open_data_contract_standard.model import OpenDataContractStandard +from typer.testing import CliRunner + +from datacontract.cli import app +from datacontract.data_contract import DataContract +from datacontract.imports.snowflake_importer import import_Snowflake_from_connector +from datacontract.model.exceptions import DataContractException + +# logging.basicConfig(level=logging.INFO, force=True) +load_dotenv(override=True) + + +data_definition_file = "fixtures/snowflake/import/ddl.sql" + + +def test_cli(): + runner = CliRunner() + result = runner.invoke( + app, + [ + "import", + "--format", + "sql", + "--source", + data_definition_file, + "--dialect", + "snowflake", + ], + ) + assert result.exit_code == 0 + + +def test_cli_connection(): + with patch("datacontract.imports.snowflake_importer.import_Snowflake_from_connector") as mock_import: + mock_import.return_value = OpenDataContractStandard(id="test", kind="DataContract", apiVersion="v3.1.0") + runner = CliRunner() + result = runner.invoke( + app, + [ + "import", + "--format", + "snowflake", + "--source", + "test_account", + "--database", + "TEST_DB", + "--schema", + "TEST_SCHEMA", + ], + ) + assert result.exit_code == 0 + + +def test_import_sql_snowflake(): + result = DataContract.import_from_source("sql", data_definition_file, dialect="snowflake") + + print("Result:\n", result.to_yaml()) + with open("fixtures/snowflake/import/datacontract.yaml") as file: + expected = file.read() + assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) + + +def test_import_snowflake_from_connector_success(): + account = "test_account" + database = "TEST_DB" + schema = "TEST_SCHEMA" + + # Mock response from Snowflake query + # This JSON mimics the structure returned by the SQL query in snowflake_importer.py + mock_response_data = { + "apiVersion": "v3.1.0", + "kind": "DataContract", + "id": "test-id", + "name": "TEST_SCHEMA", + "version": "0.0.1", + "status": "development", + "schema": [ + { + "id": "table1_propId", + "name": "TABLE1", + "physicalName": "TEST_DB.TEST_SCHEMA.TABLE1", + "logicalType": "object", + "physicalType": "table", + "description": "Test table description", + "properties": [ + { + "id": "col1_propId", + "name": "COL1", + "logicalType": "string", + "physicalType": "VARCHAR(16777216)", + "required": False, + "unique": False, + "description": "Column description", + "customProperties": [ + {"property": "ordinalPosition", "value": 1}, + {"property": "scdType", "value": 1}, + ], + }, + { + "id": "col2_propId", + "name": "COL2", + "logicalType": "integer", + "physicalType": "NUMBER(38,0)", + "required": True, + "unique": False, + "customProperties": [ + {"property": "ordinalPosition", "value": 2}, + {"property": "scdType", "value": 1}, + ], + }, + ], + } + ], + } + + # The fetchall returns a list of tuples/lists, where the first element is the JSON string + mock_fetchall_result = [[json.dumps(mock_response_data)]] + + with patch("datacontract.imports.snowflake_importer.snowflake_cursor") as mock_cursor_func: + # Setup mocks + mock_conn = MagicMock() + mock_cursor = MagicMock() + + mock_cursor_func.return_value = mock_conn + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + + # Mock cursor attributes and methods + mock_cursor.sfqid = "mock_sfqid" + mock_cursor.fetchall.return_value = mock_fetchall_result + + # Run the function + result = import_Snowflake_from_connector(account, database, schema) + + # Verify the result + assert result.apiVersion == "v3.1.0" + assert result.kind == "DataContract" + assert result.name == "TEST_SCHEMA" + + assert len(result.schema_) == 1 + table = result.schema_[0] + assert table.name == "TABLE1" + assert table.physicalName == "TEST_DB.TEST_SCHEMA.TABLE1" + + assert len(table.properties) == 2 + assert table.properties[0].name == "COL1" + assert table.properties[1].name == "COL2" + + # Verify Snowflake interactions + mock_cursor.execute.assert_any_call(f"USE SCHEMA {database}.{schema}") + mock_cursor.execute.assert_any_call(f"SHOW COLUMNS IN SCHEMA {database}.{schema}") + mock_cursor.execute.assert_any_call(f"SHOW PRIMARY KEYS IN SCHEMA {database}.{schema}") + + assert mock_cursor.execute_async.called + args, _ = mock_cursor.execute_async.call_args + query = args[0] + assert "WITH Quality_Metric AS" in query + assert "Server_Roles AS " in query + assert f"WHERE T.table_schema = '{schema}'" in query + + mock_cursor.get_results_from_sfqid.assert_called_with("mock_sfqid") + + +def test_import_snowflake_from_connector_empty_result(): + account = "test_account" + database = "TEST_DB" + schema = "TEST_SCHEMA" + + # Empty result + mock_fetchall_result = [] + + with patch("datacontract.imports.snowflake_importer.snowflake_cursor") as mock_cursor_func: + mock_conn = MagicMock() + mock_cursor = MagicMock() + + mock_cursor_func.return_value = mock_conn + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + mock_cursor.sfqid = "mock_sfqid" + mock_cursor.fetchall.return_value = mock_fetchall_result + + with pytest.raises(DataContractException) as excinfo: + import_Snowflake_from_connector(account, database, schema) + + assert "No data contract returned" in str(excinfo.value) + + +# @pytest.mark.skipif(os.environ.get("DATACONTRACT_SNOWFLAKE_USERNAME") is None, reason="Requires DATACONTRACT_SNOWFLAKE_USERNAME to be set") +# def test_cli(): +# load_dotenv(override=True) +# # os.environ['DATACONTRACT_SNOWFLAKE_USERNAME'] = "xxx" +# # os.environ['DATACONTRACT_SNOWFLAKE_PASSWORD'] = "xxx" +# # os.environ['DATACONTRACT_SNOWFLAKE_ROLE'] = "xxx" +# # os.environ['DATACONTRACT_SNOWFLAKE_WAREHOUSE'] = "COMPUTE_WH" +# runner = CliRunner() +# result = runner.invoke( +# app, +# [ +# "import", +# "--format", +# "snowflake", +# "--source", +# "workspace.canada-central.azure", +# "--schema", +# "PUBLIC", +# "--database", +# "DEMO_DB" +# ], +# ) +# assert result.exit_code == 0 + +# @pytest.mark.skipif(os.environ.get("DATACONTRACT_SNOWFLAKE_USERNAME") is None, reason="Requires DATACONTRACT_SNOWFLAKE_USERNAME to be set") +# def test_import_source(): +# load_dotenv(override=True) +# # os.environ['DATACONTRACT_SNOWFLAKE_USERNAME'] = "xxx" +# # os.environ['DATACONTRACT_SNOWFLAKE_PASSWORD'] = "xxx" +# # os.environ['DATACONTRACT_SNOWFLAKE_ROLE'] = "xxx" +# # os.environ['DATACONTRACT_SNOWFLAKE_WAREHOUSE'] = "COMPUTE_WH" +# result = DataContract.import_source("snowflake", { +# "source": "workspace.canada-central.azure", +# "schema": "PUBLIC", +# "database": "DEMO_DB" +# }) + +# print("Result:\n", result.to_yaml()) +# with open("fixtures/snowflake/import/datacontract.yaml") as file: +# expected = file.read() +# assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) diff --git a/tests/test_test_snowflake.py b/tests/test_test_snowflake.py index 9cba635e7..b0fb697af 100644 --- a/tests/test_test_snowflake.py +++ b/tests/test_test_snowflake.py @@ -1,19 +1,29 @@ +import os + +import pytest +from dotenv import load_dotenv + +from datacontract.data_contract import DataContract + # logging.basicConfig(level=logging.INFO, force=True) +load_dotenv(override=True) datacontract = "fixtures/snowflake/datacontract.yaml" -# @pytest.mark.skipif(os.environ.get("DATACONTRACT_SNOWFLAKE_USERNAME") is None, reason="Requires DATACONTRACT_SNOWFLAKE_USERNAME to be set") -# def test_test_snowflake(): -# load_dotenv(override=True) -# # os.environ['DATACONTRACT_SNOWFLAKE_USERNAME'] = "xxx" -# # os.environ['DATACONTRACT_SNOWFLAKE_PASSWORD'] = "xxx" -# # os.environ['DATACONTRACT_SNOWFLAKE_ROLE'] = "xxx" -# # os.environ['DATACONTRACT_SNOWFLAKE_WAREHOUSE'] = "COMPUTE_WH" -# data_contract = DataContract(data_contract_file=datacontract) -# -# run = data_contract.test() -# -# print(run) -# assert run.result == "passed" -# assert all(check.result == "passed" for check in run.checks) +@pytest.mark.skipif( + os.environ.get("DATACONTRACT_SNOWFLAKE_USERNAME") is None, + reason="Requires DATACONTRACT_SNOWFLAKE_USERNAME to be set", +) +def test_test_snowflake(): + # os.environ['DATACONTRACT_SNOWFLAKE_USERNAME'] = "xxx" + # os.environ['DATACONTRACT_SNOWFLAKE_PASSWORD'] = "xxx" + # os.environ['DATACONTRACT_SNOWFLAKE_ROLE'] = "xxx" + # os.environ['DATACONTRACT_SNOWFLAKE_WAREHOUSE'] = "COMPUTE_WH" + data_contract = DataContract(data_contract_file=datacontract) + + run = data_contract.test() + + print(run) + assert run.result == "passed" + assert all(check.result == "passed" for check in run.checks)