From 5be338e354b4c6f843730c6ee648d05f2407589b Mon Sep 17 00:00:00 2001 From: Jake Bromberg Date: Mon, 9 Mar 2026 23:45:54 -0700 Subject: [PATCH 1/4] test: add pure unit tests for utilities and main() functions Cover matching.is_compilation_artist, artist_splitting edge cases (_split_trailing_and single-component, _try_ampersand_split, numeric comma guard), csv_to_tsv.main(), fix_csv_newlines.fix_csv_dir/main(), filter_csv edge cases and main(), verify_cache.format_bytes/print_report, and export_to_sqlite.format_size/_do_export table/FTS verification. --- tests/unit/test_artist_splitting.py | 68 ++++++++++- tests/unit/test_csv_to_tsv.py | 30 +++++ tests/unit/test_export.py | 89 ++++++++++++++ tests/unit/test_filter_csv.py | 183 ++++++++++++++++++++++++++++ tests/unit/test_fix_csv_newlines.py | 80 ++++++++++++ tests/unit/test_matching.py | 43 +++++++ tests/unit/test_verify_cache.py | 134 ++++++++++++++++++++ 7 files changed, 626 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_matching.py diff --git a/tests/unit/test_artist_splitting.py b/tests/unit/test_artist_splitting.py index 73b98b9..e5f9605 100644 --- a/tests/unit/test_artist_splitting.py +++ b/tests/unit/test_artist_splitting.py @@ -4,7 +4,12 @@ import pytest -from lib.artist_splitting import split_artist_name, split_artist_name_contextual +from lib.artist_splitting import ( + _split_trailing_and, + _try_ampersand_split, + split_artist_name, + split_artist_name_contextual, +) # --------------------------------------------------------------------------- # split_artist_name (context-free) @@ -159,3 +164,64 @@ def test_band_names_not_split(self, name: str) -> None: """Common band name patterns with 'and'/'with' should never be split.""" known = {"sly", "andy human", "my life", "nurse"} assert split_artist_name_contextual(name, known) == [] + + +# --------------------------------------------------------------------------- +# _comma_guard (numeric guard) +# --------------------------------------------------------------------------- + + +class TestCommaGuardNumeric: + """The numeric guard in comma splitting prevents splitting artist names with numbers.""" + + @pytest.mark.parametrize( + "name", + [ + "10,000 Maniacs", + "808,303", + "1,000 Homo DJs", + ], + ids=["10000-maniacs", "808-303", "1000-homo-djs"], + ) + def test_numeric_components_block_split(self, name: str) -> None: + assert split_artist_name(name) == [] + + +# --------------------------------------------------------------------------- +# _split_trailing_and (single-component input) +# --------------------------------------------------------------------------- + + +class TestSplitTrailingAnd: + """Direct tests for _split_trailing_and edge cases.""" + + def test_single_component_returns_as_is(self) -> None: + assert _split_trailing_and(["Autechre"]) == ["Autechre"] + + def test_empty_list_returns_as_is(self) -> None: + assert _split_trailing_and([]) == [] + + def test_two_components_with_trailing_and(self) -> None: + assert _split_trailing_and(["Emerson", "and Palmer"]) == ["Emerson", "Palmer"] + + def test_two_components_without_trailing_and(self) -> None: + assert _split_trailing_and(["Emerson", "Palmer"]) == ["Emerson", "Palmer"] + + +# --------------------------------------------------------------------------- +# _try_ampersand_split (edge cases) +# --------------------------------------------------------------------------- + + +class TestTryAmpersandSplit: + """Direct tests for _try_ampersand_split edge cases.""" + + def test_no_ampersand_returns_none(self) -> None: + assert _try_ampersand_split("Autechre", {"autechre"}) is None + + def test_no_known_artist_returns_none(self) -> None: + assert _try_ampersand_split("Simon & Garfunkel", set()) is None + + def test_single_char_components_rejected(self) -> None: + """When all components after filtering are too short, returns None.""" + assert _try_ampersand_split("A & B", {"a"}) is None diff --git a/tests/unit/test_csv_to_tsv.py b/tests/unit/test_csv_to_tsv.py index 1c7c592..52ec62e 100644 --- a/tests/unit/test_csv_to_tsv.py +++ b/tests/unit/test_csv_to_tsv.py @@ -15,6 +15,7 @@ _spec.loader.exec_module(_ct) convert = _ct.convert +main = _ct.main class TestConvert: @@ -90,3 +91,32 @@ def test_empty_csv(self, tmp_path: Path) -> None: assert count == 0 lines = tsv_file.read_text().splitlines() assert lines == ["h"] + + +class TestMain: + """Tests for the main() entry point.""" + + def test_wrong_arg_count_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["csv_to_tsv.py"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_too_many_args_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["csv_to_tsv.py", "a", "b", "c"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_happy_path(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + csv_file = tmp_path / "input.csv" + tsv_file = tmp_path / "output.tsv" + csv_file.write_text("name,value\nalpha,1\nbeta,2\n") + + monkeypatch.setattr("sys.argv", ["csv_to_tsv.py", str(csv_file), str(tsv_file)]) + main() + + lines = tsv_file.read_text().splitlines() + assert lines[0] == "name\tvalue" + assert lines[1] == "alpha\t1" + assert lines[2] == "beta\t2" diff --git a/tests/unit/test_export.py b/tests/unit/test_export.py index a357d54..2aabeab 100644 --- a/tests/unit/test_export.py +++ b/tests/unit/test_export.py @@ -13,6 +13,7 @@ _mod = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(_mod) _do_export = _mod._do_export +format_size = _mod.format_size _OUTPUT_PATH_ATTR = "OUTPUT_PATH" @@ -168,3 +169,91 @@ def test_mixed_rows_with_and_without_alternate(self): assert len(results) == 2 assert results[0] == (1, "Plug") assert results[1] == (2, None) + + def test_creates_correct_tables_and_row_count(self): + """Verify the exported database has the correct tables, FTS index, and row count.""" + rows = [ + { + "id": "1", + "title": "DOGA", + "artist": "Juana Molina", + "call_letters": "M", + "artist_call_number": "42", + "release_call_number": "1", + "genre": "Rock", + "format": "LP", + "alternate_artist_name": None, + }, + { + "id": "2", + "title": "Aluminum Tunes", + "artist": "Stereolab", + "call_letters": "S", + "artist_call_number": "88", + "release_call_number": "1", + "genre": "Rock", + "format": "CD", + "alternate_artist_name": None, + }, + { + "id": "3", + "title": "Moon Pix", + "artist": "Cat Power", + "call_letters": "C", + "artist_call_number": "7", + "release_call_number": "1", + "genre": "Rock", + "format": "LP", + "alternate_artist_name": None, + }, + ] + _do_export(rows) + + conn = sqlite3.connect(self.output_path) + + # Verify table exists + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='library'" + ) + assert cursor.fetchone() is not None + + # Verify FTS table exists + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='library_fts'" + ) + assert cursor.fetchone() is not None + + # Verify row count + cursor = conn.execute("SELECT COUNT(*) FROM library") + assert cursor.fetchone()[0] == 3 + + # Verify FTS works + cursor = conn.execute(""" + SELECT l.id FROM library l + JOIN library_fts fts ON l.id = fts.rowid + WHERE library_fts MATCH 'Stereolab' + """) + results = cursor.fetchall() + assert len(results) == 1 + assert results[0][0] == 2 + + conn.close() + + +class TestFormatSize: + """Test human-readable size formatting.""" + + @pytest.mark.parametrize( + "size_bytes, expected", + [ + (0, "0.0 B"), + (1023, "1023.0 B"), + (1024, "1.0 KB"), + (1048576, "1.0 MB"), + (1073741824, "1.0 GB"), + (1099511627776, "1.0 TB"), + ], + ids=["zero", "bytes", "kilobytes", "megabytes", "gigabytes", "terabytes"], + ) + def test_format_size(self, size_bytes: int, expected: str) -> None: + assert format_size(size_bytes) == expected diff --git a/tests/unit/test_filter_csv.py b/tests/unit/test_filter_csv.py index 50c686c..df22e2a 100644 --- a/tests/unit/test_filter_csv.py +++ b/tests/unit/test_filter_csv.py @@ -20,6 +20,7 @@ find_matching_release_ids = _fc.find_matching_release_ids filter_csv_file = _fc.filter_csv_file get_release_id_column = _fc.get_release_id_column +main = _fc.main FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" @@ -246,3 +247,185 @@ def test_missing_id_column_raises_clear_error(self, tmp_path: Path) -> None: with pytest.raises(ValueError, match="Column 'nonexistent'.*not found"): filter_csv_file(input_path, output_path, {1001}, "nonexistent") + + def test_row_with_invalid_release_id_skipped(self, tmp_path: Path) -> None: + """Rows where the release_id is not a valid integer are silently skipped.""" + csv_path = tmp_path / "release.csv" + output_path = tmp_path / "out.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["id", "title"]) + writer.writerow(["abc", "Bad ID"]) + writer.writerow(["1001", "Good ID"]) + + input_count, output_count = filter_csv_file(csv_path, output_path, {1001}, "id") + assert input_count == 2 + assert output_count == 1 + + def test_short_row_skipped(self, tmp_path: Path) -> None: + """Rows shorter than expected (IndexError on id column) are silently skipped.""" + csv_path = tmp_path / "release.csv" + output_path = tmp_path / "out.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["id", "title", "country"]) + # Normal row + writer.writerow(["1001", "DOGA", "AR"]) + # Write a short row manually (fewer columns than header) + f.write('"short"\n') + + input_count, output_count = filter_csv_file(csv_path, output_path, {1001}, "id") + assert input_count == 2 + assert output_count == 1 + + +class TestFindMatchingReleaseIdsEdgeCases: + """Edge cases for find_matching_release_ids.""" + + def test_short_row_skipped(self, tmp_path: Path) -> None: + """Rows missing the artist_name column are silently skipped.""" + csv_path = tmp_path / "release_artist.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["release_id", "artist_id", "artist_name", "extra", "anv", "position"]) + writer.writerow(["1001", "101", "Juana Molina", "0", "", "1"]) + # Short row missing artist_name + f.write('"2001","201"\n') + + ids = find_matching_release_ids(csv_path, {"juana molina"}) + assert ids == {1001} + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + + +class TestMain: + """Tests for the main() entry point.""" + + def test_wrong_arg_count_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["filter_csv.py"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_missing_library_artists_exits( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr( + "sys.argv", + [ + "filter_csv.py", + str(tmp_path / "nonexistent.txt"), + str(tmp_path), + str(tmp_path / "out"), + ], + ) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_missing_csv_dir_exits(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + artists_file = tmp_path / "artists.txt" + artists_file.write_text("Juana Molina\n") + monkeypatch.setattr( + "sys.argv", + ["filter_csv.py", str(artists_file), str(tmp_path / "nope"), str(tmp_path / "out")], + ) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_missing_release_artist_csv_exits( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + artists_file = tmp_path / "artists.txt" + artists_file.write_text("Juana Molina\n") + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + monkeypatch.setattr( + "sys.argv", + ["filter_csv.py", str(artists_file), str(csv_dir), str(tmp_path / "out")], + ) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_no_matches_exits(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + artists_file = tmp_path / "artists.txt" + artists_file.write_text("Nonexistent Artist XYZ\n") + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + with open(csv_dir / "release_artist.csv", "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["release_id", "artist_id", "artist_name", "extra", "anv", "position"]) + writer.writerow(["1001", "101", "Juana Molina", "0", "", "1"]) + + monkeypatch.setattr( + "sys.argv", + ["filter_csv.py", str(artists_file), str(csv_dir), str(tmp_path / "out")], + ) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_happy_path(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + artists_file = tmp_path / "artists.txt" + artists_file.write_text("Juana Molina\n") + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + out_dir = tmp_path / "out" + + with open(csv_dir / "release_artist.csv", "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["release_id", "artist_id", "artist_name", "extra", "anv", "position"]) + writer.writerow(["5001", "101", "Juana Molina", "0", "", "1"]) + writer.writerow(["5002", "102", "Stereolab", "0", "", "1"]) + + with open(csv_dir / "release.csv", "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "id", + "status", + "title", + "country", + "released", + "notes", + "data_quality", + "master_id", + "format", + ] + ) + writer.writerow( + ["5001", "Accepted", "DOGA", "AR", "2024-05-10", "", "Correct", "8001", "LP"] + ) + writer.writerow( + [ + "5002", + "Accepted", + "Aluminum Tunes", + "UK", + "1998-09-01", + "", + "Correct", + "8002", + "CD", + ] + ) + + monkeypatch.setattr( + "sys.argv", + ["filter_csv.py", str(artists_file), str(csv_dir), str(out_dir)], + ) + main() + + assert out_dir.exists() + with open(out_dir / "release.csv") as f: + reader = csv.DictReader(f) + rows = list(reader) + assert len(rows) == 1 + assert rows[0]["id"] == "5001" diff --git a/tests/unit/test_fix_csv_newlines.py b/tests/unit/test_fix_csv_newlines.py index df68e8b..772f665 100644 --- a/tests/unit/test_fix_csv_newlines.py +++ b/tests/unit/test_fix_csv_newlines.py @@ -16,6 +16,8 @@ _spec.loader.exec_module(_fn) fix_csv = _fn.fix_csv +fix_csv_dir = _fn.fix_csv_dir +main = _fn.main class TestFixCsv: @@ -101,3 +103,81 @@ def test_edge_cases(self, tmp_path: Path, field: str, expected: str) -> None: rows = self._read_csv(output_path) assert rows[0][0] == expected + + +class TestFixCsvDir: + """Batch-processing a directory of CSV files.""" + + def _write_csv(self, path: Path, headers: list[str], rows: list[list[str]]) -> None: + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(headers) + writer.writerows(rows) + + def test_processes_all_csv_files(self, tmp_path: Path) -> None: + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + input_dir.mkdir() + + self._write_csv(input_dir / "a.csv", ["text"], [["line\none"]]) + self._write_csv(input_dir / "b.csv", ["text"], [["hello\nworld"]]) + + fix_csv_dir(input_dir, output_dir) + + assert (output_dir / "a.csv").exists() + assert (output_dir / "b.csv").exists() + + with open(output_dir / "a.csv") as f: + reader = csv.reader(f) + next(reader) + assert next(reader)[0] == "line one" + + with open(output_dir / "b.csv") as f: + reader = csv.reader(f) + next(reader) + assert next(reader)[0] == "hello world" + + def test_empty_dir_logs_warning(self, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: + import logging + + input_dir = tmp_path / "empty" + output_dir = tmp_path / "output" + input_dir.mkdir() + + with caplog.at_level(logging.WARNING): + fix_csv_dir(input_dir, output_dir) + + assert any("No .csv files found" in r.message for r in caplog.records) + + +class TestMain: + """Tests for the main() entry point.""" + + def test_wrong_arg_count_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["fix_csv_newlines.py"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_too_many_args_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["fix_csv_newlines.py", "a", "b", "c"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_happy_path(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + input_path = tmp_path / "input.csv" + output_path = tmp_path / "output.csv" + + with open(input_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["text"]) + writer.writerow(["line\none"]) + + monkeypatch.setattr("sys.argv", ["fix_csv_newlines.py", str(input_path), str(output_path)]) + main() + + with open(output_path) as f: + reader = csv.reader(f) + next(reader) + assert next(reader)[0] == "line one" diff --git a/tests/unit/test_matching.py b/tests/unit/test_matching.py new file mode 100644 index 0000000..f299076 --- /dev/null +++ b/tests/unit/test_matching.py @@ -0,0 +1,43 @@ +"""Unit tests for lib/matching.py.""" + +from __future__ import annotations + +import pytest + +from lib.matching import is_compilation_artist + + +class TestIsCompilationArtist: + """Compilation artist detection.""" + + @pytest.mark.parametrize( + "artist, expected", + [ + ("Various Artists", True), + ("various", True), + ("Soundtrack", True), + ("Original Motion Picture Soundtrack", True), + ("V/A", True), + ("v.a.", True), + ("Compilation Hits", True), + ("Stereolab", False), + ("Juana Molina", False), + ("Cat Power", False), + ("", False), + ], + ids=[ + "various-artists", + "various-lowercase", + "soundtrack", + "soundtrack-in-phrase", + "v-slash-a", + "v-dot-a", + "compilation-keyword", + "stereolab", + "juana-molina", + "cat-power", + "empty-string", + ], + ) + def test_is_compilation_artist(self, artist: str, expected: bool) -> None: + assert is_compilation_artist(artist) == expected diff --git a/tests/unit/test_verify_cache.py b/tests/unit/test_verify_cache.py index 56df5bf..0a13672 100644 --- a/tests/unit/test_verify_cache.py +++ b/tests/unit/test_verify_cache.py @@ -643,6 +643,11 @@ async def test_empty_results(self): # Step 8: Argument Parsing # --------------------------------------------------------------------------- +format_bytes = _vc.format_bytes +ClassificationReport = _vc.ClassificationReport +MatchResult = _vc.MatchResult +print_report = _vc.print_report + classify_all_releases = _vc.classify_all_releases classify_artist_fuzzy = _vc.classify_artist_fuzzy classify_fuzzy_batch = _vc.classify_fuzzy_batch @@ -843,3 +848,132 @@ def test_default_no_copy_to(self, tmp_path): lib_db.touch() args = parse_args([str(lib_db)]) assert args.copy_to is None + + +# --------------------------------------------------------------------------- +# format_bytes +# --------------------------------------------------------------------------- + + +class TestFormatBytes: + """Test human-readable byte formatting.""" + + @pytest.mark.parametrize( + "num_bytes, expected", + [ + (0, "0.0 B"), + (1023, "1023.0 B"), + (1024, "1.0 KB"), + (1048576, "1.0 MB"), + (1073741824, "1.0 GB"), + (1099511627776, "1.0 TB"), + ], + ids=["zero", "bytes", "kilobytes", "megabytes", "gigabytes", "terabytes"], + ) + def test_format_bytes(self, num_bytes: int, expected: str) -> None: + assert format_bytes(num_bytes) == expected + + +# --------------------------------------------------------------------------- +# print_report +# --------------------------------------------------------------------------- + + +class TestPrintReport: + """Test the print_report function with mock data.""" + + def test_basic_report(self, sample_index, capsys: pytest.CaptureFixture[str]) -> None: + report = ClassificationReport( + keep_ids={1, 2, 3}, + prune_ids={4, 5}, + review_ids=set(), + review_by_artist={}, + artist_originals={}, + total_releases=5, + ) + + print_report(report, sample_index) + + captured = capsys.readouterr() + assert "VERIFICATION REPORT" in captured.out + assert "KEEP:" in captured.out + assert "PRUNE:" in captured.out + assert "3" in captured.out # keep count + assert "2" in captured.out # prune count + + def test_report_with_table_sizes( + self, sample_index, capsys: pytest.CaptureFixture[str] + ) -> None: + report = ClassificationReport( + keep_ids={1, 2}, + prune_ids={3, 4, 5}, + review_ids=set(), + review_by_artist={}, + artist_originals={}, + total_releases=5, + ) + table_sizes = { + "release": (100, 1048576), + "release_artist": (200, 2097152), + "release_label": (150, 524288), + "release_track": (500, 4194304), + "release_track_artist": (300, 1048576), + "cache_metadata": (100, 262144), + } + rows_to_delete = { + "release": 60, + "release_artist": 120, + "release_label": 90, + "release_track": 300, + "release_track_artist": 180, + "cache_metadata": 60, + } + + print_report(report, sample_index, table_sizes=table_sizes, rows_to_delete=rows_to_delete) + + captured = capsys.readouterr() + assert "Database size" in captured.out + assert "Estimated savings" in captured.out + assert "release_track" in captured.out + + def test_pruned_report(self, sample_index, capsys: pytest.CaptureFixture[str]) -> None: + report = ClassificationReport( + keep_ids={1, 2}, + prune_ids={3}, + review_ids=set(), + review_by_artist={}, + artist_originals={}, + total_releases=3, + ) + + print_report(report, sample_index, pruned=True) + + captured = capsys.readouterr() + assert "PRUNING REPORT" in captured.out + assert "Releases kept:" in captured.out + assert "Releases pruned:" in captured.out + + def test_report_with_review_artists( + self, sample_index, capsys: pytest.CaptureFixture[str] + ) -> None: + match_result = MatchResult( + decision=Decision.REVIEW, + exact_score=0.0, + token_set_score=0.70, + token_sort_score=0.65, + two_stage_score=0.60, + ) + report = ClassificationReport( + keep_ids={1}, + prune_ids={2}, + review_ids={3}, + review_by_artist={"some artist": [(3, "Some Album", match_result)]}, + artist_originals={"some artist": "Some Artist"}, + total_releases=3, + ) + + print_report(report, sample_index) + + captured = capsys.readouterr() + assert "REVIEW" in captured.out + assert "artist-level decisions needed" in captured.out From e53b989b2dc3fa8c86dd59ddd17c932ffd61bb53 Mon Sep 17 00:00:00 2001 From: Jake Bromberg Date: Mon, 9 Mar 2026 23:56:21 -0700 Subject: [PATCH 2/4] test: add mocked unit tests for orchestration logic in run_pipeline, enrich_library_artists, and import_csv Cover wait_for_postgres, run_sql_file, run_sql_statements_parallel error paths, report_sizes, convert_and_filter command construction, enrich_library_artists command construction, _load_or_create_state resume modes, main() input validation, parse_args validation, MySQL extraction functions (extract_alternate_names, extract_cross_referenced_artists, extract_release_cross_ref_artists), enrich main() with and without MySQL, and import_csv main() dispatch to _import_tables vs _import_tables_parallel. --- tests/unit/test_enrich_library_artists.py | 220 +++++++++++ tests/unit/test_import_csv.py | 103 +++++ tests/unit/test_run_pipeline.py | 455 +++++++++++++++++++++- 3 files changed, 777 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_enrich_library_artists.py b/tests/unit/test_enrich_library_artists.py index cdb245a..a2e107e 100644 --- a/tests/unit/test_enrich_library_artists.py +++ b/tests/unit/test_enrich_library_artists.py @@ -5,6 +5,7 @@ import importlib.util import sys from pathlib import Path +from unittest.mock import MagicMock, patch import pytest @@ -17,6 +18,9 @@ _spec.loader.exec_module(_mod) extract_base_artists = _mod.extract_base_artists +extract_alternate_names = _mod.extract_alternate_names +extract_cross_referenced_artists = _mod.extract_cross_referenced_artists +extract_release_cross_ref_artists = _mod.extract_release_cross_ref_artists merge_and_write = _mod.merge_and_write parse_args = _mod.parse_args @@ -302,3 +306,219 @@ def test_compilation_components_excluded(self, tmp_path: Path) -> None: lines = set(output.read_text().splitlines()) assert "Juana Molina" in lines assert "Various Artists" not in lines + + +# --------------------------------------------------------------------------- +# extract_alternate_names (mocked MySQL) +# --------------------------------------------------------------------------- + + +class TestExtractAlternateNames: + """extract_alternate_names() queries LIBRARY_RELEASE for alternate artist names.""" + + def test_returns_alternate_names(self) -> None: + """Mock cursor returns sample alternate artist names.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([("Body Count",), ("Ice Cube",)])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_alternate_names(mock_conn) + assert result == {"Body Count", "Ice Cube"} + + def test_empty_result(self) -> None: + """Empty cursor returns empty set.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_alternate_names(mock_conn) + assert result == set() + + def test_strips_whitespace_and_skips_empty(self) -> None: + """Whitespace-only and None values are excluded.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([(" Stereolab ",), ("",), (None,)])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_alternate_names(mock_conn) + assert result == {"Stereolab"} + + +# --------------------------------------------------------------------------- +# extract_cross_referenced_artists (mocked MySQL) +# --------------------------------------------------------------------------- + + +class TestExtractCrossReferencedArtists: + """extract_cross_referenced_artists() queries LIBRARY_CODE_CROSS_REFERENCE.""" + + def test_returns_cross_referenced_names(self) -> None: + """Mock cursor returns UNION query results.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock( + return_value=iter([("Cat Power",), ("Chan Marshall",), ("Cat Power",)]) + ) + mock_conn.cursor.return_value = mock_cursor + + result = extract_cross_referenced_artists(mock_conn) + # Duplicates from UNION are already handled by the set + assert "Cat Power" in result + assert "Chan Marshall" in result + + def test_deduplication(self) -> None: + """Duplicate names across UNION branches produce unique results.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock( + return_value=iter([("Jessica Pratt",), ("Jessica Pratt",), ("Chuquimamani-Condori",)]) + ) + mock_conn.cursor.return_value = mock_cursor + + result = extract_cross_referenced_artists(mock_conn) + assert len(result) == 2 + assert result == {"Jessica Pratt", "Chuquimamani-Condori"} + + +# --------------------------------------------------------------------------- +# extract_release_cross_ref_artists (mocked MySQL) +# --------------------------------------------------------------------------- + + +class TestExtractReleaseCrossRefArtists: + """extract_release_cross_ref_artists() queries RELEASE_CROSS_REFERENCE.""" + + def test_returns_release_cross_ref_names(self) -> None: + """Mock cursor returns cross-reference artist names.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([("Juana Molina",), ("Sessa",)])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_release_cross_ref_artists(mock_conn) + assert result == {"Juana Molina", "Sessa"} + + def test_empty_result(self) -> None: + """Empty cursor returns empty set.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_release_cross_ref_artists(mock_conn) + assert result == set() + + +# --------------------------------------------------------------------------- +# main (mocked) +# --------------------------------------------------------------------------- + + +class TestMain: + """main() orchestrates extraction and merge.""" + + def test_with_library_db_only(self, tmp_path) -> None: + """With --library-db only (no MySQL), base artists are extracted and written.""" + library_db = tmp_path / "library.db" + # Create a minimal SQLite library.db + import sqlite3 + + conn = sqlite3.connect(library_db) + conn.execute("CREATE TABLE library (artist TEXT, title TEXT)") + conn.execute("INSERT INTO library VALUES ('Stereolab', 'Aluminum Tunes')") + conn.execute("INSERT INTO library VALUES ('Cat Power', 'Moon Pix')") + conn.commit() + conn.close() + + output = tmp_path / "artists.txt" + + with patch.object( + _mod, + "parse_args", + return_value=parse_args(["--library-db", str(library_db), "--output", str(output)]), + ): + _mod.main() + + lines = set(output.read_text().splitlines()) + assert "Stereolab" in lines + assert "Cat Power" in lines + + def test_with_library_db_and_wxyc_db_url(self, tmp_path) -> None: + """With --library-db and --wxyc-db-url, MySQL enrichment is performed.""" + library_db = tmp_path / "library.db" + import sqlite3 + + conn = sqlite3.connect(library_db) + conn.execute("CREATE TABLE library (artist TEXT, title TEXT)") + conn.execute("INSERT INTO library VALUES ('Stereolab', 'Aluminum Tunes')") + conn.commit() + conn.close() + + output = tmp_path / "artists.txt" + mock_mysql_conn = MagicMock() + + with ( + patch.object( + _mod, + "parse_args", + return_value=parse_args( + [ + "--library-db", + str(library_db), + "--output", + str(output), + "--wxyc-db-url", + "mysql://user:pass@host/db", + ] + ), + ), + patch.object(_mod, "connect_mysql", return_value=mock_mysql_conn), + patch.object(_mod, "extract_alternate_names", return_value={"Nourished by Time"}), + patch.object(_mod, "extract_cross_referenced_artists", return_value={"Buck Meek"}), + patch.object(_mod, "extract_release_cross_ref_artists", return_value={"Sessa"}), + ): + _mod.main() + + lines = set(output.read_text().splitlines()) + assert "Stereolab" in lines + assert "Nourished by Time" in lines + assert "Buck Meek" in lines + assert "Sessa" in lines + mock_mysql_conn.close.assert_called_once() + + def test_missing_library_db_exits(self, tmp_path) -> None: + """Non-existent library.db triggers sys.exit(1).""" + output = tmp_path / "artists.txt" + with ( + patch.object( + _mod, + "parse_args", + return_value=parse_args( + [ + "--library-db", + str(tmp_path / "missing.db"), + "--output", + str(output), + ] + ), + ), + pytest.raises(SystemExit, match="1"), + ): + _mod.main() diff --git a/tests/unit/test_import_csv.py b/tests/unit/test_import_csv.py index d8c84c3..635613e 100644 --- a/tests/unit/test_import_csv.py +++ b/tests/unit/test_import_csv.py @@ -392,3 +392,106 @@ def test_no_header_returns_zero(self, tmp_path, caplog) -> None: assert count == 0 assert "No header" in caplog.text + + +# --------------------------------------------------------------------------- +# main() argument parsing and dispatch +# --------------------------------------------------------------------------- + + +class TestMainArgParsing: + """import_csv.py main() validates args and dispatches to correct import mode.""" + + def test_missing_csv_dir_exits(self, tmp_path) -> None: + """Non-existent CSV directory triggers sys.exit(1).""" + from unittest.mock import patch + + with ( + patch("sys.argv", ["import_csv.py", str(tmp_path / "missing_csv")]), + pytest.raises(SystemExit, match="1"), + ): + _ic.main() + + def test_default_mode_calls_import_tables(self, tmp_path) -> None: + """Default mode (no flags) calls _import_tables for all tables.""" + from unittest.mock import MagicMock, patch + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + mock_conn = MagicMock() + + with ( + patch("sys.argv", ["import_csv.py", str(csv_dir), "postgresql:///test"]), + patch.object(_ic.psycopg, "connect", return_value=mock_conn), + patch.object(_ic, "_import_tables", return_value=100) as mock_import, + patch.object(_ic, "import_artwork", return_value=10), + patch.object(_ic, "populate_cache_metadata", return_value=50), + ): + _ic.main() + + mock_import.assert_called_once() + call_args = mock_import.call_args + assert call_args[0][0] is mock_conn + assert call_args[0][1] == csv_dir + assert call_args[0][2] == TABLES + + def test_base_only_mode_calls_parallel(self, tmp_path) -> None: + """--base-only mode calls _import_tables_parallel with base tables.""" + from unittest.mock import MagicMock, patch + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + mock_conn = MagicMock() + + with ( + patch( + "sys.argv", + ["import_csv.py", "--base-only", str(csv_dir), "postgresql:///test"], + ), + patch.object(_ic.psycopg, "connect", return_value=mock_conn), + patch.object(_ic, "_import_tables_parallel", return_value=100) as mock_parallel, + patch.object(_ic, "import_artwork", return_value=10), + patch.object(_ic, "populate_cache_metadata", return_value=50), + patch.object(_ic, "create_track_count_table", return_value=20), + ): + _ic.main() + + mock_parallel.assert_called_once() + call_args = mock_parallel.call_args + assert call_args[0][0] == "postgresql:///test" + assert call_args[0][1] == csv_dir + # Parent tables are BASE_TABLES[:1], child tables are BASE_TABLES[1:] + assert call_args[1]["parent_tables"] == BASE_TABLES[:1] + assert call_args[1]["child_tables"] == BASE_TABLES[1:] + + def test_tracks_only_mode_uses_release_id_filter(self, tmp_path) -> None: + """--tracks-only mode queries release IDs and filters track import.""" + from unittest.mock import MagicMock, patch + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [(5001,), (5002,), (5003,)] + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch( + "sys.argv", + ["import_csv.py", "--tracks-only", str(csv_dir), "postgresql:///test"], + ), + patch.object(_ic.psycopg, "connect", return_value=mock_conn), + patch.object(_ic, "_import_tables_parallel", return_value=200) as mock_parallel, + ): + _ic.main() + + mock_parallel.assert_called_once() + call_args = mock_parallel.call_args + # --tracks-only passes empty parent_tables and TRACK_TABLES as children + assert call_args[1]["parent_tables"] == [] + assert call_args[1]["child_tables"] == TRACK_TABLES + assert call_args[1]["release_id_filter"] == {5001, 5002, 5003} diff --git a/tests/unit/test_run_pipeline.py b/tests/unit/test_run_pipeline.py index 5e618a5..adf8b03 100644 --- a/tests/unit/test_run_pipeline.py +++ b/tests/unit/test_run_pipeline.py @@ -3,10 +3,11 @@ from __future__ import annotations import importlib.util +import json import logging import sys from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -492,3 +493,455 @@ def fake_convert(xml, output_dir, converter, library_artists=None): assert convert_calls[0][3] is not None, ( "library_artists path should be passed to convert_and_filter" ) + + +# --------------------------------------------------------------------------- +# wait_for_postgres +# --------------------------------------------------------------------------- + + +class TestWaitForPostgres: + """wait_for_postgres() polls until Postgres is ready or times out.""" + + def test_success_on_first_try(self) -> None: + """Successful connection on the first attempt returns immediately.""" + mock_conn = MagicMock() + with patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn): + run_pipeline.wait_for_postgres("postgresql:///test") + mock_conn.close.assert_called_once() + + def test_retry_then_success(self) -> None: + """First call raises OperationalError, second succeeds.""" + mock_conn = MagicMock() + with ( + patch.object( + run_pipeline.psycopg, + "connect", + side_effect=[run_pipeline.psycopg.OperationalError("refused"), mock_conn], + ), + patch.object(run_pipeline.time, "sleep"), + ): + run_pipeline.wait_for_postgres("postgresql:///test") + mock_conn.close.assert_called_once() + + def test_timeout_exits(self) -> None: + """All connection attempts fail and timeout is exceeded -> sys.exit(1).""" + # monotonic returns: first call sets deadline, subsequent calls exceed it + with ( + patch.object( + run_pipeline.psycopg, + "connect", + side_effect=run_pipeline.psycopg.OperationalError("refused"), + ), + patch.object(run_pipeline.time, "monotonic", side_effect=[0.0, 100.0]), + patch.object(run_pipeline.time, "sleep"), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.wait_for_postgres("postgresql:///test") + + +# --------------------------------------------------------------------------- +# run_sql_file +# --------------------------------------------------------------------------- + + +class TestRunSqlFile: + """run_sql_file() executes SQL from a file against the database.""" + + def test_happy_path(self, tmp_path) -> None: + """SQL file contents are executed via cursor.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE t (id int)") + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn): + run_pipeline.run_sql_file("postgresql:///test", sql_file) + + mock_cursor.execute.assert_called_once_with("CREATE TABLE t (id int)") + mock_conn.close.assert_called_once() + + def test_sql_error_exits(self, tmp_path) -> None: + """psycopg.Error during execution triggers sys.exit(1).""" + sql_file = tmp_path / "bad.sql" + sql_file.write_text("INVALID SQL") + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute.side_effect = run_pipeline.psycopg.Error("syntax error") + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.run_sql_file("postgresql:///test", sql_file) + mock_conn.close.assert_called() + + def test_strip_concurrently_removes_keyword(self, tmp_path) -> None: + """strip_concurrently=True removes CONCURRENTLY from SQL.""" + sql_file = tmp_path / "indexes.sql" + sql_file.write_text("CREATE INDEX CONCURRENTLY idx_a ON t(a)") + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn): + run_pipeline.run_sql_file("postgresql:///test", sql_file, strip_concurrently=True) + + executed_sql = mock_cursor.execute.call_args[0][0] + assert "CONCURRENTLY" not in executed_sql + assert "CREATE INDEX idx_a ON t(a)" == executed_sql + + +# --------------------------------------------------------------------------- +# run_sql_statements_parallel — error propagation +# --------------------------------------------------------------------------- + + +class TestRunSqlStatementsParallelError: + """Test that psycopg.Error from a parallel statement is re-raised.""" + + def test_psycopg_error_is_reraised(self) -> None: + """A psycopg.Error in a parallel statement propagates to the caller.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute.side_effect = run_pipeline.psycopg.Error("disk full") + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn), + pytest.raises(run_pipeline.psycopg.Error, match="disk full"), + ): + run_sql_statements_parallel("postgresql:///test", ["CREATE INDEX idx_x ON t(x)"]) + + +# --------------------------------------------------------------------------- +# report_sizes +# --------------------------------------------------------------------------- + + +class TestReportSizes: + """report_sizes() queries pg_stat_user_tables and logs results.""" + + def test_logs_table_sizes(self, caplog) -> None: + """Fetched rows are logged with table names and row counts.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [ + ("release", 50000, "120 MB"), + ("release_artist", 80000, "45 MB"), + ] + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn), + caplog.at_level(logging.INFO, logger=run_pipeline.logger.name), + ): + run_pipeline.report_sizes("postgresql:///test") + + mock_cursor.execute.assert_called_once() + logged = [r.message for r in caplog.records] + assert any("release" in msg and "50,000" in msg for msg in logged) + assert any("release_artist" in msg and "80,000" in msg for msg in logged) + mock_conn.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# convert_and_filter +# --------------------------------------------------------------------------- + + +class TestConvertAndFilter: + """convert_and_filter() constructs the converter command and delegates to run_step.""" + + def test_command_with_library_artists(self) -> None: + """Command includes --library-artists when provided.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.convert_and_filter( + Path("/data/releases.xml.gz"), + Path("/tmp/csv"), + "discogs-xml-converter", + library_artists=Path("/data/library_artists.txt"), + ) + + mock_run.assert_called_once() + cmd = mock_run.call_args[0][1] + assert cmd[0] == "discogs-xml-converter" + assert "/data/releases.xml.gz" in cmd + assert "--output-dir" in cmd + assert "--library-artists" in cmd + assert "/data/library_artists.txt" in cmd + + def test_command_with_database_url(self) -> None: + """Command includes --database-url for direct-PG mode.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.convert_and_filter( + Path("/data/releases.xml.gz"), + Path("/tmp/csv"), + "discogs-xml-converter", + database_url="postgresql:///discogs", + ) + + cmd = mock_run.call_args[0][1] + assert "--database-url" in cmd + assert "postgresql:///discogs" in cmd + # Description mentions PostgreSQL + description = mock_run.call_args[0][0] + assert "PostgreSQL" in description + + def test_command_without_optional_args(self) -> None: + """Command omits --library-artists and --database-url when not provided.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.convert_and_filter( + Path("/data/releases.xml.gz"), + Path("/tmp/csv"), + "discogs-xml-converter", + ) + + cmd = mock_run.call_args[0][1] + assert "--library-artists" not in cmd + assert "--database-url" not in cmd + description = mock_run.call_args[0][0] + assert "CSV" in description + + +# --------------------------------------------------------------------------- +# enrich_library_artists (orchestrator wrapper) +# --------------------------------------------------------------------------- + + +class TestEnrichLibraryArtists: + """enrich_library_artists() constructs the enrichment command.""" + + def test_command_with_wxyc_db_url(self) -> None: + """Command includes --wxyc-db-url when provided.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.enrich_library_artists( + Path("/data/library.db"), + Path("/tmp/library_artists.txt"), + wxyc_db_url="mysql://user:pass@host/db", + ) + + cmd = mock_run.call_args[0][1] + assert "--library-db" in cmd + assert "/data/library.db" in cmd + assert "--output" in cmd + assert "/tmp/library_artists.txt" in cmd + assert "--wxyc-db-url" in cmd + assert "mysql://user:pass@host/db" in cmd + + def test_command_without_wxyc_db_url(self) -> None: + """Command omits --wxyc-db-url when not provided.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.enrich_library_artists( + Path("/data/library.db"), + Path("/tmp/library_artists.txt"), + ) + + cmd = mock_run.call_args[0][1] + assert "--library-db" in cmd + assert "--output" in cmd + assert "--wxyc-db-url" not in cmd + + +# --------------------------------------------------------------------------- +# _load_or_create_state +# --------------------------------------------------------------------------- + + +class TestLoadOrCreateState: + """_load_or_create_state() handles resume modes.""" + + def test_resume_with_existing_state_file(self, tmp_path) -> None: + """When --resume and state file exists, load it.""" + state_file = tmp_path / "state.json" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + # Write a valid state file + state_data = { + "version": 3, + "database_url": "postgresql:///test", + "csv_dir": str(csv_dir.resolve()), + "steps": {s: {"status": "pending"} for s in run_pipeline.STEP_NAMES}, + } + state_data["steps"]["create_schema"] = {"status": "completed"} + state_file.write_text(json.dumps(state_data)) + + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--resume", + "--state-file", + str(state_file), + "--database-url", + "postgresql:///test", + ] + ) + + state = run_pipeline._load_or_create_state(args) + assert state.is_completed("create_schema") + assert not state.is_completed("import_csv") + + def test_resume_without_state_file_uses_db_introspect(self, tmp_path) -> None: + """When --resume but no state file, infer from database.""" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + state_file = tmp_path / "nonexistent_state.json" + + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--resume", + "--state-file", + str(state_file), + "--database-url", + "postgresql:///test", + ] + ) + + mock_state = run_pipeline.PipelineState(db_url="postgresql:///test", csv_dir="") + mock_state.mark_completed("create_schema") + + with patch("lib.db_introspect.infer_pipeline_state", return_value=mock_state): + state = run_pipeline._load_or_create_state(args) + + assert state.is_completed("create_schema") + assert state.csv_dir == str(csv_dir.resolve()) + + def test_fresh_state_no_resume(self, tmp_path) -> None: + """Without --resume, create a fresh PipelineState.""" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--database-url", + "postgresql:///test", + ] + ) + + state = run_pipeline._load_or_create_state(args) + assert not any(state.is_completed(s) for s in run_pipeline.STEP_NAMES) + assert state.db_url == "postgresql:///test" + + +# --------------------------------------------------------------------------- +# main() — input validation +# --------------------------------------------------------------------------- + + +class TestMainValidation: + """main() validates file paths before running the pipeline.""" + + def test_missing_xml_file_exits(self, tmp_path) -> None: + """Non-existent XML file triggers sys.exit(1).""" + args = run_pipeline.parse_args(["--xml", str(tmp_path / "missing.xml.gz")]) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + def test_missing_library_artists_file_exits(self, tmp_path) -> None: + """Non-existent library_artists.txt triggers sys.exit(1).""" + xml_file = tmp_path / "releases.xml.gz" + xml_file.touch() + args = run_pipeline.parse_args( + [ + "--xml", + str(xml_file), + "--library-artists", + str(tmp_path / "missing_artists.txt"), + ] + ) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + def test_missing_csv_dir_exits(self, tmp_path) -> None: + """Non-existent CSV directory triggers sys.exit(1).""" + args = run_pipeline.parse_args(["--csv-dir", str(tmp_path / "missing_csv")]) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + def test_missing_library_db_exits(self, tmp_path) -> None: + """Non-existent library.db triggers sys.exit(1).""" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--library-db", + str(tmp_path / "missing_library.db"), + ] + ) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + def test_missing_library_labels_exits(self, tmp_path) -> None: + """Non-existent library_labels.csv triggers sys.exit(1).""" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--library-labels", + str(tmp_path / "missing_labels.csv"), + ] + ) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + +# --------------------------------------------------------------------------- +# parse_args — additional validation +# --------------------------------------------------------------------------- + + +class TestParseArgsValidation: + """Additional argument validation in parse_args.""" + + def test_direct_pg_without_xml_exits(self) -> None: + """--direct-pg without --xml triggers parser.error (sys.exit(2)).""" + with pytest.raises(SystemExit): + run_pipeline.parse_args(["--csv-dir", "/tmp/csv", "--direct-pg"]) + + def test_generate_library_db_with_library_db_exits(self) -> None: + """--generate-library-db and --library-db are mutually exclusive.""" + with pytest.raises(SystemExit): + run_pipeline.parse_args( + [ + "--csv-dir", + "/tmp/csv", + "--generate-library-db", + "--library-db", + "/tmp/library.db", + ] + ) From fcfd5df506d874a2b34df88883238a31b638952b Mon Sep 17 00:00:00 2001 From: Jake Bromberg Date: Tue, 10 Mar 2026 00:10:09 -0700 Subject: [PATCH 3/4] test: add PostgreSQL integration tests for dedup, import, and verify_cache coverage gaps Add integration tests exercising previously uncovered code paths against a real PostgreSQL database: dedup_releases.py (63% -> 73%): TestEnsureDedupIdsAlreadyExists verifies the early-return path when dedup_delete_ids table already exists. TestAddTrackConstraintsAndIndexes verifies FK constraints and GIN trigram indexes on track tables. TestAddConstraintsAndIndexes verifies the wrapper function creates both base and track constraints. import_csv.py (86% -> 97%): TestPopulateCacheMetadata verifies COPY-based bulk insert of cache_metadata rows. TestImportArtwork verifies primary/fallback image selection, invalid release_id skipping, and empty URI handling. TestImportArtworkMissing and TestCreateTrackCountTableMissing verify the missing-file warning paths. TestImportTables verifies sequential import of table configs including missing CSV handling. verify_cache.py (83% -> 86%): TestGetTableSizes verifies asyncpg-based row count and size queries across all release tables. TestCountRowsToDelete verifies row counting for single, multiple, empty, and nonexistent release ID sets. TestPruneReleases verifies CASCADE deletion of releases and child rows. --- tests/integration/test_dedup.py | 276 +++++++++++++++++++++++++++++++ tests/integration/test_import.py | 261 +++++++++++++++++++++++++++++ tests/integration/test_prune.py | 244 +++++++++++++++++++++++++++ 3 files changed, 781 insertions(+) diff --git a/tests/integration/test_dedup.py b/tests/integration/test_dedup.py index dce4fde..0d29a9d 100644 --- a/tests/integration/test_dedup.py +++ b/tests/integration/test_dedup.py @@ -34,6 +34,7 @@ copy_table = _dd.copy_table swap_tables = _dd.swap_tables add_base_constraints_and_indexes = _dd.add_base_constraints_and_indexes +add_track_constraints_and_indexes = _dd.add_track_constraints_and_indexes add_constraints_and_indexes = _dd.add_constraints_and_indexes load_library_labels = _dd.load_library_labels load_label_hierarchy = _dd.load_label_hierarchy @@ -695,3 +696,278 @@ def test_total_release_count_same(self) -> None: count = cur.fetchone()[0] conn.close() assert count == 12 + + +class TestEnsureDedupIdsAlreadyExists: + """Verify ensure_dedup_ids returns existing count when table already populated.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + conn = psycopg.connect(db_url, autocommit=True) + _drop_all_tables(conn) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + # Insert releases with duplicate master_ids (would normally be deduped) + cur.execute("INSERT INTO release (id, title, master_id) VALUES (1, 'A', 100)") + cur.execute("INSERT INTO release (id, title, master_id) VALUES (2, 'B', 100)") + # Pre-create the dedup_delete_ids table with some IDs + cur.execute(""" + CREATE UNLOGGED TABLE dedup_delete_ids ( + release_id integer PRIMARY KEY + ) + """) + cur.execute("INSERT INTO dedup_delete_ids (release_id) VALUES (1)") + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_returns_existing_count(self) -> None: + """When dedup_delete_ids already exists, returns its count without recreating.""" + conn = psycopg.connect(self.db_url, autocommit=True) + count = ensure_dedup_ids(conn) + conn.close() + assert count == 1 + + def test_table_not_recreated(self) -> None: + """The pre-existing dedup_delete_ids table is not dropped and recreated.""" + conn = psycopg.connect(self.db_url, autocommit=True) + # Add a second ID to the existing table + with conn.cursor() as cur: + cur.execute( + "INSERT INTO dedup_delete_ids (release_id) VALUES (2) ON CONFLICT DO NOTHING" + ) + count = ensure_dedup_ids(conn) + conn.close() + # Should reflect the updated count (not recreate from ROW_NUMBER query) + assert count == 2 + + +class TestAddTrackConstraintsAndIndexes: + """Verify add_track_constraints_and_indexes creates FK constraints and indexes.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + conn = psycopg.connect(db_url, autocommit=True) + _drop_all_tables(conn) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + # Drop existing constraints (schema creates them, we want to test adding them) + cur.execute( + "ALTER TABLE release_track DROP CONSTRAINT IF EXISTS release_track_release_id_fkey" + ) + cur.execute( + "ALTER TABLE release_track_artist " + "DROP CONSTRAINT IF EXISTS release_track_artist_release_id_fkey" + ) + cur.execute("DROP INDEX IF EXISTS idx_release_track_release_id") + cur.execute("DROP INDEX IF EXISTS idx_release_track_artist_release_id") + cur.execute("DROP INDEX IF EXISTS idx_release_track_title_trgm") + cur.execute("DROP INDEX IF EXISTS idx_release_track_artist_name_trgm") + # Insert test data + cur.execute("INSERT INTO release (id, title) VALUES (1, 'Test Album')") + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (1, 1, 'Track 1')" + ) + cur.execute( + "INSERT INTO release_track_artist (release_id, track_sequence, artist_name) " + "VALUES (1, 1, 'Test Artist')" + ) + add_track_constraints_and_indexes(conn, db_url=db_url) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_release_track_fk_exists(self) -> None: + """FK constraint on release_track referencing release exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT constraint_name FROM information_schema.table_constraints + WHERE table_name = 'release_track' AND constraint_type = 'FOREIGN KEY' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_artist_fk_exists(self) -> None: + """FK constraint on release_track_artist referencing release exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT constraint_name FROM information_schema.table_constraints + WHERE table_name = 'release_track_artist' AND constraint_type = 'FOREIGN KEY' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_release_id_index_exists(self) -> None: + """Index on release_track(release_id) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE tablename = 'release_track' AND indexname = 'idx_release_track_release_id' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_artist_release_id_index_exists(self) -> None: + """Index on release_track_artist(release_id) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE tablename = 'release_track_artist' + AND indexname = 'idx_release_track_artist_release_id' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_title_trgm_index_exists(self) -> None: + """GIN trigram index on release_track(title) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE tablename = 'release_track' + AND indexname = 'idx_release_track_title_trgm' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_artist_name_trgm_index_exists(self) -> None: + """GIN trigram index on release_track_artist(artist_name) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE tablename = 'release_track_artist' + AND indexname = 'idx_release_track_artist_name_trgm' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + +class TestAddConstraintsAndIndexes: + """Verify add_constraints_and_indexes creates both base and track constraints.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + conn = psycopg.connect(db_url, autocommit=True) + _drop_all_tables(conn) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + # Drop all FK constraints and indexes (schema creates them) + for constraint, table in [ + ("release_artist_release_id_fkey", "release_artist"), + ("release_label_release_id_fkey", "release_label"), + ("release_track_release_id_fkey", "release_track"), + ("release_track_artist_release_id_fkey", "release_track_artist"), + ("cache_metadata_release_id_fkey", "cache_metadata"), + ]: + cur.execute(f"ALTER TABLE {table} DROP CONSTRAINT IF EXISTS {constraint}") + # Drop the PK on release so add_constraints_and_indexes can recreate it + cur.execute("ALTER TABLE release DROP CONSTRAINT IF EXISTS release_pkey") + cur.execute("ALTER TABLE cache_metadata DROP CONSTRAINT IF EXISTS cache_metadata_pkey") + # Drop all indexes (FK, GIN trigram, cache metadata) + for idx in [ + "idx_release_artist_release_id", + "idx_release_label_release_id", + "idx_release_track_release_id", + "idx_release_track_artist_release_id", + "idx_release_artist_name_trgm", + "idx_release_title_trgm", + "idx_release_track_title_trgm", + "idx_release_track_artist_name_trgm", + "idx_cache_metadata_cached_at", + "idx_cache_metadata_source", + "idx_release_master_id", + ]: + cur.execute(f"DROP INDEX IF EXISTS {idx}") + # Insert test data + cur.execute("INSERT INTO release (id, title) VALUES (1, 'Test Album')") + cur.execute( + "INSERT INTO release_artist (release_id, artist_name) VALUES (1, 'Test Artist')" + ) + cur.execute("INSERT INTO release_label (release_id, label_name) VALUES (1, 'Test Lbl')") + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (1, 1, 'Track 1')" + ) + cur.execute( + "INSERT INTO release_track_artist (release_id, track_sequence, artist_name) " + "VALUES (1, 1, 'Track Artist')" + ) + cur.execute("INSERT INTO cache_metadata (release_id, source) VALUES (1, 'bulk_import')") + add_constraints_and_indexes(conn, db_url=db_url) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_release_pk_exists(self) -> None: + """Primary key on release(id) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT constraint_name FROM information_schema.table_constraints + WHERE table_name = 'release' AND constraint_type = 'PRIMARY KEY' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_all_fk_constraints_exist(self) -> None: + """FK constraints on all child tables exist.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT tc.table_name + FROM information_schema.table_constraints tc + WHERE tc.constraint_type = 'FOREIGN KEY' + """) + fk_tables = {row[0] for row in cur.fetchall()} + conn.close() + expected = { + "release_artist", + "release_label", + "release_track", + "release_track_artist", + "cache_metadata", + } + assert expected.issubset(fk_tables) + + def test_base_and_track_indexes_exist(self) -> None: + """Both base and track FK indexes exist.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT indexname FROM pg_indexes WHERE schemaname = 'public'") + indexes = {row[0] for row in cur.fetchall()} + conn.close() + expected_indexes = { + "idx_release_artist_release_id", + "idx_release_label_release_id", + "idx_release_track_release_id", + "idx_release_track_artist_release_id", + } + assert expected_indexes.issubset(indexes) diff --git a/tests/integration/test_import.py b/tests/integration/test_import.py index 9d93d18..e784029 100644 --- a/tests/integration/test_import.py +++ b/tests/integration/test_import.py @@ -22,6 +22,8 @@ import_csv_func = _ic.import_csv import_artwork = _ic.import_artwork create_track_count_table = _ic.create_track_count_table +populate_cache_metadata = _ic.populate_cache_metadata +_import_tables = _ic._import_tables TABLES = _ic.TABLES BASE_TABLES = _ic.BASE_TABLES TRACK_TABLES = _ic.TRACK_TABLES @@ -467,3 +469,262 @@ def test_duplicate_release_ids_keep_first(self, tmp_path) -> None: conn.close() # First occurrence wins assert title == "DOGA" + + +class TestPopulateCacheMetadata: + """Verify populate_cache_metadata() inserts metadata for all releases via COPY.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (5001, 'DOGA')") + cur.execute("INSERT INTO release (id, title) VALUES (5002, 'Aluminum Tunes')") + cur.execute("INSERT INTO release (id, title) VALUES (5003, 'Moon Pix')") + conn.close() + + conn = psycopg.connect(db_url) + populate_cache_metadata(conn) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_metadata_row_count(self) -> None: + """One cache_metadata row per release.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM cache_metadata") + count = cur.fetchone()[0] + conn.close() + assert count == 3 + + def test_metadata_source(self) -> None: + """All rows have source='bulk_import'.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT DISTINCT source FROM cache_metadata") + sources = {row[0] for row in cur.fetchall()} + conn.close() + assert sources == {"bulk_import"} + + def test_metadata_release_ids(self) -> None: + """Metadata release_ids match the inserted releases.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT release_id FROM cache_metadata ORDER BY release_id") + ids = [row[0] for row in cur.fetchall()] + conn.close() + assert ids == [5001, 5002, 5003] + + def test_metadata_cached_at_not_null(self) -> None: + """cached_at defaults to current timestamp (not null).""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM cache_metadata WHERE cached_at IS NOT NULL") + count = cur.fetchone()[0] + conn.close() + assert count == 3 + + +class TestImportArtwork: + """Verify import_artwork() populates artwork_url from release_image.csv.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (101, 'Album A')") + cur.execute("INSERT INTO release (id, title) VALUES (102, 'Album B')") + cur.execute("INSERT INTO release (id, title) VALUES (103, 'Album C')") + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_primary_image_preferred(self, tmp_path) -> None: + """Primary image type is used over secondary.""" + csv_path = tmp_path / "release_image.csv" + csv_path.write_text( + "release_id,type,width,height,uri\n" + "101,secondary,300,300,https://img.discogs.com/secondary-101.jpg\n" + "101,primary,600,600,https://img.discogs.com/primary-101.jpg\n" + ) + conn = psycopg.connect(self.db_url) + import_artwork(conn, tmp_path) + conn.close() + + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT artwork_url FROM release WHERE id = 101") + url = cur.fetchone()[0] + conn.close() + assert url == "https://img.discogs.com/primary-101.jpg" + + def test_fallback_when_no_primary(self, tmp_path) -> None: + """Secondary image used as fallback when no primary exists.""" + csv_path = tmp_path / "release_image.csv" + csv_path.write_text( + "release_id,type,width,height,uri\n" + "102,secondary,600,600,https://img.discogs.com/secondary-102.jpg\n" + ) + # Reset artwork_url for release 102 + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("UPDATE release SET artwork_url = NULL WHERE id = 102") + conn.commit() + import_artwork(conn, tmp_path) + conn.close() + + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT artwork_url FROM release WHERE id = 102") + url = cur.fetchone()[0] + conn.close() + assert url == "https://img.discogs.com/secondary-102.jpg" + + def test_invalid_release_id_skipped(self, tmp_path) -> None: + """Rows with non-integer release_id are silently skipped.""" + csv_path = tmp_path / "release_image.csv" + csv_path.write_text( + "release_id,type,width,height,uri\n" + "abc,primary,600,600,https://img.discogs.com/bad.jpg\n" + "103,primary,600,600,https://img.discogs.com/good-103.jpg\n" + ) + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("UPDATE release SET artwork_url = NULL WHERE id = 103") + conn.commit() + count = import_artwork(conn, tmp_path) + conn.close() + + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT artwork_url FROM release WHERE id = 103") + url = cur.fetchone()[0] + conn.close() + assert url == "https://img.discogs.com/good-103.jpg" + assert count >= 1 + + def test_empty_uri_skipped(self, tmp_path) -> None: + """Rows with empty URI are skipped.""" + csv_path = tmp_path / "release_image.csv" + csv_path.write_text("release_id,type,width,height,uri\n103,primary,600,600,\n") + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("UPDATE release SET artwork_url = NULL WHERE id = 103") + conn.commit() + count = import_artwork(conn, tmp_path) + conn.close() + + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT artwork_url FROM release WHERE id = 103") + url = cur.fetchone()[0] + conn.close() + assert url is None + assert count == 0 + + +class TestImportArtworkMissing: + """Verify import_artwork() returns 0 when release_image.csv is missing.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (1, 'Test')") + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_returns_zero(self, tmp_path) -> None: + """import_artwork returns 0 when release_image.csv does not exist.""" + conn = psycopg.connect(self.db_url) + result = import_artwork(conn, tmp_path) + conn.close() + assert result == 0 + + +class TestCreateTrackCountTableMissing: + """Verify create_track_count_table() returns 0 when release_track.csv is missing.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_returns_zero(self, tmp_path) -> None: + """create_track_count_table returns 0 when release_track.csv does not exist.""" + conn = psycopg.connect(self.db_url) + result = create_track_count_table(conn, tmp_path) + conn.close() + assert result == 0 + + +class TestImportTables: + """Verify _import_tables() sequential import of table configs.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_imports_all_tables(self) -> None: + """_import_tables imports all CSVs in the table list and returns total count.""" + conn = psycopg.connect(self.db_url) + total = _import_tables(conn, CSV_DIR, BASE_TABLES) + conn.close() + + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release") + release_count = cur.fetchone()[0] + cur.execute("SELECT count(*) FROM release_artist") + artist_count = cur.fetchone()[0] + conn.close() + assert release_count > 0 + assert artist_count > 0 + assert total == release_count + artist_count + 16 # + release_label count + + def test_skips_missing_csv(self, tmp_path) -> None: + """_import_tables skips table configs whose CSV file does not exist.""" + conn = psycopg.connect(self.db_url) + total = _import_tables(conn, tmp_path, TRACK_TABLES) + conn.close() + assert total == 0 diff --git a/tests/integration/test_prune.py b/tests/integration/test_prune.py index a3845c9..b8cfce2 100644 --- a/tests/integration/test_prune.py +++ b/tests/integration/test_prune.py @@ -6,6 +6,9 @@ import sys as _sys from pathlib import Path +import asyncio + +import asyncpg import psycopg import pytest @@ -44,6 +47,9 @@ MultiIndexMatcher = _vc.MultiIndexMatcher Decision = _vc.Decision classify_all_releases = _vc.classify_all_releases +get_table_sizes = _vc.get_table_sizes +count_rows_to_delete = _vc.count_rows_to_delete +prune_releases = _vc.prune_releases pytestmark = pytest.mark.postgres @@ -208,3 +214,241 @@ def test_keep_releases_survive_prune(self) -> None: count = cur.fetchone()[0] conn.close() assert count == len(self.report.keep_ids) + + +def _set_up_tables_for_async(db_url: str) -> None: + """Set up schema and insert test data for async function tests.""" + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + for table in ALL_TABLES: + cur.execute(f"DROP TABLE IF EXISTS {table} CASCADE") + cur.execute("DROP TABLE IF EXISTS release_label CASCADE") + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (101, 'Confield')") + cur.execute("INSERT INTO release (id, title) VALUES (102, 'DOGA')") + cur.execute("INSERT INTO release (id, title) VALUES (103, 'Moon Pix')") + cur.execute( + "INSERT INTO release_artist (release_id, artist_name, extra) " + "VALUES (101, 'Autechre', 0)" + ) + cur.execute( + "INSERT INTO release_artist (release_id, artist_name, extra) " + "VALUES (102, 'Juana Molina', 0)" + ) + cur.execute( + "INSERT INTO release_artist (release_id, artist_name, extra) " + "VALUES (103, 'Cat Power', 0)" + ) + cur.execute("INSERT INTO release_label (release_id, label_name) VALUES (101, 'Warp')") + cur.execute("INSERT INTO release_label (release_id, label_name) VALUES (102, 'Sonamos')") + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) " + "VALUES (101, 1, 'VI Scose Poise')" + ) + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (102, 1, 'Cosoco')" + ) + cur.execute( + "INSERT INTO release_track_artist (release_id, track_sequence, artist_name) " + "VALUES (101, 1, 'Autechre')" + ) + cur.execute("INSERT INTO cache_metadata (release_id, source) VALUES (101, 'bulk_import')") + cur.execute("INSERT INTO cache_metadata (release_id, source) VALUES (102, 'bulk_import')") + cur.execute("INSERT INTO cache_metadata (release_id, source) VALUES (103, 'bulk_import')") + conn.close() + + +class TestGetTableSizes: + """Verify get_table_sizes returns row counts and sizes for each release table.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + _set_up_tables_for_async(db_url) + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_returns_all_tables(self) -> None: + """get_table_sizes returns entries for all RELEASE_TABLES.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await get_table_sizes(conn) + finally: + await conn.close() + + sizes = asyncio.run(_run()) + expected_tables = { + "release", + "release_artist", + "release_label", + "release_track", + "release_track_artist", + "cache_metadata", + } + assert set(sizes.keys()) == expected_tables + + def test_release_row_count(self) -> None: + """release table has 3 rows.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await get_table_sizes(conn) + finally: + await conn.close() + + sizes = asyncio.run(_run()) + row_count, size_bytes = sizes["release"] + assert row_count == 3 + assert size_bytes > 0 + + def test_track_artist_row_count(self) -> None: + """release_track_artist table has 1 row from test data.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await get_table_sizes(conn) + finally: + await conn.close() + + sizes = asyncio.run(_run()) + row_count, _ = sizes["release_track_artist"] + assert row_count == 1 + + +class TestCountRowsToDelete: + """Verify count_rows_to_delete counts rows matching given release IDs.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + _set_up_tables_for_async(db_url) + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_counts_for_single_release(self) -> None: + """Counts rows to delete for a single release ID.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await count_rows_to_delete(conn, {101}) + finally: + await conn.close() + + counts = asyncio.run(_run()) + assert counts["release"] == 1 + assert counts["release_artist"] == 1 + assert counts["release_label"] == 1 + assert counts["release_track"] == 1 + assert counts["release_track_artist"] == 1 + assert counts["cache_metadata"] == 1 + + def test_counts_for_multiple_releases(self) -> None: + """Counts rows to delete for multiple release IDs.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await count_rows_to_delete(conn, {101, 102}) + finally: + await conn.close() + + counts = asyncio.run(_run()) + assert counts["release"] == 2 + assert counts["release_artist"] == 2 + + def test_empty_release_set(self) -> None: + """Empty release set returns zero counts.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await count_rows_to_delete(conn, set()) + finally: + await conn.close() + + counts = asyncio.run(_run()) + for table_count in counts.values(): + assert table_count == 0 + + def test_nonexistent_release_id(self) -> None: + """Non-existent release ID returns zero counts.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await count_rows_to_delete(conn, {99999}) + finally: + await conn.close() + + counts = asyncio.run(_run()) + assert counts["release"] == 0 + + +class TestPruneReleases: + """Verify prune_releases deletes releases and child rows via CASCADE.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + _set_up_tables_for_async(db_url) + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_deletes_specified_releases_and_cascades(self) -> None: + """prune_releases deletes the specified release IDs and cascades to children.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await prune_releases(conn, {101}) + finally: + await conn.close() + + result = asyncio.run(_run()) + assert result["release"] == 1 + + # Verify release 101 is gone, and cascades cleaned children + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release WHERE id = 101") + assert cur.fetchone()[0] == 0 + cur.execute("SELECT count(*) FROM release_artist WHERE release_id = 101") + assert cur.fetchone()[0] == 0 + cur.execute("SELECT count(*) FROM release_track WHERE release_id = 101") + assert cur.fetchone()[0] == 0 + cur.execute("SELECT count(*) FROM cache_metadata WHERE release_id = 101") + assert cur.fetchone()[0] == 0 + conn.close() + + def test_undeleted_releases_survive(self) -> None: + """Releases not in the prune set are not affected.""" + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release WHERE id IN (102, 103)") + count = cur.fetchone()[0] + conn.close() + assert count == 2 + + def test_empty_set_deletes_nothing(self) -> None: + """Empty release set returns zero count and deletes nothing.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await prune_releases(conn, set()) + finally: + await conn.close() + + result = asyncio.run(_run()) + assert result["release"] == 0 From e4169c8fc5d4b67910e7260c204ac1d2147a3365 Mon Sep 17 00:00:00 2001 From: Jake Bromberg Date: Tue, 10 Mar 2026 00:14:46 -0700 Subject: [PATCH 4/4] fix: sort imports in test_prune.py for ruff isort --- tests/integration/test_prune.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration/test_prune.py b/tests/integration/test_prune.py index b8cfce2..ea3cc5b 100644 --- a/tests/integration/test_prune.py +++ b/tests/integration/test_prune.py @@ -2,12 +2,11 @@ from __future__ import annotations +import asyncio import importlib.util import sys as _sys from pathlib import Path -import asyncio - import asyncpg import psycopg import pytest