Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 2e523c1

Browse files
dlawinerezsh
authored andcommitted
squash: add get_stats method
1 parent 01abf3a commit 2e523c1

File tree

5 files changed

+133
-66
lines changed

5 files changed

+133
-66
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,6 @@ benchmark_*.png
141141

142142
# IntelliJ
143143
.idea
144+
145+
# VSCode
146+
.vscode

data_diff/__main__.py

Lines changed: 6 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -403,58 +403,19 @@ def _main(
403403
]
404404

405405
diff_iter = differ.diff_tables(*segments)
406-
info = diff_iter.info_tree.info
407406

408407
if limit:
409408
diff_iter = islice(diff_iter, int(limit))
410409

411410
if stats:
412-
diff = list(diff_iter)
413-
key_columns_len = len(key_columns)
414-
415-
diff_by_key = {}
416-
for sign, values in diff:
417-
k = values[:key_columns_len]
418-
if k in diff_by_key:
419-
assert sign != diff_by_key[k]
420-
diff_by_key[k] = "!"
421-
else:
422-
diff_by_key[k] = sign
423-
424-
diff_by_sign = {k: 0 for k in "+-!"}
425-
for sign in diff_by_key.values():
426-
diff_by_sign[sign] += 1
427-
428-
table1_count = info.rowcounts[1]
429-
table2_count = info.rowcounts[2]
430-
unchanged = table1_count - diff_by_sign["-"] - diff_by_sign["!"]
431-
diff_percent = 1 - unchanged / max(table1_count, table2_count)
432-
411+
# required to create this variable before get_stats
412+
diff_list = list(diff_iter)
413+
stats_output = diff_iter.get_stats()
433414
if json_output:
434-
json_output = {
435-
"rows_A": table1_count,
436-
"rows_B": table2_count,
437-
"exclusive_A": diff_by_sign["-"],
438-
"exclusive_B": diff_by_sign["+"],
439-
"updated": diff_by_sign["!"],
440-
"unchanged": unchanged,
441-
"total": sum(diff_by_sign.values()),
442-
"stats": differ.stats,
443-
}
444-
rich.print_json(json.dumps(json_output))
415+
rich.print(json.dumps(stats_output[0]))
445416
else:
446-
rich.print(f"{table1_count} rows in table A")
447-
rich.print(f"{table2_count} rows in table B")
448-
rich.print(f"{diff_by_sign['-']} rows exclusive to table A (not present in B)")
449-
rich.print(f"{diff_by_sign['+']} rows exclusive to table B (not present in A)")
450-
rich.print(f"{diff_by_sign['!']} rows updated")
451-
rich.print(f"{unchanged} rows unchanged")
452-
rich.print(f"{100*diff_percent:.2f}% difference score")
453-
454-
if differ.stats:
455-
print("\nExtra-Info:")
456-
for k, v in sorted(differ.stats.items()):
457-
rich.print(f" {k} = {v}")
417+
rich.print(stats_output[1])
418+
458419
else:
459420
for op, values in diff_iter:
460421
color = COLOR_SCHEME[op]

data_diff/diff_tables.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from concurrent.futures import ThreadPoolExecutor, as_completed
1212

1313
from runtype import dataclass
14+
from dataclasses import field
1415

1516
from data_diff.info_tree import InfoTree, SegmentInfo
1617

@@ -82,10 +83,66 @@ def _run_in_background(self, *funcs):
8283
class DiffResultWrapper:
8384
diff: iter # DiffResult
8485
info_tree: InfoTree
86+
stats: dict
87+
result_list: list = field(default_factory=list)
8588

8689
def __iter__(self):
87-
return iter(self.diff)
90+
for i in self.diff:
91+
self.result_list.append(i)
92+
yield i
93+
94+
def get_stats(self):
95+
96+
diff_by_key = {}
97+
if len(self.result_list) > 0:
98+
for sign, values in self.result_list:
99+
k = values[: len(self.info_tree.info.tables[0].key_columns)]
100+
if k in diff_by_key:
101+
assert sign != diff_by_key[k]
102+
diff_by_key[k] = "!"
103+
else:
104+
diff_by_key[k] = sign
105+
106+
diff_by_sign = {k: 0 for k in "+-!"}
107+
for sign in diff_by_key.values():
108+
diff_by_sign[sign] += 1
109+
110+
table1_count = self.info_tree.info.rowcounts[1]
111+
table2_count = self.info_tree.info.rowcounts[2]
112+
unchanged = table1_count - diff_by_sign["-"] - diff_by_sign["!"]
113+
diff_percent = 1 - unchanged / max(table1_count, table2_count)
114+
115+
116+
json_output = {
117+
"rows_A": table1_count,
118+
"rows_B": table2_count,
119+
"exclusive_A": diff_by_sign["-"],
120+
"exclusive_B": diff_by_sign["+"],
121+
"updated": diff_by_sign["!"],
122+
"unchanged": unchanged,
123+
"total": sum(diff_by_sign.values()),
124+
"stats": self.stats,
125+
}
126+
127+
string_output = ""
128+
string_output += f"{table1_count} rows in table A\n"
129+
string_output += f"{table2_count} rows in table B\n"
130+
string_output += f"{diff_by_sign['-']} rows exclusive to table A (not present in B)\n"
131+
string_output += f"{diff_by_sign['+']} rows exclusive to table B (not present in A)\n"
132+
string_output += f"{diff_by_sign['!']} rows updated\n"
133+
string_output += f"{unchanged} rows unchanged\n"
134+
string_output += f"{100*diff_percent:.2f}% difference score\n"
135+
136+
if self.stats:
137+
string_output += "\nExtra-Info:\n"
138+
for k, v in sorted(self.stats.items()):
139+
string_output += f" {k} = {v}\n"
140+
else:
141+
raise RuntimeError(
142+
"result_list is empty, consume the diff iterator to populate values: e.g. \ndiff_iter = diff_tables(...) \ndiff_list = list(diff_iter) \ndiff_iter.print_stats(json_output)"
143+
)
88144

145+
return json_output, string_output
89146

90147
class TableDiffer(ThreadBase, ABC):
91148
bisection_factor = 32
@@ -106,7 +163,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf
106163
"""
107164
if info_tree is None:
108165
info_tree = InfoTree(SegmentInfo([table1, table2]))
109-
return DiffResultWrapper(self._diff_tables_wrapper(table1, table2, info_tree), info_tree)
166+
return DiffResultWrapper(self._diff_tables_wrapper(table1, table2, info_tree), info_tree, self.stats, [])
110167

111168
def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult:
112169
if is_tracking_enabled():
@@ -177,6 +234,8 @@ def _bisect_and_diff_tables(self, table1, table2, info_tree):
177234
raise NotImplementedError("Composite key not supported yet!")
178235
if len(table2.key_columns) > 1:
179236
raise NotImplementedError("Composite key not supported yet!")
237+
if len(table1.key_columns) != len(table2.key_columns):
238+
raise ValueError("Tables should have an equivalent number of key columns!")
180239
(key1,) = table1.key_columns
181240
(key2,) = table2.key_columns
182241

tests/test_api.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import unittest
2+
import io
3+
import unittest.mock
24
import arrow
35
from datetime import datetime
46

57
from data_diff import diff_tables, connect_to_table
68
from data_diff.databases import MySQL
79
from data_diff.sqeleton.queries import table, commit
810

9-
from .common import TEST_MYSQL_CONN_STRING, get_conn
11+
from .common import TEST_MYSQL_CONN_STRING, get_conn, random_table_suffix
1012

1113

1214
def _commit(conn):
@@ -16,16 +18,17 @@ def _commit(conn):
1618
class TestApi(unittest.TestCase):
1719
def setUp(self) -> None:
1820
self.conn = get_conn(MySQL)
19-
table_src_name = "test_api"
20-
table_dst_name = "test_api_2"
21+
suffix = random_table_suffix()
22+
self.table_src_name = f"test_api{suffix}"
23+
self.table_dst_name = f"test_api_2{suffix}"
2124

22-
self.table_src = table(table_src_name)
23-
self.table_dst = table(table_dst_name)
25+
self.table_src = table(self.table_src_name)
26+
self.table_dst = table(self.table_dst_name)
2427

2528
self.conn.query(self.table_src.drop(True))
2629
self.conn.query(self.table_dst.drop(True))
2730

28-
src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str})
31+
src_table = table(self.table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str})
2932
self.conn.query(src_table.create())
3033
self.now = now = arrow.get()
3134

@@ -53,8 +56,8 @@ def tearDown(self) -> None:
5356
return super().tearDown()
5457

5558
def test_api(self):
56-
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, "test_api")
57-
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, ("test_api_2",))
59+
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
60+
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, (self.table_dst_name,))
5861
diff = list(diff_tables(t1, t2))
5962
assert len(diff) == 1
6063

@@ -65,10 +68,38 @@ def test_api(self):
6568
diff_id = diff[0][1][0]
6669
where = f"id != {diff_id}"
6770

68-
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, "test_api", where=where)
69-
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, "test_api_2", where=where)
71+
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name, where=where)
72+
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_dst_name, where=where)
7073
diff = list(diff_tables(t1, t2))
7174
assert len(diff) == 0
7275

7376
t1.database.close()
7477
t2.database.close()
78+
79+
def test_api_get_stats(self):
80+
expected_string = "5 rows in table A\n4 rows in table B\n1 rows exclusive to table A (not present in B)\n0 rows exclusive to table B (not present in A)\n0 rows updated\n4 rows unchanged\n20.00% difference score\n\nExtra-Info:\n rows_downloaded = 5\n"
81+
expected_dict = {'rows_A': 5, 'rows_B': 4, 'exclusive_A': 1, 'exclusive_B': 0, 'updated': 0, 'unchanged': 4, 'total': 1, 'stats': {'rows_downloaded': 5}}
82+
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
83+
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_dst_name)
84+
diff = diff_tables(t1, t2)
85+
diff_list = list(diff)
86+
output = diff.get_stats()
87+
88+
self.assertEqual(expected_dict, output[0])
89+
self.assertEqual(expected_string, output[1])
90+
self.assertIsNotNone(diff)
91+
assert len(diff_list) == 1
92+
93+
t1.database.close()
94+
t2.database.close()
95+
96+
def test_api_print_error(self):
97+
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
98+
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, (self.table_dst_name,))
99+
diff = diff_tables(t1, t2)
100+
101+
with self.assertRaises(RuntimeError):
102+
diff.get_stats()
103+
104+
t1.database.close()
105+
t2.database.close()

tests/test_cli.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from data_diff.databases import MySQL
99
from data_diff.sqeleton.queries import table, commit
1010

11-
from .common import TEST_MYSQL_CONN_STRING, get_conn
11+
from .common import TEST_MYSQL_CONN_STRING, get_conn, random_table_suffix
1212

1313

1414
def _commit(conn):
@@ -30,15 +30,16 @@ class TestCLI(unittest.TestCase):
3030
def setUp(self) -> None:
3131
self.conn = get_conn(MySQL)
3232

33-
table_src_name = "test_cli"
34-
table_dst_name = "test_cli_2"
33+
suffix = random_table_suffix()
34+
self.table_src_name = f"test_api{suffix}"
35+
self.table_dst_name = f"test_api_2{suffix}"
3536

36-
self.table_src = table(table_src_name)
37-
self.table_dst = table(table_dst_name)
37+
self.table_src = table(self.table_src_name)
38+
self.table_dst = table(self.table_dst_name)
3839
self.conn.query(self.table_src.drop(True))
3940
self.conn.query(self.table_dst.drop(True))
4041

41-
src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str})
42+
src_table = table(self.table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str})
4243
self.conn.query(src_table.create())
4344
self.conn.query("SET @@session.time_zone='+00:00'")
4445
now = self.conn.query("select now()", datetime)
@@ -67,15 +68,15 @@ def tearDown(self) -> None:
6768
return super().tearDown()
6869

6970
def test_basic(self):
70-
diff = run_datadiff_cli(TEST_MYSQL_CONN_STRING, "test_cli", TEST_MYSQL_CONN_STRING, "test_cli_2")
71+
diff = run_datadiff_cli(TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_dst_name)
7172
assert len(diff) == 1
7273

7374
def test_options(self):
7475
diff = run_datadiff_cli(
7576
TEST_MYSQL_CONN_STRING,
76-
"test_cli",
77+
self.table_src_name,
7778
TEST_MYSQL_CONN_STRING,
78-
"test_cli_2",
79+
self.table_dst_name,
7980
"--bisection-factor",
8081
"16",
8182
"--bisection-threshold",
@@ -88,3 +89,15 @@ def test_options(self):
8889
"1h",
8990
)
9091
assert len(diff) == 1
92+
93+
def test_stats(self):
94+
diff_output = run_datadiff_cli(
95+
TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_dst_name, "-s"
96+
)
97+
assert len(diff_output) == 11
98+
99+
def test_stats_json(self):
100+
diff_output = run_datadiff_cli(
101+
TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_dst_name, "-s", "--json"
102+
)
103+
assert len(diff_output) == 2

0 commit comments

Comments
 (0)