diff --git a/dissect/database/sqlite3/sqlite3.py b/dissect/database/sqlite3/sqlite3.py index e009053..1528ba7 100644 --- a/dissect/database/sqlite3/sqlite3.py +++ b/dissect/database/sqlite3/sqlite3.py @@ -66,6 +66,7 @@ class SQLite3: fh: The path or file-like object to open a SQLite3 database on. wal: The path or file-like object to open a SQLite3 WAL file on. checkpoint: The checkpoint to apply from the WAL file. Can be a :class:`Checkpoint` object or an integer index. + validate_checksums: A boolean that sets whether to validate the checksum of frames when reading. Raises: InvalidDatabase: If the file-like object does not look like a SQLite3 database based on the header magic. @@ -79,6 +80,8 @@ def __init__( fh: Path | BinaryIO, wal: WAL | Path | BinaryIO | None = None, checkpoint: Checkpoint | int | None = None, + *, + validate_checksums: bool = False, ): if isinstance(fh, Path): path = fh @@ -90,6 +93,7 @@ def __init__( self.path = path self.wal = None self.checkpoint = None + self.validate_checksums = validate_checksums self.header = c_sqlite3.header(self.fh) if self.header.magic != SQLITE3_HEADER_MAGIC: @@ -211,7 +215,7 @@ def raw_page(self, num: int) -> bytes: # Check if the latest valid instance of the page is committed (either the frame itself # is the commit frame or it is included in a commit's frames). If so, return that frame's data. for commit in reversed(self.wal.commits): - if (frame := commit.get(num)) and frame.valid: + if (frame := commit.get(num)) and frame.is_valid(validate_checksums=self.validate_checksums): data = frame.data break diff --git a/dissect/database/sqlite3/wal.py b/dissect/database/sqlite3/wal.py index 7d4ec76..b40799f 100644 --- a/dissect/database/sqlite3/wal.py +++ b/dissect/database/sqlite3/wal.py @@ -125,13 +125,72 @@ def __init__(self, wal: WAL, offset: int): def __repr__(self) -> str: return f"" - @property - def valid(self) -> bool: + def is_valid(self, validate_checksums: bool = True) -> bool: + """Return whether the frame is valid by comparing its salt values and optionally verifying the checksum. + + A frame is valid if: + - Its salt1 and salt2 values match those in the WAL header. + - Its checksum matches the calculated checksum. + + References: + - https://sqlite.org/fileformat2.html#wal_file_format + """ + return (self.is_valid_salt() and self.is_valid_checksum()) if validate_checksums else self.is_valid_salt() + + def is_valid_salt(self) -> bool: + """Return whether the frame's salt values match those in the WAL header. + + References: + - https://sqlite.org/fileformat2.html#wal_file_format + """ salt1_match = self.header.salt1 == self.wal.header.salt1 salt2_match = self.header.salt2 == self.wal.header.salt2 return salt1_match and salt2_match + def is_valid_checksum(self) -> bool: + """Return whether the frame's checksum matches the calculated checksum. + + The checksum values in the final 8 bytes of the frame-header (checksum-1 and checksum-2) + exactly match the computed checksum over: + + 1. the first 24 bytes of the WAL header + 2. the first 8 bytes of each frame header (up to and including this frame) + 3. the page data of each frame (up to and including this frame) + + References: + - https://sqlite.org/fileformat2.html#wal_file_format + - https://github.com/sqlite/sqlite/blob/master/src/wal.c#L995-L1047 + """ + # Start seed with checksum over first 24 bytes of WAL header + seed = calculate_checksum(self.header.dumps()[:24], endian=self.wal.checksum_endian) + + # Iterate frames from the first frame up to and including this frame + frame_size = len(c_sqlite3.wal_frame) + self.wal.header.page_size + first_frame_offset = len(c_sqlite3.wal_header) + offset = first_frame_offset + + while offset <= self.offset: + # Read frame header + self.fh.seek(offset) + frame_hdr_bytes = self.fh.read(len(c_sqlite3.wal_frame)) + if len(frame_hdr_bytes) < len(c_sqlite3.wal_frame): + raise EOFError("Incomplete frame header while calculating checksum") + + # Checksum first 16 bytes of frame header + seed = calculate_checksum(frame_hdr_bytes[:16], seed=seed, endian=self.wal.checksum_endian) + + # Read and checksum page data + page_data = self.fh.read(self.wal.header.page_size) + if len(page_data) < self.wal.header.page_size: + raise EOFError("Incomplete page data while calculating checksum") + seed = calculate_checksum(page_data, seed=seed, endian=self.wal.checksum_endian) + + offset += frame_size + + # Compare calculated checksum to stored checksum in this frame header + return (seed[0], seed[1]) == (self.header.checksum1, self.header.checksum2) + @property def data(self) -> bytes: self.fh.seek(self.offset + len(c_sqlite3.wal_frame)) @@ -187,8 +246,14 @@ class Commit(_FrameCollection): """ -def checksum(buf: bytes, endian: str = ">") -> tuple[int, int]: - s0 = s1 = 0 +def calculate_checksum(buf: bytes, seed: tuple[int, int] = (0, 0), endian: str = ">") -> tuple[int, int]: + """Calculate the checksum of a WAL header or frame. + + References: + - https://sqlite.org/fileformat2.html#checksum_algorithm + """ + + s0, s1 = seed num_ints = len(buf) // 4 arr = struct.unpack(f"{endian}{num_ints}I", buf) diff --git a/tests/sqlite3/test_wal.py b/tests/sqlite3/test_wal.py index 6d477fe..46fe0b3 100644 --- a/tests/sqlite3/test_wal.py +++ b/tests/sqlite3/test_wal.py @@ -18,7 +18,7 @@ ("wal_as_path"), [pytest.param(True, id="wal_as_path"), pytest.param(False, id="wal_as_fh")], ) -def test_sqlite_wal(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_path: bool) -> None: +def test_sqlite_wal_checkpoint(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_path: bool) -> None: db = sqlite3.SQLite3( sqlite_db if db_as_path else sqlite_db.open("rb"), sqlite_wal if wal_as_path else sqlite_wal.open("rb"), @@ -47,6 +47,40 @@ def test_sqlite_wal(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_ db.close() +@pytest.mark.parametrize( + ("db_as_path"), + [pytest.param(True, id="db_as_path"), pytest.param(False, id="db_as_fh")], +) +@pytest.mark.parametrize( + ("wal_as_path"), + [pytest.param(True, id="wal_as_path"), pytest.param(False, id="wal_as_fh")], +) +def test_sqlite_wal_checksum_validation(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_path: bool) -> None: + # Test that the WAL checksum validation works as expected + # When validate_checksums=True, only entries before the last checkpoint are visible + db = sqlite3.SQLite3( + sqlite_db if db_as_path else sqlite_db.open("rb"), + sqlite_wal if wal_as_path else sqlite_wal.open("rb"), + validate_checksums=True, + ) + + _assert_valid_checksum(db) + + db.close() + + # When validate_checksums=False, entries after the last checkpoint are also visible + db = sqlite3.SQLite3( + sqlite_db if db_as_path else sqlite_db.open("rb"), + sqlite_wal if wal_as_path else sqlite_wal.open("rb"), + validate_checksums=False, + ) + + _assert_invalid_checksum(db) + + db.close() + + +# Assertion functions for test_sqlite_wal_checkpoint() def _assert_checkpoint_1(s: sqlite3.SQLite3) -> None: # After the first checkpoint the "after checkpoint" entries are present table = next(iter(s.tables())) @@ -162,3 +196,85 @@ def _assert_checkpoint_3(s: sqlite3.SQLite3) -> None: assert rows[9].id == 11 assert rows[9].name == "second checkpoint" assert rows[9].value == 101 + + +# Assertion functions for test_sqlite_wal_checksum_validation() +def _assert_valid_checksum(s: sqlite3.SQLite3) -> None: + # If the checksum validation is correct, all entries BEFORE the last checkpoint should be present + table = next(iter(s.tables())) + rows = list(table.rows()) + + assert len(rows) == 11 + + assert rows[0].id == 1 + assert rows[0].name == "testing" + assert rows[0].value == 1337 + assert rows[1].id == 2 + assert rows[1].name == "omg" + assert rows[1].value == 7331 + assert rows[2].id == 3 + assert rows[2].name == "A" * 4100 + assert rows[2].value == 4100 + assert rows[3].id == 4 + assert rows[3].name == "B" * 4100 + assert rows[3].value == 4100 + assert rows[4].id == 5 + assert rows[4].name == "negative" + assert rows[4].value == -11644473429 + assert rows[5].id == 6 + assert rows[5].name == "after checkpoint" + assert rows[5].value == 42 + assert rows[6].id == 7 + assert rows[6].name == "after checkpoint" + assert rows[6].value == 43 + assert rows[7].id == 8 + assert rows[7].name == "after checkpoint" + assert rows[7].value == 44 + assert rows[8].id == 9 + assert rows[8].name == "after checkpoint" + assert rows[8].value == 45 + assert rows[9].id == 10 + assert rows[9].name == "second checkpoint" + assert rows[9].value == 100 + assert rows[10].id == 11 + assert rows[10].name == "second checkpoint" + assert rows[10].value == 101 + + +def _assert_invalid_checksum(s: sqlite3.SQLite3) -> None: + # If the checksum validation is incorrect, all entries AFTER the last checkpoint should be present + table = next(iter(s.tables())) + rows = list(table.rows()) + + assert len(rows) == 10 + + assert rows[0].id == 1 + assert rows[0].name == "testing" + assert rows[0].value == 1337 + assert rows[1].id == 2 + assert rows[1].name == "omg" + assert rows[1].value == 7331 + assert rows[2].id == 3 + assert rows[2].name == "A" * 4100 + assert rows[2].value == 4100 + assert rows[3].id == 4 + assert rows[3].name == "B" * 4100 + assert rows[3].value == 4100 + assert rows[4].id == 5 + assert rows[4].name == "negative" + assert rows[4].value == -11644473429 + assert rows[5].id == 6 + assert rows[5].name == "after checkpoint" + assert rows[5].value == 42 + assert rows[6].id == 8 + assert rows[6].name == "after checkpoint" + assert rows[6].value == 44 + assert rows[7].id == 9 + assert rows[7].name == "wow" + assert rows[7].value == 1234 + assert rows[8].id == 10 + assert rows[8].name == "second checkpoint" + assert rows[8].value == 100 + assert rows[9].id == 11 + assert rows[9].name == "second checkpoint" + assert rows[9].value == 101