Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5db828c

Browse files
authored
Merge branch 'master' into allow-dbt-selectors
2 parents 1eacf7f + cfd941f commit 5db828c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+7159
-456
lines changed

data_diff/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Sequence, Tuple, Iterator, Optional, Union
22

3-
from sqeleton.abcs import DbTime, DbPath
3+
from data_diff.sqeleton.abcs import DbTime, DbPath
44

55
from .tracking import disable_tracking
66
from .databases import connect

data_diff/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import rich
1111
import click
1212

13-
from sqeleton.schema import create_schema
14-
from sqeleton.queries.api import current_timestamp
13+
from data_diff.sqeleton.schema import create_schema
14+
from data_diff.sqeleton.queries.api import current_timestamp
1515

1616
from .dbt import dbt_diff
1717
from .utils import eval_name_template, remove_password_from_url, safezip, match_like

data_diff/cloud/data_source.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import time
23
from typing import List, Optional, Union, overload
34

@@ -50,14 +51,33 @@ def _validate_temp_schema(temp_schema: str):
5051
raise ValueError("Temporary schema should have a format <database>.<schema>")
5152

5253

54+
def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]:
55+
diff_vars = dbt_parser.get_datadiff_variables()
56+
config_prod_database = diff_vars.get("prod_database")
57+
config_prod_schema = diff_vars.get("prod_schema")
58+
if config_prod_database is not None and config_prod_schema is not None:
59+
temp_schema = f"{config_prod_database}.{config_prod_schema}"
60+
if db_type == "snowflake":
61+
return temp_schema.upper()
62+
elif db_type in {"pg", "postgres_aurora", "postgres_aws_rds", "redshift"}:
63+
return temp_schema.lower()
64+
return temp_schema
65+
return
66+
67+
5368
def create_ds_config(
5469
ds_config: TCloudApiDataSourceConfigSchema,
5570
data_source_name: str,
5671
dbt_parser: Optional[DbtParser] = None,
5772
) -> TDsConfig:
5873
options = _parse_ds_credentials(ds_config=ds_config, only_basic_settings=True, dbt_parser=dbt_parser)
5974

60-
temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
75+
temp_schema = _get_temp_schema(dbt_parser=dbt_parser, db_type=ds_config.db_type) if dbt_parser else None
76+
if temp_schema:
77+
temp_schema = TemporarySchemaPrompt.ask("Temporary schema", default=temp_schema)
78+
else:
79+
temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
80+
6181
float_tolerance = FloatPrompt.ask("Float tolerance", default=0.000001)
6282

6383
return TDsConfig(
@@ -92,6 +112,37 @@ def _cast_value(value: str, type_: str) -> Union[bool, int, str]:
92112
return value
93113

94114

115+
def _get_data_from_bigquery_json(path: str):
116+
with open(path, "r") as file:
117+
return json.load(file)
118+
119+
120+
def _align_dbt_cred_params_with_datafold_params(dbt_creds: dict) -> dict:
121+
db_type = dbt_creds["type"]
122+
if db_type == "bigquery":
123+
method = dbt_creds["method"]
124+
if method == "service-account":
125+
data = _get_data_from_bigquery_json(path=dbt_creds["keyfile"])
126+
dbt_creds["jsonKeyFile"] = json.dumps(data)
127+
elif method == "service-account-json":
128+
dbt_creds["jsonKeyFile"] = json.dumps(dbt_creds["keyfile_json"])
129+
else:
130+
rich.print(
131+
f'[red]Cannot extract bigquery credentials from dbt_project.yml for "{method}" type. '
132+
f"If you want to provide credentials via dbt_project.yml, "
133+
f'please, use "service-account" or "service-account-json" '
134+
f"(more in docs: https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup). "
135+
f"Otherwise, you can provide a path to a json key file or a json key file data as an input."
136+
)
137+
dbt_creds["projectId"] = dbt_creds["project"]
138+
elif db_type == "snowflake":
139+
dbt_creds["default_db"] = dbt_creds["database"]
140+
elif db_type == "databricks":
141+
dbt_creds["http_password"] = dbt_creds["token"]
142+
dbt_creds["database"] = dbt_creds.get("catalog")
143+
return dbt_creds
144+
145+
95146
def _parse_ds_credentials(
96147
ds_config: TCloudApiDataSourceConfigSchema, only_basic_settings: bool = True, dbt_parser: Optional[DbtParser] = None
97148
):
@@ -101,6 +152,7 @@ def _parse_ds_credentials(
101152
use_dbt_data = Confirm.ask("Would you like to extract database credentials from dbt profiles.yml?")
102153
try:
103154
creds = dbt_parser.get_connection_creds()[0]
155+
creds = _align_dbt_cred_params_with_datafold_params(dbt_creds=creds)
104156
except Exception as e:
105157
rich.print(f"[red]Cannot parse database credentials from dbt profiles.yml. Reason: {e}")
106158

data_diff/cloud/datafold_api.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import dataclasses
23
import enum
34
import time
@@ -6,6 +7,9 @@
67
import pydantic
78
import requests
89

10+
from ..utils import getLogger
11+
12+
logger = getLogger(__name__)
913

1014
Self = TypeVar("Self", bound=pydantic.BaseModel)
1115

@@ -97,6 +101,8 @@ class TCloudApiDataDiff(pydantic.BaseModel):
97101
table1: List[str]
98102
table2: List[str]
99103
pk_columns: List[str]
104+
filter1: Optional[str] = None
105+
filter2: Optional[str] = None
100106

101107

102108
class TSummaryResultPrimaryKeyStats(pydantic.BaseModel):
@@ -159,7 +165,7 @@ class TCloudDataSourceTestResult(pydantic.BaseModel):
159165
class TCloudApiDataSourceTestResult(pydantic.BaseModel):
160166
name: str
161167
status: str
162-
result: TCloudDataSourceTestResult
168+
result: Optional[TCloudDataSourceTestResult]
163169

164170

165171
@dataclasses.dataclass
@@ -191,7 +197,11 @@ def get_data_sources(self) -> List[TCloudApiDataSource]:
191197
return [TCloudApiDataSource(**item) for item in rv.json()]
192198

193199
def create_data_source(self, config: TDsConfig) -> TCloudApiDataSource:
194-
rv = self.make_post_request(url="api/v1/data_sources", payload=config.dict())
200+
payload = config.dict()
201+
if config.type == "bigquery":
202+
json_string = payload["options"]["jsonKeyFile"].encode("utf-8")
203+
payload["options"]["jsonKeyFile"] = base64.b64encode(json_string).decode("utf-8")
204+
rv = self.make_post_request(url="api/v1/data_sources", payload=payload)
195205
return TCloudApiDataSource(**rv.json())
196206

197207
def get_data_source_schema_config(
@@ -216,11 +226,12 @@ def poll_data_diff_results(self, diff_id: int) -> TCloudApiDataDiffSummaryResult
216226
summary_results = None
217227
start_time = time.monotonic()
218228
sleep_interval = 5 # starts at 5 sec
219-
max_sleep_interval = 60
229+
max_sleep_interval = 30
220230
max_wait_time = 300
221231

222232
diff_url = f"{self.host}/datadiffs/{diff_id}/overview"
223233
while not summary_results:
234+
logger.debug(f"Polling: {diff_url}")
224235
response = self.make_get_request(url=f"api/v1/datadiffs/{diff_id}/summary_results")
225236
response_json = response.json()
226237
if response_json["status"] == "success":
@@ -250,7 +261,9 @@ def check_data_source_test_results(self, job_id: int) -> List[TCloudApiDataSourc
250261
status=item["result"]["code"].lower(),
251262
message=item["result"]["message"],
252263
outcome=item["result"]["outcome"],
253-
),
264+
)
265+
if item["result"] is not None
266+
else None,
254267
)
255268
for item in rv.json()["results"]
256269
]

data_diff/databases/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqeleton.databases import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError
1+
from data_diff.sqeleton.databases import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError
22

33
from .postgresql import PostgreSQL
44
from .mysql import MySQL

data_diff/databases/_connect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from sqeleton.databases import Connect
3+
from data_diff.sqeleton.databases import Connect
44

55
from .postgresql import PostgreSQL
66
from .mysql import MySQL

data_diff/databases/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
1+
from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
22

33

44
class DatadiffDialect(AbstractMixin_MD5, AbstractMixin_NormalizeValue):

data_diff/databases/bigquery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqeleton.databases import bigquery
1+
from data_diff.sqeleton.databases import bigquery
22
from .base import DatadiffDialect
33

44

data_diff/databases/clickhouse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqeleton.databases import clickhouse
1+
from data_diff.sqeleton.databases import clickhouse
22
from .base import DatadiffDialect
33

44

data_diff/databases/databricks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqeleton.databases import databricks
1+
from data_diff.sqeleton.databases import databricks
22
from .base import DatadiffDialect
33

44

0 commit comments

Comments
 (0)