diff --git a/tests/integration/test_dedup.py b/tests/integration/test_dedup.py index dce4fde..0d29a9d 100644 --- a/tests/integration/test_dedup.py +++ b/tests/integration/test_dedup.py @@ -34,6 +34,7 @@ copy_table = _dd.copy_table swap_tables = _dd.swap_tables add_base_constraints_and_indexes = _dd.add_base_constraints_and_indexes +add_track_constraints_and_indexes = _dd.add_track_constraints_and_indexes add_constraints_and_indexes = _dd.add_constraints_and_indexes load_library_labels = _dd.load_library_labels load_label_hierarchy = _dd.load_label_hierarchy @@ -695,3 +696,278 @@ def test_total_release_count_same(self) -> None: count = cur.fetchone()[0] conn.close() assert count == 12 + + +class TestEnsureDedupIdsAlreadyExists: + """Verify ensure_dedup_ids returns existing count when table already populated.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + conn = psycopg.connect(db_url, autocommit=True) + _drop_all_tables(conn) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + # Insert releases with duplicate master_ids (would normally be deduped) + cur.execute("INSERT INTO release (id, title, master_id) VALUES (1, 'A', 100)") + cur.execute("INSERT INTO release (id, title, master_id) VALUES (2, 'B', 100)") + # Pre-create the dedup_delete_ids table with some IDs + cur.execute(""" + CREATE UNLOGGED TABLE dedup_delete_ids ( + release_id integer PRIMARY KEY + ) + """) + cur.execute("INSERT INTO dedup_delete_ids (release_id) VALUES (1)") + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_returns_existing_count(self) -> None: + """When dedup_delete_ids already exists, returns its count without recreating.""" + conn = psycopg.connect(self.db_url, autocommit=True) + count = ensure_dedup_ids(conn) + conn.close() + assert count == 1 + + def test_table_not_recreated(self) -> None: + """The pre-existing dedup_delete_ids table is not dropped and recreated.""" + conn = psycopg.connect(self.db_url, autocommit=True) + # Add a second ID to the existing table + with conn.cursor() as cur: + cur.execute( + "INSERT INTO dedup_delete_ids (release_id) VALUES (2) ON CONFLICT DO NOTHING" + ) + count = ensure_dedup_ids(conn) + conn.close() + # Should reflect the updated count (not recreate from ROW_NUMBER query) + assert count == 2 + + +class TestAddTrackConstraintsAndIndexes: + """Verify add_track_constraints_and_indexes creates FK constraints and indexes.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + conn = psycopg.connect(db_url, autocommit=True) + _drop_all_tables(conn) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + # Drop existing constraints (schema creates them, we want to test adding them) + cur.execute( + "ALTER TABLE release_track DROP CONSTRAINT IF EXISTS release_track_release_id_fkey" + ) + cur.execute( + "ALTER TABLE release_track_artist " + "DROP CONSTRAINT IF EXISTS release_track_artist_release_id_fkey" + ) + cur.execute("DROP INDEX IF EXISTS idx_release_track_release_id") + cur.execute("DROP INDEX IF EXISTS idx_release_track_artist_release_id") + cur.execute("DROP INDEX IF EXISTS idx_release_track_title_trgm") + cur.execute("DROP INDEX IF EXISTS idx_release_track_artist_name_trgm") + # Insert test data + cur.execute("INSERT INTO release (id, title) VALUES (1, 'Test Album')") + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (1, 1, 'Track 1')" + ) + cur.execute( + "INSERT INTO release_track_artist (release_id, track_sequence, artist_name) " + "VALUES (1, 1, 'Test Artist')" + ) + add_track_constraints_and_indexes(conn, db_url=db_url) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_release_track_fk_exists(self) -> None: + """FK constraint on release_track referencing release exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT constraint_name FROM information_schema.table_constraints + WHERE table_name = 'release_track' AND constraint_type = 'FOREIGN KEY' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_artist_fk_exists(self) -> None: + """FK constraint on release_track_artist referencing release exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT constraint_name FROM information_schema.table_constraints + WHERE table_name = 'release_track_artist' AND constraint_type = 'FOREIGN KEY' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_release_id_index_exists(self) -> None: + """Index on release_track(release_id) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE tablename = 'release_track' AND indexname = 'idx_release_track_release_id' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_artist_release_id_index_exists(self) -> None: + """Index on release_track_artist(release_id) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE tablename = 'release_track_artist' + AND indexname = 'idx_release_track_artist_release_id' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_title_trgm_index_exists(self) -> None: + """GIN trigram index on release_track(title) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE tablename = 'release_track' + AND indexname = 'idx_release_track_title_trgm' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_release_track_artist_name_trgm_index_exists(self) -> None: + """GIN trigram index on release_track_artist(artist_name) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT indexname FROM pg_indexes + WHERE tablename = 'release_track_artist' + AND indexname = 'idx_release_track_artist_name_trgm' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + +class TestAddConstraintsAndIndexes: + """Verify add_constraints_and_indexes creates both base and track constraints.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + conn = psycopg.connect(db_url, autocommit=True) + _drop_all_tables(conn) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute(SCHEMA_DIR.joinpath("create_functions.sql").read_text()) + # Drop all FK constraints and indexes (schema creates them) + for constraint, table in [ + ("release_artist_release_id_fkey", "release_artist"), + ("release_label_release_id_fkey", "release_label"), + ("release_track_release_id_fkey", "release_track"), + ("release_track_artist_release_id_fkey", "release_track_artist"), + ("cache_metadata_release_id_fkey", "cache_metadata"), + ]: + cur.execute(f"ALTER TABLE {table} DROP CONSTRAINT IF EXISTS {constraint}") + # Drop the PK on release so add_constraints_and_indexes can recreate it + cur.execute("ALTER TABLE release DROP CONSTRAINT IF EXISTS release_pkey") + cur.execute("ALTER TABLE cache_metadata DROP CONSTRAINT IF EXISTS cache_metadata_pkey") + # Drop all indexes (FK, GIN trigram, cache metadata) + for idx in [ + "idx_release_artist_release_id", + "idx_release_label_release_id", + "idx_release_track_release_id", + "idx_release_track_artist_release_id", + "idx_release_artist_name_trgm", + "idx_release_title_trgm", + "idx_release_track_title_trgm", + "idx_release_track_artist_name_trgm", + "idx_cache_metadata_cached_at", + "idx_cache_metadata_source", + "idx_release_master_id", + ]: + cur.execute(f"DROP INDEX IF EXISTS {idx}") + # Insert test data + cur.execute("INSERT INTO release (id, title) VALUES (1, 'Test Album')") + cur.execute( + "INSERT INTO release_artist (release_id, artist_name) VALUES (1, 'Test Artist')" + ) + cur.execute("INSERT INTO release_label (release_id, label_name) VALUES (1, 'Test Lbl')") + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (1, 1, 'Track 1')" + ) + cur.execute( + "INSERT INTO release_track_artist (release_id, track_sequence, artist_name) " + "VALUES (1, 1, 'Track Artist')" + ) + cur.execute("INSERT INTO cache_metadata (release_id, source) VALUES (1, 'bulk_import')") + add_constraints_and_indexes(conn, db_url=db_url) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_release_pk_exists(self) -> None: + """Primary key on release(id) exists.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT constraint_name FROM information_schema.table_constraints + WHERE table_name = 'release' AND constraint_type = 'PRIMARY KEY' + """) + result = cur.fetchone() + conn.close() + assert result is not None + + def test_all_fk_constraints_exist(self) -> None: + """FK constraints on all child tables exist.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute(""" + SELECT tc.table_name + FROM information_schema.table_constraints tc + WHERE tc.constraint_type = 'FOREIGN KEY' + """) + fk_tables = {row[0] for row in cur.fetchall()} + conn.close() + expected = { + "release_artist", + "release_label", + "release_track", + "release_track_artist", + "cache_metadata", + } + assert expected.issubset(fk_tables) + + def test_base_and_track_indexes_exist(self) -> None: + """Both base and track FK indexes exist.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT indexname FROM pg_indexes WHERE schemaname = 'public'") + indexes = {row[0] for row in cur.fetchall()} + conn.close() + expected_indexes = { + "idx_release_artist_release_id", + "idx_release_label_release_id", + "idx_release_track_release_id", + "idx_release_track_artist_release_id", + } + assert expected_indexes.issubset(indexes) diff --git a/tests/integration/test_import.py b/tests/integration/test_import.py index 9d93d18..e784029 100644 --- a/tests/integration/test_import.py +++ b/tests/integration/test_import.py @@ -22,6 +22,8 @@ import_csv_func = _ic.import_csv import_artwork = _ic.import_artwork create_track_count_table = _ic.create_track_count_table +populate_cache_metadata = _ic.populate_cache_metadata +_import_tables = _ic._import_tables TABLES = _ic.TABLES BASE_TABLES = _ic.BASE_TABLES TRACK_TABLES = _ic.TRACK_TABLES @@ -467,3 +469,262 @@ def test_duplicate_release_ids_keep_first(self, tmp_path) -> None: conn.close() # First occurrence wins assert title == "DOGA" + + +class TestPopulateCacheMetadata: + """Verify populate_cache_metadata() inserts metadata for all releases via COPY.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (5001, 'DOGA')") + cur.execute("INSERT INTO release (id, title) VALUES (5002, 'Aluminum Tunes')") + cur.execute("INSERT INTO release (id, title) VALUES (5003, 'Moon Pix')") + conn.close() + + conn = psycopg.connect(db_url) + populate_cache_metadata(conn) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_metadata_row_count(self) -> None: + """One cache_metadata row per release.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM cache_metadata") + count = cur.fetchone()[0] + conn.close() + assert count == 3 + + def test_metadata_source(self) -> None: + """All rows have source='bulk_import'.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT DISTINCT source FROM cache_metadata") + sources = {row[0] for row in cur.fetchall()} + conn.close() + assert sources == {"bulk_import"} + + def test_metadata_release_ids(self) -> None: + """Metadata release_ids match the inserted releases.""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT release_id FROM cache_metadata ORDER BY release_id") + ids = [row[0] for row in cur.fetchall()] + conn.close() + assert ids == [5001, 5002, 5003] + + def test_metadata_cached_at_not_null(self) -> None: + """cached_at defaults to current timestamp (not null).""" + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM cache_metadata WHERE cached_at IS NOT NULL") + count = cur.fetchone()[0] + conn.close() + assert count == 3 + + +class TestImportArtwork: + """Verify import_artwork() populates artwork_url from release_image.csv.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (101, 'Album A')") + cur.execute("INSERT INTO release (id, title) VALUES (102, 'Album B')") + cur.execute("INSERT INTO release (id, title) VALUES (103, 'Album C')") + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def _connect(self): + return psycopg.connect(self.db_url) + + def test_primary_image_preferred(self, tmp_path) -> None: + """Primary image type is used over secondary.""" + csv_path = tmp_path / "release_image.csv" + csv_path.write_text( + "release_id,type,width,height,uri\n" + "101,secondary,300,300,https://img.discogs.com/secondary-101.jpg\n" + "101,primary,600,600,https://img.discogs.com/primary-101.jpg\n" + ) + conn = psycopg.connect(self.db_url) + import_artwork(conn, tmp_path) + conn.close() + + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT artwork_url FROM release WHERE id = 101") + url = cur.fetchone()[0] + conn.close() + assert url == "https://img.discogs.com/primary-101.jpg" + + def test_fallback_when_no_primary(self, tmp_path) -> None: + """Secondary image used as fallback when no primary exists.""" + csv_path = tmp_path / "release_image.csv" + csv_path.write_text( + "release_id,type,width,height,uri\n" + "102,secondary,600,600,https://img.discogs.com/secondary-102.jpg\n" + ) + # Reset artwork_url for release 102 + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("UPDATE release SET artwork_url = NULL WHERE id = 102") + conn.commit() + import_artwork(conn, tmp_path) + conn.close() + + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT artwork_url FROM release WHERE id = 102") + url = cur.fetchone()[0] + conn.close() + assert url == "https://img.discogs.com/secondary-102.jpg" + + def test_invalid_release_id_skipped(self, tmp_path) -> None: + """Rows with non-integer release_id are silently skipped.""" + csv_path = tmp_path / "release_image.csv" + csv_path.write_text( + "release_id,type,width,height,uri\n" + "abc,primary,600,600,https://img.discogs.com/bad.jpg\n" + "103,primary,600,600,https://img.discogs.com/good-103.jpg\n" + ) + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("UPDATE release SET artwork_url = NULL WHERE id = 103") + conn.commit() + count = import_artwork(conn, tmp_path) + conn.close() + + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT artwork_url FROM release WHERE id = 103") + url = cur.fetchone()[0] + conn.close() + assert url == "https://img.discogs.com/good-103.jpg" + assert count >= 1 + + def test_empty_uri_skipped(self, tmp_path) -> None: + """Rows with empty URI are skipped.""" + csv_path = tmp_path / "release_image.csv" + csv_path.write_text("release_id,type,width,height,uri\n103,primary,600,600,\n") + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("UPDATE release SET artwork_url = NULL WHERE id = 103") + conn.commit() + count = import_artwork(conn, tmp_path) + conn.close() + + conn = self._connect() + with conn.cursor() as cur: + cur.execute("SELECT artwork_url FROM release WHERE id = 103") + url = cur.fetchone()[0] + conn.close() + assert url is None + assert count == 0 + + +class TestImportArtworkMissing: + """Verify import_artwork() returns 0 when release_image.csv is missing.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (1, 'Test')") + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_returns_zero(self, tmp_path) -> None: + """import_artwork returns 0 when release_image.csv does not exist.""" + conn = psycopg.connect(self.db_url) + result = import_artwork(conn, tmp_path) + conn.close() + assert result == 0 + + +class TestCreateTrackCountTableMissing: + """Verify create_track_count_table() returns 0 when release_track.csv is missing.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_returns_zero(self, tmp_path) -> None: + """create_track_count_table returns 0 when release_track.csv does not exist.""" + conn = psycopg.connect(self.db_url) + result = create_track_count_table(conn, tmp_path) + conn.close() + assert result == 0 + + +class TestImportTables: + """Verify _import_tables() sequential import of table configs.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up_database(self, db_url): + self.__class__._db_url = db_url + _clean_db(db_url) + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + conn.close() + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_imports_all_tables(self) -> None: + """_import_tables imports all CSVs in the table list and returns total count.""" + conn = psycopg.connect(self.db_url) + total = _import_tables(conn, CSV_DIR, BASE_TABLES) + conn.close() + + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release") + release_count = cur.fetchone()[0] + cur.execute("SELECT count(*) FROM release_artist") + artist_count = cur.fetchone()[0] + conn.close() + assert release_count > 0 + assert artist_count > 0 + assert total == release_count + artist_count + 16 # + release_label count + + def test_skips_missing_csv(self, tmp_path) -> None: + """_import_tables skips table configs whose CSV file does not exist.""" + conn = psycopg.connect(self.db_url) + total = _import_tables(conn, tmp_path, TRACK_TABLES) + conn.close() + assert total == 0 diff --git a/tests/integration/test_prune.py b/tests/integration/test_prune.py index a3845c9..ea3cc5b 100644 --- a/tests/integration/test_prune.py +++ b/tests/integration/test_prune.py @@ -2,10 +2,12 @@ from __future__ import annotations +import asyncio import importlib.util import sys as _sys from pathlib import Path +import asyncpg import psycopg import pytest @@ -44,6 +46,9 @@ MultiIndexMatcher = _vc.MultiIndexMatcher Decision = _vc.Decision classify_all_releases = _vc.classify_all_releases +get_table_sizes = _vc.get_table_sizes +count_rows_to_delete = _vc.count_rows_to_delete +prune_releases = _vc.prune_releases pytestmark = pytest.mark.postgres @@ -208,3 +213,241 @@ def test_keep_releases_survive_prune(self) -> None: count = cur.fetchone()[0] conn.close() assert count == len(self.report.keep_ids) + + +def _set_up_tables_for_async(db_url: str) -> None: + """Set up schema and insert test data for async function tests.""" + conn = psycopg.connect(db_url, autocommit=True) + with conn.cursor() as cur: + for table in ALL_TABLES: + cur.execute(f"DROP TABLE IF EXISTS {table} CASCADE") + cur.execute("DROP TABLE IF EXISTS release_label CASCADE") + cur.execute(SCHEMA_DIR.joinpath("create_database.sql").read_text()) + cur.execute("INSERT INTO release (id, title) VALUES (101, 'Confield')") + cur.execute("INSERT INTO release (id, title) VALUES (102, 'DOGA')") + cur.execute("INSERT INTO release (id, title) VALUES (103, 'Moon Pix')") + cur.execute( + "INSERT INTO release_artist (release_id, artist_name, extra) " + "VALUES (101, 'Autechre', 0)" + ) + cur.execute( + "INSERT INTO release_artist (release_id, artist_name, extra) " + "VALUES (102, 'Juana Molina', 0)" + ) + cur.execute( + "INSERT INTO release_artist (release_id, artist_name, extra) " + "VALUES (103, 'Cat Power', 0)" + ) + cur.execute("INSERT INTO release_label (release_id, label_name) VALUES (101, 'Warp')") + cur.execute("INSERT INTO release_label (release_id, label_name) VALUES (102, 'Sonamos')") + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) " + "VALUES (101, 1, 'VI Scose Poise')" + ) + cur.execute( + "INSERT INTO release_track (release_id, sequence, title) VALUES (102, 1, 'Cosoco')" + ) + cur.execute( + "INSERT INTO release_track_artist (release_id, track_sequence, artist_name) " + "VALUES (101, 1, 'Autechre')" + ) + cur.execute("INSERT INTO cache_metadata (release_id, source) VALUES (101, 'bulk_import')") + cur.execute("INSERT INTO cache_metadata (release_id, source) VALUES (102, 'bulk_import')") + cur.execute("INSERT INTO cache_metadata (release_id, source) VALUES (103, 'bulk_import')") + conn.close() + + +class TestGetTableSizes: + """Verify get_table_sizes returns row counts and sizes for each release table.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + _set_up_tables_for_async(db_url) + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_returns_all_tables(self) -> None: + """get_table_sizes returns entries for all RELEASE_TABLES.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await get_table_sizes(conn) + finally: + await conn.close() + + sizes = asyncio.run(_run()) + expected_tables = { + "release", + "release_artist", + "release_label", + "release_track", + "release_track_artist", + "cache_metadata", + } + assert set(sizes.keys()) == expected_tables + + def test_release_row_count(self) -> None: + """release table has 3 rows.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await get_table_sizes(conn) + finally: + await conn.close() + + sizes = asyncio.run(_run()) + row_count, size_bytes = sizes["release"] + assert row_count == 3 + assert size_bytes > 0 + + def test_track_artist_row_count(self) -> None: + """release_track_artist table has 1 row from test data.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await get_table_sizes(conn) + finally: + await conn.close() + + sizes = asyncio.run(_run()) + row_count, _ = sizes["release_track_artist"] + assert row_count == 1 + + +class TestCountRowsToDelete: + """Verify count_rows_to_delete counts rows matching given release IDs.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + _set_up_tables_for_async(db_url) + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_counts_for_single_release(self) -> None: + """Counts rows to delete for a single release ID.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await count_rows_to_delete(conn, {101}) + finally: + await conn.close() + + counts = asyncio.run(_run()) + assert counts["release"] == 1 + assert counts["release_artist"] == 1 + assert counts["release_label"] == 1 + assert counts["release_track"] == 1 + assert counts["release_track_artist"] == 1 + assert counts["cache_metadata"] == 1 + + def test_counts_for_multiple_releases(self) -> None: + """Counts rows to delete for multiple release IDs.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await count_rows_to_delete(conn, {101, 102}) + finally: + await conn.close() + + counts = asyncio.run(_run()) + assert counts["release"] == 2 + assert counts["release_artist"] == 2 + + def test_empty_release_set(self) -> None: + """Empty release set returns zero counts.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await count_rows_to_delete(conn, set()) + finally: + await conn.close() + + counts = asyncio.run(_run()) + for table_count in counts.values(): + assert table_count == 0 + + def test_nonexistent_release_id(self) -> None: + """Non-existent release ID returns zero counts.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await count_rows_to_delete(conn, {99999}) + finally: + await conn.close() + + counts = asyncio.run(_run()) + assert counts["release"] == 0 + + +class TestPruneReleases: + """Verify prune_releases deletes releases and child rows via CASCADE.""" + + @pytest.fixture(autouse=True, scope="class") + def _set_up(self, db_url): + self.__class__._db_url = db_url + _set_up_tables_for_async(db_url) + + @pytest.fixture(autouse=True) + def _store_url(self): + self.db_url = self.__class__._db_url + + def test_deletes_specified_releases_and_cascades(self) -> None: + """prune_releases deletes the specified release IDs and cascades to children.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await prune_releases(conn, {101}) + finally: + await conn.close() + + result = asyncio.run(_run()) + assert result["release"] == 1 + + # Verify release 101 is gone, and cascades cleaned children + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release WHERE id = 101") + assert cur.fetchone()[0] == 0 + cur.execute("SELECT count(*) FROM release_artist WHERE release_id = 101") + assert cur.fetchone()[0] == 0 + cur.execute("SELECT count(*) FROM release_track WHERE release_id = 101") + assert cur.fetchone()[0] == 0 + cur.execute("SELECT count(*) FROM cache_metadata WHERE release_id = 101") + assert cur.fetchone()[0] == 0 + conn.close() + + def test_undeleted_releases_survive(self) -> None: + """Releases not in the prune set are not affected.""" + conn = psycopg.connect(self.db_url) + with conn.cursor() as cur: + cur.execute("SELECT count(*) FROM release WHERE id IN (102, 103)") + count = cur.fetchone()[0] + conn.close() + assert count == 2 + + def test_empty_set_deletes_nothing(self) -> None: + """Empty release set returns zero count and deletes nothing.""" + + async def _run(): + conn = await asyncpg.connect(self.db_url) + try: + return await prune_releases(conn, set()) + finally: + await conn.close() + + result = asyncio.run(_run()) + assert result["release"] == 0 diff --git a/tests/unit/test_artist_splitting.py b/tests/unit/test_artist_splitting.py index 73b98b9..e5f9605 100644 --- a/tests/unit/test_artist_splitting.py +++ b/tests/unit/test_artist_splitting.py @@ -4,7 +4,12 @@ import pytest -from lib.artist_splitting import split_artist_name, split_artist_name_contextual +from lib.artist_splitting import ( + _split_trailing_and, + _try_ampersand_split, + split_artist_name, + split_artist_name_contextual, +) # --------------------------------------------------------------------------- # split_artist_name (context-free) @@ -159,3 +164,64 @@ def test_band_names_not_split(self, name: str) -> None: """Common band name patterns with 'and'/'with' should never be split.""" known = {"sly", "andy human", "my life", "nurse"} assert split_artist_name_contextual(name, known) == [] + + +# --------------------------------------------------------------------------- +# _comma_guard (numeric guard) +# --------------------------------------------------------------------------- + + +class TestCommaGuardNumeric: + """The numeric guard in comma splitting prevents splitting artist names with numbers.""" + + @pytest.mark.parametrize( + "name", + [ + "10,000 Maniacs", + "808,303", + "1,000 Homo DJs", + ], + ids=["10000-maniacs", "808-303", "1000-homo-djs"], + ) + def test_numeric_components_block_split(self, name: str) -> None: + assert split_artist_name(name) == [] + + +# --------------------------------------------------------------------------- +# _split_trailing_and (single-component input) +# --------------------------------------------------------------------------- + + +class TestSplitTrailingAnd: + """Direct tests for _split_trailing_and edge cases.""" + + def test_single_component_returns_as_is(self) -> None: + assert _split_trailing_and(["Autechre"]) == ["Autechre"] + + def test_empty_list_returns_as_is(self) -> None: + assert _split_trailing_and([]) == [] + + def test_two_components_with_trailing_and(self) -> None: + assert _split_trailing_and(["Emerson", "and Palmer"]) == ["Emerson", "Palmer"] + + def test_two_components_without_trailing_and(self) -> None: + assert _split_trailing_and(["Emerson", "Palmer"]) == ["Emerson", "Palmer"] + + +# --------------------------------------------------------------------------- +# _try_ampersand_split (edge cases) +# --------------------------------------------------------------------------- + + +class TestTryAmpersandSplit: + """Direct tests for _try_ampersand_split edge cases.""" + + def test_no_ampersand_returns_none(self) -> None: + assert _try_ampersand_split("Autechre", {"autechre"}) is None + + def test_no_known_artist_returns_none(self) -> None: + assert _try_ampersand_split("Simon & Garfunkel", set()) is None + + def test_single_char_components_rejected(self) -> None: + """When all components after filtering are too short, returns None.""" + assert _try_ampersand_split("A & B", {"a"}) is None diff --git a/tests/unit/test_csv_to_tsv.py b/tests/unit/test_csv_to_tsv.py index 1c7c592..52ec62e 100644 --- a/tests/unit/test_csv_to_tsv.py +++ b/tests/unit/test_csv_to_tsv.py @@ -15,6 +15,7 @@ _spec.loader.exec_module(_ct) convert = _ct.convert +main = _ct.main class TestConvert: @@ -90,3 +91,32 @@ def test_empty_csv(self, tmp_path: Path) -> None: assert count == 0 lines = tsv_file.read_text().splitlines() assert lines == ["h"] + + +class TestMain: + """Tests for the main() entry point.""" + + def test_wrong_arg_count_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["csv_to_tsv.py"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_too_many_args_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["csv_to_tsv.py", "a", "b", "c"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_happy_path(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + csv_file = tmp_path / "input.csv" + tsv_file = tmp_path / "output.tsv" + csv_file.write_text("name,value\nalpha,1\nbeta,2\n") + + monkeypatch.setattr("sys.argv", ["csv_to_tsv.py", str(csv_file), str(tsv_file)]) + main() + + lines = tsv_file.read_text().splitlines() + assert lines[0] == "name\tvalue" + assert lines[1] == "alpha\t1" + assert lines[2] == "beta\t2" diff --git a/tests/unit/test_enrich_library_artists.py b/tests/unit/test_enrich_library_artists.py index cdb245a..a2e107e 100644 --- a/tests/unit/test_enrich_library_artists.py +++ b/tests/unit/test_enrich_library_artists.py @@ -5,6 +5,7 @@ import importlib.util import sys from pathlib import Path +from unittest.mock import MagicMock, patch import pytest @@ -17,6 +18,9 @@ _spec.loader.exec_module(_mod) extract_base_artists = _mod.extract_base_artists +extract_alternate_names = _mod.extract_alternate_names +extract_cross_referenced_artists = _mod.extract_cross_referenced_artists +extract_release_cross_ref_artists = _mod.extract_release_cross_ref_artists merge_and_write = _mod.merge_and_write parse_args = _mod.parse_args @@ -302,3 +306,219 @@ def test_compilation_components_excluded(self, tmp_path: Path) -> None: lines = set(output.read_text().splitlines()) assert "Juana Molina" in lines assert "Various Artists" not in lines + + +# --------------------------------------------------------------------------- +# extract_alternate_names (mocked MySQL) +# --------------------------------------------------------------------------- + + +class TestExtractAlternateNames: + """extract_alternate_names() queries LIBRARY_RELEASE for alternate artist names.""" + + def test_returns_alternate_names(self) -> None: + """Mock cursor returns sample alternate artist names.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([("Body Count",), ("Ice Cube",)])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_alternate_names(mock_conn) + assert result == {"Body Count", "Ice Cube"} + + def test_empty_result(self) -> None: + """Empty cursor returns empty set.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_alternate_names(mock_conn) + assert result == set() + + def test_strips_whitespace_and_skips_empty(self) -> None: + """Whitespace-only and None values are excluded.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([(" Stereolab ",), ("",), (None,)])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_alternate_names(mock_conn) + assert result == {"Stereolab"} + + +# --------------------------------------------------------------------------- +# extract_cross_referenced_artists (mocked MySQL) +# --------------------------------------------------------------------------- + + +class TestExtractCrossReferencedArtists: + """extract_cross_referenced_artists() queries LIBRARY_CODE_CROSS_REFERENCE.""" + + def test_returns_cross_referenced_names(self) -> None: + """Mock cursor returns UNION query results.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock( + return_value=iter([("Cat Power",), ("Chan Marshall",), ("Cat Power",)]) + ) + mock_conn.cursor.return_value = mock_cursor + + result = extract_cross_referenced_artists(mock_conn) + # Duplicates from UNION are already handled by the set + assert "Cat Power" in result + assert "Chan Marshall" in result + + def test_deduplication(self) -> None: + """Duplicate names across UNION branches produce unique results.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock( + return_value=iter([("Jessica Pratt",), ("Jessica Pratt",), ("Chuquimamani-Condori",)]) + ) + mock_conn.cursor.return_value = mock_cursor + + result = extract_cross_referenced_artists(mock_conn) + assert len(result) == 2 + assert result == {"Jessica Pratt", "Chuquimamani-Condori"} + + +# --------------------------------------------------------------------------- +# extract_release_cross_ref_artists (mocked MySQL) +# --------------------------------------------------------------------------- + + +class TestExtractReleaseCrossRefArtists: + """extract_release_cross_ref_artists() queries RELEASE_CROSS_REFERENCE.""" + + def test_returns_release_cross_ref_names(self) -> None: + """Mock cursor returns cross-reference artist names.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([("Juana Molina",), ("Sessa",)])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_release_cross_ref_artists(mock_conn) + assert result == {"Juana Molina", "Sessa"} + + def test_empty_result(self) -> None: + """Empty cursor returns empty set.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_cursor.__iter__ = MagicMock(return_value=iter([])) + mock_conn.cursor.return_value = mock_cursor + + result = extract_release_cross_ref_artists(mock_conn) + assert result == set() + + +# --------------------------------------------------------------------------- +# main (mocked) +# --------------------------------------------------------------------------- + + +class TestMain: + """main() orchestrates extraction and merge.""" + + def test_with_library_db_only(self, tmp_path) -> None: + """With --library-db only (no MySQL), base artists are extracted and written.""" + library_db = tmp_path / "library.db" + # Create a minimal SQLite library.db + import sqlite3 + + conn = sqlite3.connect(library_db) + conn.execute("CREATE TABLE library (artist TEXT, title TEXT)") + conn.execute("INSERT INTO library VALUES ('Stereolab', 'Aluminum Tunes')") + conn.execute("INSERT INTO library VALUES ('Cat Power', 'Moon Pix')") + conn.commit() + conn.close() + + output = tmp_path / "artists.txt" + + with patch.object( + _mod, + "parse_args", + return_value=parse_args(["--library-db", str(library_db), "--output", str(output)]), + ): + _mod.main() + + lines = set(output.read_text().splitlines()) + assert "Stereolab" in lines + assert "Cat Power" in lines + + def test_with_library_db_and_wxyc_db_url(self, tmp_path) -> None: + """With --library-db and --wxyc-db-url, MySQL enrichment is performed.""" + library_db = tmp_path / "library.db" + import sqlite3 + + conn = sqlite3.connect(library_db) + conn.execute("CREATE TABLE library (artist TEXT, title TEXT)") + conn.execute("INSERT INTO library VALUES ('Stereolab', 'Aluminum Tunes')") + conn.commit() + conn.close() + + output = tmp_path / "artists.txt" + mock_mysql_conn = MagicMock() + + with ( + patch.object( + _mod, + "parse_args", + return_value=parse_args( + [ + "--library-db", + str(library_db), + "--output", + str(output), + "--wxyc-db-url", + "mysql://user:pass@host/db", + ] + ), + ), + patch.object(_mod, "connect_mysql", return_value=mock_mysql_conn), + patch.object(_mod, "extract_alternate_names", return_value={"Nourished by Time"}), + patch.object(_mod, "extract_cross_referenced_artists", return_value={"Buck Meek"}), + patch.object(_mod, "extract_release_cross_ref_artists", return_value={"Sessa"}), + ): + _mod.main() + + lines = set(output.read_text().splitlines()) + assert "Stereolab" in lines + assert "Nourished by Time" in lines + assert "Buck Meek" in lines + assert "Sessa" in lines + mock_mysql_conn.close.assert_called_once() + + def test_missing_library_db_exits(self, tmp_path) -> None: + """Non-existent library.db triggers sys.exit(1).""" + output = tmp_path / "artists.txt" + with ( + patch.object( + _mod, + "parse_args", + return_value=parse_args( + [ + "--library-db", + str(tmp_path / "missing.db"), + "--output", + str(output), + ] + ), + ), + pytest.raises(SystemExit, match="1"), + ): + _mod.main() diff --git a/tests/unit/test_export.py b/tests/unit/test_export.py index a357d54..2aabeab 100644 --- a/tests/unit/test_export.py +++ b/tests/unit/test_export.py @@ -13,6 +13,7 @@ _mod = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(_mod) _do_export = _mod._do_export +format_size = _mod.format_size _OUTPUT_PATH_ATTR = "OUTPUT_PATH" @@ -168,3 +169,91 @@ def test_mixed_rows_with_and_without_alternate(self): assert len(results) == 2 assert results[0] == (1, "Plug") assert results[1] == (2, None) + + def test_creates_correct_tables_and_row_count(self): + """Verify the exported database has the correct tables, FTS index, and row count.""" + rows = [ + { + "id": "1", + "title": "DOGA", + "artist": "Juana Molina", + "call_letters": "M", + "artist_call_number": "42", + "release_call_number": "1", + "genre": "Rock", + "format": "LP", + "alternate_artist_name": None, + }, + { + "id": "2", + "title": "Aluminum Tunes", + "artist": "Stereolab", + "call_letters": "S", + "artist_call_number": "88", + "release_call_number": "1", + "genre": "Rock", + "format": "CD", + "alternate_artist_name": None, + }, + { + "id": "3", + "title": "Moon Pix", + "artist": "Cat Power", + "call_letters": "C", + "artist_call_number": "7", + "release_call_number": "1", + "genre": "Rock", + "format": "LP", + "alternate_artist_name": None, + }, + ] + _do_export(rows) + + conn = sqlite3.connect(self.output_path) + + # Verify table exists + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='library'" + ) + assert cursor.fetchone() is not None + + # Verify FTS table exists + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='library_fts'" + ) + assert cursor.fetchone() is not None + + # Verify row count + cursor = conn.execute("SELECT COUNT(*) FROM library") + assert cursor.fetchone()[0] == 3 + + # Verify FTS works + cursor = conn.execute(""" + SELECT l.id FROM library l + JOIN library_fts fts ON l.id = fts.rowid + WHERE library_fts MATCH 'Stereolab' + """) + results = cursor.fetchall() + assert len(results) == 1 + assert results[0][0] == 2 + + conn.close() + + +class TestFormatSize: + """Test human-readable size formatting.""" + + @pytest.mark.parametrize( + "size_bytes, expected", + [ + (0, "0.0 B"), + (1023, "1023.0 B"), + (1024, "1.0 KB"), + (1048576, "1.0 MB"), + (1073741824, "1.0 GB"), + (1099511627776, "1.0 TB"), + ], + ids=["zero", "bytes", "kilobytes", "megabytes", "gigabytes", "terabytes"], + ) + def test_format_size(self, size_bytes: int, expected: str) -> None: + assert format_size(size_bytes) == expected diff --git a/tests/unit/test_filter_csv.py b/tests/unit/test_filter_csv.py index 50c686c..df22e2a 100644 --- a/tests/unit/test_filter_csv.py +++ b/tests/unit/test_filter_csv.py @@ -20,6 +20,7 @@ find_matching_release_ids = _fc.find_matching_release_ids filter_csv_file = _fc.filter_csv_file get_release_id_column = _fc.get_release_id_column +main = _fc.main FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" @@ -246,3 +247,185 @@ def test_missing_id_column_raises_clear_error(self, tmp_path: Path) -> None: with pytest.raises(ValueError, match="Column 'nonexistent'.*not found"): filter_csv_file(input_path, output_path, {1001}, "nonexistent") + + def test_row_with_invalid_release_id_skipped(self, tmp_path: Path) -> None: + """Rows where the release_id is not a valid integer are silently skipped.""" + csv_path = tmp_path / "release.csv" + output_path = tmp_path / "out.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["id", "title"]) + writer.writerow(["abc", "Bad ID"]) + writer.writerow(["1001", "Good ID"]) + + input_count, output_count = filter_csv_file(csv_path, output_path, {1001}, "id") + assert input_count == 2 + assert output_count == 1 + + def test_short_row_skipped(self, tmp_path: Path) -> None: + """Rows shorter than expected (IndexError on id column) are silently skipped.""" + csv_path = tmp_path / "release.csv" + output_path = tmp_path / "out.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["id", "title", "country"]) + # Normal row + writer.writerow(["1001", "DOGA", "AR"]) + # Write a short row manually (fewer columns than header) + f.write('"short"\n') + + input_count, output_count = filter_csv_file(csv_path, output_path, {1001}, "id") + assert input_count == 2 + assert output_count == 1 + + +class TestFindMatchingReleaseIdsEdgeCases: + """Edge cases for find_matching_release_ids.""" + + def test_short_row_skipped(self, tmp_path: Path) -> None: + """Rows missing the artist_name column are silently skipped.""" + csv_path = tmp_path / "release_artist.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["release_id", "artist_id", "artist_name", "extra", "anv", "position"]) + writer.writerow(["1001", "101", "Juana Molina", "0", "", "1"]) + # Short row missing artist_name + f.write('"2001","201"\n') + + ids = find_matching_release_ids(csv_path, {"juana molina"}) + assert ids == {1001} + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + + +class TestMain: + """Tests for the main() entry point.""" + + def test_wrong_arg_count_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["filter_csv.py"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_missing_library_artists_exits( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr( + "sys.argv", + [ + "filter_csv.py", + str(tmp_path / "nonexistent.txt"), + str(tmp_path), + str(tmp_path / "out"), + ], + ) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_missing_csv_dir_exits(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + artists_file = tmp_path / "artists.txt" + artists_file.write_text("Juana Molina\n") + monkeypatch.setattr( + "sys.argv", + ["filter_csv.py", str(artists_file), str(tmp_path / "nope"), str(tmp_path / "out")], + ) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_missing_release_artist_csv_exits( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + artists_file = tmp_path / "artists.txt" + artists_file.write_text("Juana Molina\n") + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + monkeypatch.setattr( + "sys.argv", + ["filter_csv.py", str(artists_file), str(csv_dir), str(tmp_path / "out")], + ) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_no_matches_exits(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + artists_file = tmp_path / "artists.txt" + artists_file.write_text("Nonexistent Artist XYZ\n") + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + with open(csv_dir / "release_artist.csv", "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["release_id", "artist_id", "artist_name", "extra", "anv", "position"]) + writer.writerow(["1001", "101", "Juana Molina", "0", "", "1"]) + + monkeypatch.setattr( + "sys.argv", + ["filter_csv.py", str(artists_file), str(csv_dir), str(tmp_path / "out")], + ) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_happy_path(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + artists_file = tmp_path / "artists.txt" + artists_file.write_text("Juana Molina\n") + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + out_dir = tmp_path / "out" + + with open(csv_dir / "release_artist.csv", "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["release_id", "artist_id", "artist_name", "extra", "anv", "position"]) + writer.writerow(["5001", "101", "Juana Molina", "0", "", "1"]) + writer.writerow(["5002", "102", "Stereolab", "0", "", "1"]) + + with open(csv_dir / "release.csv", "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "id", + "status", + "title", + "country", + "released", + "notes", + "data_quality", + "master_id", + "format", + ] + ) + writer.writerow( + ["5001", "Accepted", "DOGA", "AR", "2024-05-10", "", "Correct", "8001", "LP"] + ) + writer.writerow( + [ + "5002", + "Accepted", + "Aluminum Tunes", + "UK", + "1998-09-01", + "", + "Correct", + "8002", + "CD", + ] + ) + + monkeypatch.setattr( + "sys.argv", + ["filter_csv.py", str(artists_file), str(csv_dir), str(out_dir)], + ) + main() + + assert out_dir.exists() + with open(out_dir / "release.csv") as f: + reader = csv.DictReader(f) + rows = list(reader) + assert len(rows) == 1 + assert rows[0]["id"] == "5001" diff --git a/tests/unit/test_fix_csv_newlines.py b/tests/unit/test_fix_csv_newlines.py index df68e8b..772f665 100644 --- a/tests/unit/test_fix_csv_newlines.py +++ b/tests/unit/test_fix_csv_newlines.py @@ -16,6 +16,8 @@ _spec.loader.exec_module(_fn) fix_csv = _fn.fix_csv +fix_csv_dir = _fn.fix_csv_dir +main = _fn.main class TestFixCsv: @@ -101,3 +103,81 @@ def test_edge_cases(self, tmp_path: Path, field: str, expected: str) -> None: rows = self._read_csv(output_path) assert rows[0][0] == expected + + +class TestFixCsvDir: + """Batch-processing a directory of CSV files.""" + + def _write_csv(self, path: Path, headers: list[str], rows: list[list[str]]) -> None: + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(headers) + writer.writerows(rows) + + def test_processes_all_csv_files(self, tmp_path: Path) -> None: + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + input_dir.mkdir() + + self._write_csv(input_dir / "a.csv", ["text"], [["line\none"]]) + self._write_csv(input_dir / "b.csv", ["text"], [["hello\nworld"]]) + + fix_csv_dir(input_dir, output_dir) + + assert (output_dir / "a.csv").exists() + assert (output_dir / "b.csv").exists() + + with open(output_dir / "a.csv") as f: + reader = csv.reader(f) + next(reader) + assert next(reader)[0] == "line one" + + with open(output_dir / "b.csv") as f: + reader = csv.reader(f) + next(reader) + assert next(reader)[0] == "hello world" + + def test_empty_dir_logs_warning(self, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: + import logging + + input_dir = tmp_path / "empty" + output_dir = tmp_path / "output" + input_dir.mkdir() + + with caplog.at_level(logging.WARNING): + fix_csv_dir(input_dir, output_dir) + + assert any("No .csv files found" in r.message for r in caplog.records) + + +class TestMain: + """Tests for the main() entry point.""" + + def test_wrong_arg_count_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["fix_csv_newlines.py"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_too_many_args_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.argv", ["fix_csv_newlines.py", "a", "b", "c"]) + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_happy_path(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + input_path = tmp_path / "input.csv" + output_path = tmp_path / "output.csv" + + with open(input_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["text"]) + writer.writerow(["line\none"]) + + monkeypatch.setattr("sys.argv", ["fix_csv_newlines.py", str(input_path), str(output_path)]) + main() + + with open(output_path) as f: + reader = csv.reader(f) + next(reader) + assert next(reader)[0] == "line one" diff --git a/tests/unit/test_import_csv.py b/tests/unit/test_import_csv.py index d8c84c3..635613e 100644 --- a/tests/unit/test_import_csv.py +++ b/tests/unit/test_import_csv.py @@ -392,3 +392,106 @@ def test_no_header_returns_zero(self, tmp_path, caplog) -> None: assert count == 0 assert "No header" in caplog.text + + +# --------------------------------------------------------------------------- +# main() argument parsing and dispatch +# --------------------------------------------------------------------------- + + +class TestMainArgParsing: + """import_csv.py main() validates args and dispatches to correct import mode.""" + + def test_missing_csv_dir_exits(self, tmp_path) -> None: + """Non-existent CSV directory triggers sys.exit(1).""" + from unittest.mock import patch + + with ( + patch("sys.argv", ["import_csv.py", str(tmp_path / "missing_csv")]), + pytest.raises(SystemExit, match="1"), + ): + _ic.main() + + def test_default_mode_calls_import_tables(self, tmp_path) -> None: + """Default mode (no flags) calls _import_tables for all tables.""" + from unittest.mock import MagicMock, patch + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + mock_conn = MagicMock() + + with ( + patch("sys.argv", ["import_csv.py", str(csv_dir), "postgresql:///test"]), + patch.object(_ic.psycopg, "connect", return_value=mock_conn), + patch.object(_ic, "_import_tables", return_value=100) as mock_import, + patch.object(_ic, "import_artwork", return_value=10), + patch.object(_ic, "populate_cache_metadata", return_value=50), + ): + _ic.main() + + mock_import.assert_called_once() + call_args = mock_import.call_args + assert call_args[0][0] is mock_conn + assert call_args[0][1] == csv_dir + assert call_args[0][2] == TABLES + + def test_base_only_mode_calls_parallel(self, tmp_path) -> None: + """--base-only mode calls _import_tables_parallel with base tables.""" + from unittest.mock import MagicMock, patch + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + mock_conn = MagicMock() + + with ( + patch( + "sys.argv", + ["import_csv.py", "--base-only", str(csv_dir), "postgresql:///test"], + ), + patch.object(_ic.psycopg, "connect", return_value=mock_conn), + patch.object(_ic, "_import_tables_parallel", return_value=100) as mock_parallel, + patch.object(_ic, "import_artwork", return_value=10), + patch.object(_ic, "populate_cache_metadata", return_value=50), + patch.object(_ic, "create_track_count_table", return_value=20), + ): + _ic.main() + + mock_parallel.assert_called_once() + call_args = mock_parallel.call_args + assert call_args[0][0] == "postgresql:///test" + assert call_args[0][1] == csv_dir + # Parent tables are BASE_TABLES[:1], child tables are BASE_TABLES[1:] + assert call_args[1]["parent_tables"] == BASE_TABLES[:1] + assert call_args[1]["child_tables"] == BASE_TABLES[1:] + + def test_tracks_only_mode_uses_release_id_filter(self, tmp_path) -> None: + """--tracks-only mode queries release IDs and filters track import.""" + from unittest.mock import MagicMock, patch + + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [(5001,), (5002,), (5003,)] + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch( + "sys.argv", + ["import_csv.py", "--tracks-only", str(csv_dir), "postgresql:///test"], + ), + patch.object(_ic.psycopg, "connect", return_value=mock_conn), + patch.object(_ic, "_import_tables_parallel", return_value=200) as mock_parallel, + ): + _ic.main() + + mock_parallel.assert_called_once() + call_args = mock_parallel.call_args + # --tracks-only passes empty parent_tables and TRACK_TABLES as children + assert call_args[1]["parent_tables"] == [] + assert call_args[1]["child_tables"] == TRACK_TABLES + assert call_args[1]["release_id_filter"] == {5001, 5002, 5003} diff --git a/tests/unit/test_matching.py b/tests/unit/test_matching.py new file mode 100644 index 0000000..f299076 --- /dev/null +++ b/tests/unit/test_matching.py @@ -0,0 +1,43 @@ +"""Unit tests for lib/matching.py.""" + +from __future__ import annotations + +import pytest + +from lib.matching import is_compilation_artist + + +class TestIsCompilationArtist: + """Compilation artist detection.""" + + @pytest.mark.parametrize( + "artist, expected", + [ + ("Various Artists", True), + ("various", True), + ("Soundtrack", True), + ("Original Motion Picture Soundtrack", True), + ("V/A", True), + ("v.a.", True), + ("Compilation Hits", True), + ("Stereolab", False), + ("Juana Molina", False), + ("Cat Power", False), + ("", False), + ], + ids=[ + "various-artists", + "various-lowercase", + "soundtrack", + "soundtrack-in-phrase", + "v-slash-a", + "v-dot-a", + "compilation-keyword", + "stereolab", + "juana-molina", + "cat-power", + "empty-string", + ], + ) + def test_is_compilation_artist(self, artist: str, expected: bool) -> None: + assert is_compilation_artist(artist) == expected diff --git a/tests/unit/test_run_pipeline.py b/tests/unit/test_run_pipeline.py index 5e618a5..adf8b03 100644 --- a/tests/unit/test_run_pipeline.py +++ b/tests/unit/test_run_pipeline.py @@ -3,10 +3,11 @@ from __future__ import annotations import importlib.util +import json import logging import sys from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -492,3 +493,455 @@ def fake_convert(xml, output_dir, converter, library_artists=None): assert convert_calls[0][3] is not None, ( "library_artists path should be passed to convert_and_filter" ) + + +# --------------------------------------------------------------------------- +# wait_for_postgres +# --------------------------------------------------------------------------- + + +class TestWaitForPostgres: + """wait_for_postgres() polls until Postgres is ready or times out.""" + + def test_success_on_first_try(self) -> None: + """Successful connection on the first attempt returns immediately.""" + mock_conn = MagicMock() + with patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn): + run_pipeline.wait_for_postgres("postgresql:///test") + mock_conn.close.assert_called_once() + + def test_retry_then_success(self) -> None: + """First call raises OperationalError, second succeeds.""" + mock_conn = MagicMock() + with ( + patch.object( + run_pipeline.psycopg, + "connect", + side_effect=[run_pipeline.psycopg.OperationalError("refused"), mock_conn], + ), + patch.object(run_pipeline.time, "sleep"), + ): + run_pipeline.wait_for_postgres("postgresql:///test") + mock_conn.close.assert_called_once() + + def test_timeout_exits(self) -> None: + """All connection attempts fail and timeout is exceeded -> sys.exit(1).""" + # monotonic returns: first call sets deadline, subsequent calls exceed it + with ( + patch.object( + run_pipeline.psycopg, + "connect", + side_effect=run_pipeline.psycopg.OperationalError("refused"), + ), + patch.object(run_pipeline.time, "monotonic", side_effect=[0.0, 100.0]), + patch.object(run_pipeline.time, "sleep"), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.wait_for_postgres("postgresql:///test") + + +# --------------------------------------------------------------------------- +# run_sql_file +# --------------------------------------------------------------------------- + + +class TestRunSqlFile: + """run_sql_file() executes SQL from a file against the database.""" + + def test_happy_path(self, tmp_path) -> None: + """SQL file contents are executed via cursor.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE t (id int)") + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn): + run_pipeline.run_sql_file("postgresql:///test", sql_file) + + mock_cursor.execute.assert_called_once_with("CREATE TABLE t (id int)") + mock_conn.close.assert_called_once() + + def test_sql_error_exits(self, tmp_path) -> None: + """psycopg.Error during execution triggers sys.exit(1).""" + sql_file = tmp_path / "bad.sql" + sql_file.write_text("INVALID SQL") + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute.side_effect = run_pipeline.psycopg.Error("syntax error") + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.run_sql_file("postgresql:///test", sql_file) + mock_conn.close.assert_called() + + def test_strip_concurrently_removes_keyword(self, tmp_path) -> None: + """strip_concurrently=True removes CONCURRENTLY from SQL.""" + sql_file = tmp_path / "indexes.sql" + sql_file.write_text("CREATE INDEX CONCURRENTLY idx_a ON t(a)") + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn): + run_pipeline.run_sql_file("postgresql:///test", sql_file, strip_concurrently=True) + + executed_sql = mock_cursor.execute.call_args[0][0] + assert "CONCURRENTLY" not in executed_sql + assert "CREATE INDEX idx_a ON t(a)" == executed_sql + + +# --------------------------------------------------------------------------- +# run_sql_statements_parallel — error propagation +# --------------------------------------------------------------------------- + + +class TestRunSqlStatementsParallelError: + """Test that psycopg.Error from a parallel statement is re-raised.""" + + def test_psycopg_error_is_reraised(self) -> None: + """A psycopg.Error in a parallel statement propagates to the caller.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute.side_effect = run_pipeline.psycopg.Error("disk full") + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn), + pytest.raises(run_pipeline.psycopg.Error, match="disk full"), + ): + run_sql_statements_parallel("postgresql:///test", ["CREATE INDEX idx_x ON t(x)"]) + + +# --------------------------------------------------------------------------- +# report_sizes +# --------------------------------------------------------------------------- + + +class TestReportSizes: + """report_sizes() queries pg_stat_user_tables and logs results.""" + + def test_logs_table_sizes(self, caplog) -> None: + """Fetched rows are logged with table names and row counts.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [ + ("release", 50000, "120 MB"), + ("release_artist", 80000, "45 MB"), + ] + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch.object(run_pipeline.psycopg, "connect", return_value=mock_conn), + caplog.at_level(logging.INFO, logger=run_pipeline.logger.name), + ): + run_pipeline.report_sizes("postgresql:///test") + + mock_cursor.execute.assert_called_once() + logged = [r.message for r in caplog.records] + assert any("release" in msg and "50,000" in msg for msg in logged) + assert any("release_artist" in msg and "80,000" in msg for msg in logged) + mock_conn.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# convert_and_filter +# --------------------------------------------------------------------------- + + +class TestConvertAndFilter: + """convert_and_filter() constructs the converter command and delegates to run_step.""" + + def test_command_with_library_artists(self) -> None: + """Command includes --library-artists when provided.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.convert_and_filter( + Path("/data/releases.xml.gz"), + Path("/tmp/csv"), + "discogs-xml-converter", + library_artists=Path("/data/library_artists.txt"), + ) + + mock_run.assert_called_once() + cmd = mock_run.call_args[0][1] + assert cmd[0] == "discogs-xml-converter" + assert "/data/releases.xml.gz" in cmd + assert "--output-dir" in cmd + assert "--library-artists" in cmd + assert "/data/library_artists.txt" in cmd + + def test_command_with_database_url(self) -> None: + """Command includes --database-url for direct-PG mode.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.convert_and_filter( + Path("/data/releases.xml.gz"), + Path("/tmp/csv"), + "discogs-xml-converter", + database_url="postgresql:///discogs", + ) + + cmd = mock_run.call_args[0][1] + assert "--database-url" in cmd + assert "postgresql:///discogs" in cmd + # Description mentions PostgreSQL + description = mock_run.call_args[0][0] + assert "PostgreSQL" in description + + def test_command_without_optional_args(self) -> None: + """Command omits --library-artists and --database-url when not provided.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.convert_and_filter( + Path("/data/releases.xml.gz"), + Path("/tmp/csv"), + "discogs-xml-converter", + ) + + cmd = mock_run.call_args[0][1] + assert "--library-artists" not in cmd + assert "--database-url" not in cmd + description = mock_run.call_args[0][0] + assert "CSV" in description + + +# --------------------------------------------------------------------------- +# enrich_library_artists (orchestrator wrapper) +# --------------------------------------------------------------------------- + + +class TestEnrichLibraryArtists: + """enrich_library_artists() constructs the enrichment command.""" + + def test_command_with_wxyc_db_url(self) -> None: + """Command includes --wxyc-db-url when provided.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.enrich_library_artists( + Path("/data/library.db"), + Path("/tmp/library_artists.txt"), + wxyc_db_url="mysql://user:pass@host/db", + ) + + cmd = mock_run.call_args[0][1] + assert "--library-db" in cmd + assert "/data/library.db" in cmd + assert "--output" in cmd + assert "/tmp/library_artists.txt" in cmd + assert "--wxyc-db-url" in cmd + assert "mysql://user:pass@host/db" in cmd + + def test_command_without_wxyc_db_url(self) -> None: + """Command omits --wxyc-db-url when not provided.""" + with patch.object(run_pipeline, "run_step") as mock_run: + run_pipeline.enrich_library_artists( + Path("/data/library.db"), + Path("/tmp/library_artists.txt"), + ) + + cmd = mock_run.call_args[0][1] + assert "--library-db" in cmd + assert "--output" in cmd + assert "--wxyc-db-url" not in cmd + + +# --------------------------------------------------------------------------- +# _load_or_create_state +# --------------------------------------------------------------------------- + + +class TestLoadOrCreateState: + """_load_or_create_state() handles resume modes.""" + + def test_resume_with_existing_state_file(self, tmp_path) -> None: + """When --resume and state file exists, load it.""" + state_file = tmp_path / "state.json" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + # Write a valid state file + state_data = { + "version": 3, + "database_url": "postgresql:///test", + "csv_dir": str(csv_dir.resolve()), + "steps": {s: {"status": "pending"} for s in run_pipeline.STEP_NAMES}, + } + state_data["steps"]["create_schema"] = {"status": "completed"} + state_file.write_text(json.dumps(state_data)) + + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--resume", + "--state-file", + str(state_file), + "--database-url", + "postgresql:///test", + ] + ) + + state = run_pipeline._load_or_create_state(args) + assert state.is_completed("create_schema") + assert not state.is_completed("import_csv") + + def test_resume_without_state_file_uses_db_introspect(self, tmp_path) -> None: + """When --resume but no state file, infer from database.""" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + state_file = tmp_path / "nonexistent_state.json" + + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--resume", + "--state-file", + str(state_file), + "--database-url", + "postgresql:///test", + ] + ) + + mock_state = run_pipeline.PipelineState(db_url="postgresql:///test", csv_dir="") + mock_state.mark_completed("create_schema") + + with patch("lib.db_introspect.infer_pipeline_state", return_value=mock_state): + state = run_pipeline._load_or_create_state(args) + + assert state.is_completed("create_schema") + assert state.csv_dir == str(csv_dir.resolve()) + + def test_fresh_state_no_resume(self, tmp_path) -> None: + """Without --resume, create a fresh PipelineState.""" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--database-url", + "postgresql:///test", + ] + ) + + state = run_pipeline._load_or_create_state(args) + assert not any(state.is_completed(s) for s in run_pipeline.STEP_NAMES) + assert state.db_url == "postgresql:///test" + + +# --------------------------------------------------------------------------- +# main() — input validation +# --------------------------------------------------------------------------- + + +class TestMainValidation: + """main() validates file paths before running the pipeline.""" + + def test_missing_xml_file_exits(self, tmp_path) -> None: + """Non-existent XML file triggers sys.exit(1).""" + args = run_pipeline.parse_args(["--xml", str(tmp_path / "missing.xml.gz")]) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + def test_missing_library_artists_file_exits(self, tmp_path) -> None: + """Non-existent library_artists.txt triggers sys.exit(1).""" + xml_file = tmp_path / "releases.xml.gz" + xml_file.touch() + args = run_pipeline.parse_args( + [ + "--xml", + str(xml_file), + "--library-artists", + str(tmp_path / "missing_artists.txt"), + ] + ) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + def test_missing_csv_dir_exits(self, tmp_path) -> None: + """Non-existent CSV directory triggers sys.exit(1).""" + args = run_pipeline.parse_args(["--csv-dir", str(tmp_path / "missing_csv")]) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + def test_missing_library_db_exits(self, tmp_path) -> None: + """Non-existent library.db triggers sys.exit(1).""" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--library-db", + str(tmp_path / "missing_library.db"), + ] + ) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + def test_missing_library_labels_exits(self, tmp_path) -> None: + """Non-existent library_labels.csv triggers sys.exit(1).""" + csv_dir = tmp_path / "csv" + csv_dir.mkdir() + args = run_pipeline.parse_args( + [ + "--csv-dir", + str(csv_dir), + "--library-labels", + str(tmp_path / "missing_labels.csv"), + ] + ) + with ( + patch.object(run_pipeline, "parse_args", return_value=args), + pytest.raises(SystemExit, match="1"), + ): + run_pipeline.main() + + +# --------------------------------------------------------------------------- +# parse_args — additional validation +# --------------------------------------------------------------------------- + + +class TestParseArgsValidation: + """Additional argument validation in parse_args.""" + + def test_direct_pg_without_xml_exits(self) -> None: + """--direct-pg without --xml triggers parser.error (sys.exit(2)).""" + with pytest.raises(SystemExit): + run_pipeline.parse_args(["--csv-dir", "/tmp/csv", "--direct-pg"]) + + def test_generate_library_db_with_library_db_exits(self) -> None: + """--generate-library-db and --library-db are mutually exclusive.""" + with pytest.raises(SystemExit): + run_pipeline.parse_args( + [ + "--csv-dir", + "/tmp/csv", + "--generate-library-db", + "--library-db", + "/tmp/library.db", + ] + ) diff --git a/tests/unit/test_verify_cache.py b/tests/unit/test_verify_cache.py index 56df5bf..0a13672 100644 --- a/tests/unit/test_verify_cache.py +++ b/tests/unit/test_verify_cache.py @@ -643,6 +643,11 @@ async def test_empty_results(self): # Step 8: Argument Parsing # --------------------------------------------------------------------------- +format_bytes = _vc.format_bytes +ClassificationReport = _vc.ClassificationReport +MatchResult = _vc.MatchResult +print_report = _vc.print_report + classify_all_releases = _vc.classify_all_releases classify_artist_fuzzy = _vc.classify_artist_fuzzy classify_fuzzy_batch = _vc.classify_fuzzy_batch @@ -843,3 +848,132 @@ def test_default_no_copy_to(self, tmp_path): lib_db.touch() args = parse_args([str(lib_db)]) assert args.copy_to is None + + +# --------------------------------------------------------------------------- +# format_bytes +# --------------------------------------------------------------------------- + + +class TestFormatBytes: + """Test human-readable byte formatting.""" + + @pytest.mark.parametrize( + "num_bytes, expected", + [ + (0, "0.0 B"), + (1023, "1023.0 B"), + (1024, "1.0 KB"), + (1048576, "1.0 MB"), + (1073741824, "1.0 GB"), + (1099511627776, "1.0 TB"), + ], + ids=["zero", "bytes", "kilobytes", "megabytes", "gigabytes", "terabytes"], + ) + def test_format_bytes(self, num_bytes: int, expected: str) -> None: + assert format_bytes(num_bytes) == expected + + +# --------------------------------------------------------------------------- +# print_report +# --------------------------------------------------------------------------- + + +class TestPrintReport: + """Test the print_report function with mock data.""" + + def test_basic_report(self, sample_index, capsys: pytest.CaptureFixture[str]) -> None: + report = ClassificationReport( + keep_ids={1, 2, 3}, + prune_ids={4, 5}, + review_ids=set(), + review_by_artist={}, + artist_originals={}, + total_releases=5, + ) + + print_report(report, sample_index) + + captured = capsys.readouterr() + assert "VERIFICATION REPORT" in captured.out + assert "KEEP:" in captured.out + assert "PRUNE:" in captured.out + assert "3" in captured.out # keep count + assert "2" in captured.out # prune count + + def test_report_with_table_sizes( + self, sample_index, capsys: pytest.CaptureFixture[str] + ) -> None: + report = ClassificationReport( + keep_ids={1, 2}, + prune_ids={3, 4, 5}, + review_ids=set(), + review_by_artist={}, + artist_originals={}, + total_releases=5, + ) + table_sizes = { + "release": (100, 1048576), + "release_artist": (200, 2097152), + "release_label": (150, 524288), + "release_track": (500, 4194304), + "release_track_artist": (300, 1048576), + "cache_metadata": (100, 262144), + } + rows_to_delete = { + "release": 60, + "release_artist": 120, + "release_label": 90, + "release_track": 300, + "release_track_artist": 180, + "cache_metadata": 60, + } + + print_report(report, sample_index, table_sizes=table_sizes, rows_to_delete=rows_to_delete) + + captured = capsys.readouterr() + assert "Database size" in captured.out + assert "Estimated savings" in captured.out + assert "release_track" in captured.out + + def test_pruned_report(self, sample_index, capsys: pytest.CaptureFixture[str]) -> None: + report = ClassificationReport( + keep_ids={1, 2}, + prune_ids={3}, + review_ids=set(), + review_by_artist={}, + artist_originals={}, + total_releases=3, + ) + + print_report(report, sample_index, pruned=True) + + captured = capsys.readouterr() + assert "PRUNING REPORT" in captured.out + assert "Releases kept:" in captured.out + assert "Releases pruned:" in captured.out + + def test_report_with_review_artists( + self, sample_index, capsys: pytest.CaptureFixture[str] + ) -> None: + match_result = MatchResult( + decision=Decision.REVIEW, + exact_score=0.0, + token_set_score=0.70, + token_sort_score=0.65, + two_stage_score=0.60, + ) + report = ClassificationReport( + keep_ids={1}, + prune_ids={2}, + review_ids={3}, + review_by_artist={"some artist": [(3, "Some Album", match_result)]}, + artist_originals={"some artist": "Some Artist"}, + total_releases=3, + ) + + print_report(report, sample_index) + + captured = capsys.readouterr() + assert "REVIEW" in captured.out + assert "artist-level decisions needed" in captured.out