Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit b49b93a

Browse files
committed
support combo pks in --dbt local_diff
1 parent d2d7849 commit b49b93a

File tree

2 files changed

+14
-25
lines changed

2 files changed

+14
-25
lines changed

data_diff/dbt.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,16 @@ def dbt_diff(
7373

7474
if is_cloud and len(diff_vars.primary_keys) > 0:
7575
_cloud_diff(diff_vars)
76-
elif is_cloud:
77-
rich.print(
78-
"[red]"
79-
+ ".".join(diff_vars.prod_path)
80-
+ " <> "
81-
+ ".".join(diff_vars.dev_path)
82-
+ "[/] \n"
83-
+ "Skipped due to missing primary-key tag\n"
84-
)
85-
86-
if not is_cloud and len(diff_vars.primary_keys) == 1:
76+
elif not is_cloud and len(diff_vars.primary_keys) > 0:
8777
_local_diff(diff_vars)
88-
elif not is_cloud:
78+
else:
8979
rich.print(
9080
"[red]"
9181
+ ".".join(diff_vars.prod_path)
9282
+ " <> "
9383
+ ".".join(diff_vars.dev_path)
9484
+ "[/] \n"
95-
+ "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs)\n"
85+
+ "Skipped due to missing primary-key tag\n"
9686
)
9787

9888
rich.print("Diffs Complete!")
@@ -127,10 +117,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
127117
column_diffs_str = ""
128118
dev_qualified_string = ".".join(diff_vars.dev_path)
129119
prod_qualified_string = ".".join(diff_vars.prod_path)
130-
primary_key = diff_vars.primary_keys[0]
131120

132-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, primary_key)
133-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, primary_key)
121+
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys))
122+
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys))
134123

135124
table1_columns = list(table1.get_schema())
136125
try:
@@ -159,7 +148,7 @@ def _local_diff(diff_vars: DiffVars) -> None:
159148
if table2_set_diff:
160149
column_diffs_str += "Column(s) removed: " + str(table2_set_diff) + "\n"
161150

162-
mutual_set.discard(primary_key)
151+
mutual_set = mutual_set - set(diff_vars.primary_keys)
163152
extra_columns = tuple(mutual_set)
164153

165154
diff = diff_tables(table1, table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=extra_columns)

tests/test_dbt.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -351,17 +351,17 @@ def test_local_diff(self, mock_diff_tables):
351351
mock_diff.__iter__.return_value = [1, 2, 3]
352352
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
353353
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
354-
expected_key = "key"
355-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, [expected_key], None, mock_connection)
354+
expected_keys = ["key"]
355+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
356356
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
357357
_local_diff(diff_vars)
358358

359359
mock_diff_tables.assert_called_once_with(
360360
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=tuple(column_set)
361361
)
362362
self.assertEqual(mock_connect.call_count, 2)
363-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), expected_key)
364-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), expected_key)
363+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
364+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
365365
mock_diff.get_stats_string.assert_called_once()
366366

367367
@patch("data_diff.dbt.diff_tables")
@@ -377,17 +377,17 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
377377
mock_diff.__iter__.return_value = []
378378
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
379379
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
380-
expected_key = "primary_key_column"
381-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, [expected_key], None, mock_connection)
380+
expected_keys = ["primary_key_column"]
381+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, None, mock_connection)
382382
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
383383
_local_diff(diff_vars)
384384

385385
mock_diff_tables.assert_called_once_with(
386386
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=tuple(column_set)
387387
)
388388
self.assertEqual(mock_connect.call_count, 2)
389-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), expected_key)
390-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), expected_key)
389+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys))
390+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys))
391391
mock_diff.get_stats_string.assert_not_called()
392392

393393
@patch("data_diff.dbt.rich.print")

0 commit comments

Comments
 (0)