From 79fc42b4108e24388eb4ed5ad8e722459d733459 Mon Sep 17 00:00:00 2001 From: Robert Keller Date: Mon, 28 Apr 2025 22:34:26 -0400 Subject: [PATCH 1/2] Refactor for easier testing --- src/arinc.py | 65 +++++++++++++++++++++++++++++++++++++++ src/config.py | 3 -- src/database.py | 4 +-- src/main.py | 76 ++++------------------------------------------ src/record_maps.py | 23 +++++++++++--- 5 files changed, 92 insertions(+), 79 deletions(-) create mode 100644 src/arinc.py diff --git a/src/arinc.py b/src/arinc.py new file mode 100644 index 0000000..803ce1c --- /dev/null +++ b/src/arinc.py @@ -0,0 +1,65 @@ +from rich.progress import track +from database import DbConfig +from record_maps import record_maps + + +class ArincRecord: + def __init__(self, record_map: dict): + self.section: int | None = record_map.get("section_code") + self.subsection: int | None = record_map.get("subsection_code") + self.section_pos: int | None = record_map.get("section_pos") + self.subsection_pos: int | None = record_map.get("subsection_pos") + self.cont_rec_pos: int | None = record_map.get("cont_rec_pos") + self.cont_rec_vals: list[str] = record_map.get("cont_rec_vals", []) + self.name: str = record_map.get("name", "") + self.columns: list[dict] = record_map.get("columns", []) + self.column_names: list[str] = [c["name"] for c in record_map["columns"]] + + +class ArincParser: + def __init__(self, db: DbConfig, file: str): + self.db = db + self.file = file + self.lines = self.read_file() + self.cycle = self.get_cycle() + self.schema = f"cycle{self.cycle}" + + def read_file(self) -> list[str]: + with open(self.file) as file: + return file.readlines() + + def parse(self) -> None: + self.create_schema() + for record in record_maps: + self.create_arinc_record(record) + + def get_cycle(self) -> str: + return self.lines[0][35:39] + + def create_schema(self) -> None: + self.db.create_schema(self.schema) + + def create_table(self, record: ArincRecord) -> None: + self.db.create_table(self.schema, record.name, record.column_names) + + def add_row(self, name: str, values: list, cycle: str) -> None: + self.db.add_row(self.schema, name, values) + + def create_arinc_record(self, record_map) -> None: + record = ArincRecord(record_map) + + self.create_table(record) + + for line in track(self.lines, description=f"{record.name.rjust(26)}"): + if ( + record.section_pos is not None + and record.subsection_pos is not None + and line[record.section_pos] == record.section + and line[record.subsection_pos] == record.subsection + ): + if ( + not record.cont_rec_pos + or line[record.cont_rec_pos] in record.cont_rec_vals + ): + row = [f"{line[i['start']:i['end']]}" for i in record.columns] + self.add_row(record.name, row, self.cycle) diff --git a/src/config.py b/src/config.py index 0f9a4ee..268d63d 100644 --- a/src/config.py +++ b/src/config.py @@ -19,9 +19,6 @@ def __init__(self): self.dbtype = "sqlite" self.dbname = parser["sqlite"]["dbname"] - if not parser.has_section("cifp_file"): - raise ValueError("No cifp_file configuration found in config.ini") - self.file_loc = parser["cifp_file"]["file_loc"] diff --git a/src/database.py b/src/database.py index 4ccf5ee..5fea440 100644 --- a/src/database.py +++ b/src/database.py @@ -5,7 +5,7 @@ from config import UserConfigs -class DbConfig(Protocol): +class DbConfig(Protocol): # pragma: no cover def connect(self): pass @@ -81,7 +81,7 @@ def connect(self) -> Generator[sqlite3.Cursor, None, None]: conn.commit() conn.close() - def create_schema(self, _) -> None: + def create_schema(self, _) -> None: # pragma: no cover # SQLite does not support schemas in the same way as PostgreSQL. # This method is intentionally a no-op. pass diff --git a/src/main.py b/src/main.py index d8deb26..ef0b3c5 100644 --- a/src/main.py +++ b/src/main.py @@ -1,81 +1,17 @@ -from rich.progress import track -from database import get_db, DbConfig +from arinc import ArincParser from config import UserConfigs -from record_maps import record_maps - - -configs: UserConfigs = UserConfigs() - - -class ArincRecord: - def __init__(self, record_map: dict): - self.section: int | None = record_map.get("section_code") - self.subsection: int | None = record_map.get("subsection_code") - self.section_pos: int | None = record_map.get("section_pos") - self.subsection_pos: int | None = record_map.get("subsection_pos") - self.cont_rec_pos: int | None = record_map.get("cont_rec_pos") - self.cont_rec_vals: list[str] = record_map.get("cont_rec_vals", []) - self.name: str = record_map.get("name", "") - self.columns: list[dict] = record_map.get("columns", []) - self.column_names: list[str] = [c["name"] for c in record_map["columns"]] - - -class ArincParser: - def __init__(self, db: DbConfig, file=configs.file_loc): - self.db = db - self.file = file - self.lines = self.read_file() - self.cycle = self.get_cycle() - self.schema = f"cycle{self.cycle}" - - def read_file(self) -> list[str]: - with open(self.file) as file: - return file.readlines() - - def parse(self) -> None: - self.create_schema() - for record in record_maps: - self.create_arinc_record(record) - - def get_cycle(self) -> str: - return self.lines[0][35:39] - - def create_schema(self) -> None: - self.db.create_schema(self.schema) - - def create_table(self, record: ArincRecord) -> None: - self.db.create_table(self.schema, record.name, record.column_names) - - def add_row(self, name: str, values: list, cycle: str) -> None: - self.db.add_row(self.schema, name, values) - - def create_arinc_record(self, record_map) -> None: - record = ArincRecord(record_map) - - self.create_table(record) - - for line in track(self.lines, description=f"{record.name.rjust(26)}"): - if ( - record.section_pos is not None - and record.subsection_pos is not None - and line[record.section_pos] == record.section - and line[record.subsection_pos] == record.subsection - ): - if ( - not record.cont_rec_pos - or line[record.cont_rec_pos] in record.cont_rec_vals - ): - row = [f"{line[i['start']:i['end']]}" for i in record.columns] - self.add_row(record.name, row, self.cycle) +from database import DbConfig, get_db def main() -> None: + configs: UserConfigs = UserConfigs() + db: DbConfig = get_db(configs) with db.connect(): - parser = ArincParser(db) + parser = ArincParser(db, configs.file_loc) parser.parse() -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/record_maps.py b/src/record_maps.py index 8592529..a8183f6 100644 --- a/src/record_maps.py +++ b/src/record_maps.py @@ -1,6 +1,6 @@ # Store of mappings of each ARINC-424 record in FAA CIFP output. -record_maps = [ +record_maps = [ # pragma: no cover { "name": "airport", "section_code": "P", @@ -648,8 +648,18 @@ {"id": 11, "start": 27, "end": 28, "name": "Route_Indicator"}, {"id": 12, "start": 28, "end": 30, "name": "SBAS_Service_Provider"}, {"id": 13, "start": 30, "end": 32, "name": "Reference_Path_Data_Selector"}, - {"id": 14, "start": 32, "end": 36, "name": "Reference_Path_Data_Identifier"}, - {"id": 15, "start": 36, "end": 37, "name": "Approach_Performance_Designator"}, + { + "id": 14, + "start": 32, + "end": 36, + "name": "Reference_Path_Data_Identifier", + }, + { + "id": 15, + "start": 36, + "end": 37, + "name": "Approach_Performance_Designator", + }, {"id": 16, "start": 37, "end": 48, "name": "LTP_Latitude"}, {"id": 17, "start": 48, "end": 60, "name": "LTP_Longitude"}, {"id": 18, "start": 60, "end": 66, "name": "LTP_Ellipsoid_Height"}, @@ -802,7 +812,12 @@ {"id": 20, "start": 81, "end": 85, "name": "LOC__MLS__GLS_Identifier"}, {"id": 21, "start": 85, "end": 86, "name": "Category__Class"}, {"id": 22, "start": 86, "end": 90, "name": "Stopway"}, - {"id": 23, "start": 90, "end": 94, "name": "Secondary_LOC_MLS_GLS_Identifier"}, + { + "id": 23, + "start": 90, + "end": 94, + "name": "Secondary_LOC_MLS_GLS_Identifier", + }, {"id": 24, "start": 94, "end": 95, "name": "Category__Class_2"}, {"id": 25, "start": 101, "end": 123, "name": "Runway_Description"}, {"id": 26, "start": 123, "end": 128, "name": "File_Record_Number"}, From d241fae0093a77ccdfa0c38a41daa6029ceec83f Mon Sep 17 00:00:00 2001 From: Robert Keller Date: Mon, 28 Apr 2025 22:34:33 -0400 Subject: [PATCH 2/2] Add unit tests --- tests/test_arinc.py | 92 +++++++++++++++++++++++++++++++ tests/test_database.py | 121 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 tests/test_arinc.py create mode 100644 tests/test_database.py diff --git a/tests/test_arinc.py b/tests/test_arinc.py new file mode 100644 index 0000000..22eb8a5 --- /dev/null +++ b/tests/test_arinc.py @@ -0,0 +1,92 @@ +import os +import tempfile + +import arinc # type: ignore + + +class MockDbConfig: + def __init__(self): + self.schemas_created = [] + self.tables_created = [] + self.rows_added = [] + + def create_schema(self, schema_name: str) -> None: + self.schemas_created.append(schema_name) + + def create_table( + self, schema_name: str, table_name: str, columns: list[str] + ) -> None: + self.tables_created.append((schema_name, table_name, columns)) + + def add_row(self, schema_name: str, table_name: str, values: list[str]) -> None: + self.rows_added.append((schema_name, table_name, values)) + + +def test_arinc_parser_parse(): + test_record_map = { + "section_code": "A", + "subsection_code": "B", + "section_pos": 0, + "subsection_pos": 1, + "cont_rec_pos": 2, + "cont_rec_vals": ["C"], + "name": "test_record_parser", + "columns": [ + {"name": "col1", "start": 3, "end": 5}, + {"name": "col2", "start": 5, "end": 7}, + ], + } + arinc.record_maps = [test_record_map] + + cycle_line = "X" * 35 + "2023\n" + line_match = "ABCDEFAB\n" + line_nomatch = "ZZZZZZZZ\n" + file_content = cycle_line + line_match + line_nomatch + line_match + + with tempfile.NamedTemporaryFile("w+", delete=False) as tmp_file: + tmp_file.write(file_content) + tmp_file_path = tmp_file.name + + try: + mock_db = MockDbConfig() + parser = arinc.ArincParser(mock_db, tmp_file_path) + parser.parse() + + expected_schema = "cycle2023" + assert expected_schema in mock_db.schemas_created + + expected_table = (expected_schema, "test_record_parser", ["col1", "col2"]) + assert expected_table in mock_db.tables_created + + expected_row = ["DE", "FA"] + matching_rows = [ + row for _, tbl, row in mock_db.rows_added if tbl == "test_record_parser" + ] + assert len(matching_rows) == 2 + for row in matching_rows: + assert row == expected_row + finally: + os.unlink(tmp_file_path) + + +def test_arinc_record(): + record_map = { + "section_code": 1, + "subsection_code": 2, + "section_pos": 3, + "subsection_pos": 4, + "cont_rec_pos": 5, + "cont_rec_vals": ["val1", "val2"], + "name": "test_record", + "columns": [{"name": "col1"}, {"name": "col2"}], + } + record = arinc.ArincRecord(record_map) + + assert record.section == 1 + assert record.subsection == 2 + assert record.section_pos == 3 + assert record.subsection_pos == 4 + assert record.cont_rec_pos == 5 + assert record.cont_rec_vals == ["val1", "val2"] + assert record.name == "test_record" + assert record.column_names == ["col1", "col2"] diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..c063112 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,121 @@ +import pytest +from unittest.mock import MagicMock, patch +from database import PostgresDb, SqliteDb, get_db # type: ignore + + +class MockConfigs: + def __init__( + self, dbtype, dbname="test.db", user=None, password=None, host=None, port=None + ): + self.dbtype = dbtype + self.dbname = dbname + self.user = user + self.password = password + self.host = host + self.port = port + + +@pytest.fixture +def mock_postgres_configs(): + return MockConfigs( + dbtype="postgres", + dbname="test_db", + user="test_user", + password="test_pass", + host="localhost", + port=5432, + ) + + +@pytest.fixture +def mock_sqlite_configs(): + return MockConfigs(dbtype="sqlite", dbname=":memory:") + + +def test_get_db_postgres(mock_postgres_configs): + db = get_db(mock_postgres_configs) + assert isinstance(db, PostgresDb) + + +def test_get_db_sqlite(mock_sqlite_configs): + db = get_db(mock_sqlite_configs) + assert isinstance(db, SqliteDb) + + +def test_get_db_invalid_type(): + configs = MockConfigs(dbtype="invalid") + with pytest.raises(ValueError, match="Unsupported database type: invalid"): + get_db(configs) + + +@patch("database.psycopg2.connect") +def test_postgresdb_connect(mock_connect, mock_postgres_configs): + mock_cursor = MagicMock() + mock_connect.return_value.cursor.return_value = mock_cursor + db = PostgresDb(mock_postgres_configs) + + with db.connect() as cursor: + assert cursor == mock_cursor + + mock_cursor.close.assert_called_once() + mock_connect.return_value.commit.assert_called_once() + mock_connect.return_value.close.assert_called_once() + + +@patch("database.sqlite3.connect") +def test_sqlitedb_connect(mock_connect, mock_sqlite_configs): + mock_cursor = MagicMock() + mock_connect.return_value.cursor.return_value = mock_cursor + db = SqliteDb(mock_sqlite_configs) + + with db.connect() as cursor: + assert cursor == mock_cursor + + mock_cursor.close.assert_called_once() + mock_connect.return_value.commit.assert_called_once() + mock_connect.return_value.close.assert_called_once() + + +def test_postgresdb_create_schema(mock_postgres_configs): + db = PostgresDb(mock_postgres_configs) + db.cursor = MagicMock() + db.create_schema("test_schema") + db.cursor.execute.assert_called_once_with( + "DROP SCHEMA IF EXISTS test_schema CASCADE; CREATE SCHEMA test_schema;" + ) + + +def test_postgresdb_create_table(mock_postgres_configs): + db = PostgresDb(mock_postgres_configs) + db.cursor = MagicMock() + db.create_table("test_schema", "test_table", ["col1", "col2"]) + db.cursor.execute.assert_called_once_with( + "DROP TABLE IF EXISTS test_schema.test_table; CREATE TABLE test_schema.test_table (col1 varchar, col2 varchar);" + ) + + +def test_postgresdb_add_row(mock_postgres_configs): + db = PostgresDb(mock_postgres_configs) + db.cursor = MagicMock() + db.add_row("test_schema", "test_table", ["val1", "val2"]) + db.cursor.execute.assert_called_once_with( + "INSERT INTO test_schema.test_table VALUES ('val1', 'val2');" + ) + + +def test_sqlitedb_create_table(mock_sqlite_configs): + db = SqliteDb(mock_sqlite_configs) + db.cursor = MagicMock() + db.create_table(None, "test_table", ["col1", "col2"]) + db.cursor.executescript.assert_called_once_with( + "DROP TABLE IF EXISTS test_table; CREATE TABLE test_table (col1 TEXT, col2 TEXT);" + ) + + +def test_sqlitedb_add_row(mock_sqlite_configs): + db = SqliteDb(mock_sqlite_configs) + db.cursor = MagicMock() + db.add_row(None, "test_table", ["val1", "val2"]) + db.cursor.executescript.assert_called_once_with( + "INSERT INTO test_table VALUES ('val1', 'val2');" + )