diff --git a/src/smpclient/mcuboot.py b/src/smpclient/mcuboot.py index 1e9f20b..ebf17fc 100644 --- a/src/smpclient/mcuboot.py +++ b/src/smpclient/mcuboot.py @@ -3,20 +3,23 @@ Specification: https://docs.mcuboot.com/design.html """ +from __future__ import annotations + import argparse import pathlib import struct from enum import IntEnum, IntFlag, unique from functools import cached_property from io import BufferedReader, BytesIO -from typing import Annotated, Any, Final, Union +from typing import Annotated, Any, Final, Generic, Literal, TypeVar, Union from intelhex import hex2bin # type: ignore from pydantic import Field, GetCoreSchemaHandler from pydantic.dataclasses import dataclass from pydantic_core import CoreSchema, core_schema -IMAGE_MAGIC: Final = 0x96F3B83D +ImageMagic = Literal[0x96F3B83D] +IMAGE_MAGIC: Final[ImageMagic] = 0x96F3B83D IMAGE_HEADER_SIZE: Final = 32 _IMAGE_VERSION_FORMAT_STRING: Final = "BBHL" @@ -26,8 +29,18 @@ IMAGE_HEADER_STRUCT: Final = struct.Struct(f" 'ImageVersion': + def loads(data: bytes) -> ImageVersion: """Load an `ImageVersion` from `bytes`.""" return ImageVersion(*IMAGE_VERSION_STRUCT.unpack(data)) @@ -180,7 +193,7 @@ def __str__(self) -> str: class ImageHeader: """An MCUBoot signed FW update header.""" - magic: int + magic: ImageMagic load_addr: int hdr_size: int protect_tlv_size: int @@ -189,7 +202,7 @@ class ImageHeader: ver: ImageVersion @staticmethod - def loads(data: bytes) -> 'ImageHeader': + def loads(data: bytes) -> ImageHeader: """Load an `ImageHeader` from `bytes`.""" ( magic, @@ -200,6 +213,10 @@ def loads(data: bytes) -> 'ImageHeader': flags, *ver, ) = IMAGE_HEADER_STRUCT.unpack(data) + + if magic != IMAGE_MAGIC: + raise MCUBootImageError(f"Magic is {hex(magic)}, expected {hex(IMAGE_MAGIC)}") + return ImageHeader( magic=magic, load_addr=load_addr, @@ -210,59 +227,51 @@ def loads(data: bytes) -> 'ImageHeader': ver=ImageVersion(*ver), ) - def __post_init__(self) -> None: - """Do initial validation of the header.""" - if self.magic != IMAGE_MAGIC: - raise MCUBootImageError(f"Magic is {hex(self.magic)}, expected {hex(IMAGE_MAGIC)}") - @staticmethod - def load_from(file: BytesIO | BufferedReader) -> 'ImageHeader': + def load_from(file: BytesIO | BufferedReader) -> ImageHeader: """Load an `ImageHeader` from an open file.""" return ImageHeader.loads(file.read(IMAGE_HEADER_STRUCT.size)) @staticmethod - def load_file(path: str) -> 'ImageHeader': + def load_file(path: str) -> ImageHeader: """Load an `ImageHeader` the file at `path`.""" with open(path, 'rb') as f: return ImageHeader.load_from(f) @dataclass(frozen=True) -class ImageTLVInfo: +class ImageTLVInfo(Generic[T]): """An image Type-Length-Value (TLV) region header.""" - magic: int + magic: T tlv_tot: int """size of TLV area (including tlv_info header)""" - REGION_SIZE = IMAGE_TLV_INFO_STRUCT.size - @staticmethod - def loads(data: bytes, protected: bool = False) -> 'ImageTLVInfo': + def loads(data: bytes, magic: MagicT) -> ImageTLVInfo[MagicT]: """Load an `ImageTLVInfo` from bytes.""" - info = ImageTLVInfo(*IMAGE_TLV_INFO_STRUCT.unpack(data)) + parsed_magic, tlv_tot = IMAGE_TLV_INFO_STRUCT.unpack(data) - if protected and info.magic != IMAGE_TLV_PROT_INFO_MAGIC: + if parsed_magic != magic: raise MCUBootImageError( - f"Expected protected TLV info magic {hex(IMAGE_TLV_PROT_INFO_MAGIC)}, got {hex(info.magic)}" + f"Expected TLV info magic {hex(magic)}, got {hex(parsed_magic)}" ) - if not protected and info.magic != IMAGE_TLV_INFO_MAGIC: + if tlv_tot < IMAGE_TLV_INFO_STRUCT.size: raise MCUBootImageError( - f"Expected TLV info magic {hex(IMAGE_TLV_INFO_MAGIC)}, got {hex(info.magic)}" + f"TLV total size must be at least {IMAGE_TLV_INFO_STRUCT.size}, got {tlv_tot}" ) - if info.tlv_tot < ImageTLVInfo.REGION_SIZE: - raise MCUBootImageError( - f"TLV total size must be at least {ImageTLVInfo.REGION_SIZE}, got {info.tlv_tot}" - ) - - return info + return ImageTLVInfo(magic=magic, tlv_tot=tlv_tot) @staticmethod - def load_from(file: BytesIO | BufferedReader, protected: bool = False) -> 'ImageTLVInfo': + def load_from(file: BytesIO | BufferedReader, magic: MagicT) -> ImageTLVInfo[MagicT]: """Load an `ImageTLVInfo` from a file.""" - return ImageTLVInfo.loads(file.read(IMAGE_TLV_INFO_STRUCT.size), protected=protected) + return ImageTLVInfo.loads(file.read(IMAGE_TLV_INFO_STRUCT.size), magic) + + def load_tlvs_from(self, file: BytesIO | BufferedReader) -> list[ImageTLVValue]: + """Read and parse the TLV entries that follow this header in `file`.""" + return ImageInfo.parse_tlvs(file.read(self.tlv_tot - IMAGE_TLV_INFO_STRUCT.size)) @dataclass(frozen=True) @@ -274,7 +283,7 @@ class ImageTLV: """Data length (not including TLV header).""" @staticmethod - def load_from(file: BytesIO | BufferedReader) -> 'ImageTLV': + def load_from(file: BytesIO | BufferedReader) -> ImageTLV: """Load an `ImageTLV` from a file.""" return ImageTLV(*IMAGE_TLV_STRUCT.unpack_from(file.read(IMAGE_TLV_STRUCT.size))) @@ -302,10 +311,10 @@ class ImageInfo: """A summary of an MCUBoot FW update image.""" header: ImageHeader - tlv_info: ImageTLVInfo + tlv_info: ImageTLVInfo[ImageTLVInfoMagic] tlvs: list[ImageTLVValue] - protected_tlv_info: ImageTLVInfo | None = None - protected_tlvs: list[ImageTLVValue] = Field(default_factory=lambda: []) + protected_tlv_info: ImageTLVInfo[ImageTLVProtInfoMagic] | None = None + protected_tlvs: list[ImageTLVValue] | None = None file: str | None = None def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue: @@ -327,7 +336,7 @@ def parse_tlvs(region: bytes) -> list[ImageTLVValue]: return tlvs @staticmethod - def load_file(path: str) -> 'ImageInfo': + def load_file(path: str) -> ImageInfo: """Load MCUBoot `ImageInfo` from the file at `path`. Files with the `.hex` extension are treated as Intel HEX format. @@ -352,22 +361,26 @@ def load_file(path: str) -> 'ImageInfo': f.seek(tlv_offset) # move to the start of the TLV area # The mcuboot design doc says that optional protected TLV entries come before regular TLV entries - protected_tlvs: list[ImageTLVValue] = [] - protected_tlv_info: ImageTLVInfo | None = None - if image_header.protect_tlv_size > 0: - protected_tlv_info = ImageTLVInfo.load_from(f, protected=True) - - if protected_tlv_info.tlv_tot != image_header.protect_tlv_size: - raise MCUBootImageError( - f"Protected TLV info total size {protected_tlv_info.tlv_tot} does not match header value {image_header.protect_tlv_size}" - ) - - protected_tlvs = ImageInfo.parse_tlvs( - f.read(protected_tlv_info.tlv_tot - ImageTLVInfo.REGION_SIZE) + protected_tlv_info = ( + ImageTLVInfo.load_from(f, IMAGE_TLV_PROT_INFO_MAGIC) + if image_header.protect_tlv_size > 0 + else None + ) + + if ( + protected_tlv_info is not None + and protected_tlv_info.tlv_tot != image_header.protect_tlv_size + ): + raise MCUBootImageError( + f"Protected TLV info total size {protected_tlv_info.tlv_tot} does not match header value {image_header.protect_tlv_size}" ) - tlv_info = ImageTLVInfo.load_from(f) - tlvs = ImageInfo.parse_tlvs(f.read(tlv_info.tlv_tot - ImageTLVInfo.REGION_SIZE)) + protected_tlvs = ( + protected_tlv_info.load_tlvs_from(f) if protected_tlv_info is not None else None + ) + + tlv_info = ImageTLVInfo.load_from(f, IMAGE_TLV_INFO_MAGIC) + tlvs = tlv_info.load_tlvs_from(f) return ImageInfo( file=path, @@ -380,7 +393,7 @@ def load_file(path: str) -> 'ImageInfo': @cached_property def _map_tlv_type_to_value(self) -> dict[int, ImageTLVValue]: - return {tlv.header.type: tlv for tlv in (*self.tlvs, *self.protected_tlvs)} + return {tlv.header.type: tlv for tlv in (*self.tlvs, *(self.protected_tlvs or []))} def __str__(self) -> str: rep = ( @@ -395,7 +408,7 @@ def __str__(self) -> str: if self.protected_tlv_info: rep += f"{self.protected_tlv_info}\n" - for tlv in self.protected_tlvs: + for tlv in self.protected_tlvs or []: rep += f" {str(tlv)}\n" return rep diff --git a/tests/test_mcuboot_tools.py b/tests/test_mcuboot_tools.py index 0d1fecc..3b2b8a8 100644 --- a/tests/test_mcuboot_tools.py +++ b/tests/test_mcuboot_tools.py @@ -5,18 +5,29 @@ from typing import Protocol import pytest +from typing_extensions import assert_type from smpclient.mcuboot import ( + IMAGE_HEADER_STRUCT, IMAGE_MAGIC, IMAGE_TLV, IMAGE_TLV_INFO_MAGIC, + IMAGE_TLV_INFO_STRUCT, + IMAGE_TLV_PROT_INFO_MAGIC, ImageHeader, ImageInfo, + ImageMagic, ImageTLV, + ImageTLVInfo, + ImageTLVInfoMagic, + ImageTLVProtInfoMagic, ImageTLVType, ImageTLVValue, ImageVersion, + MCUBootImageError, + TLVNotFound, VendorTLV, + mcuimg, ) @@ -67,6 +78,8 @@ def test_ImageInfo(image: _ImageFileFixture) -> None: # TLV header t = image_info.tlv_info + assert_type(t, ImageTLVInfo[ImageTLVInfoMagic]) + assert_type(t.magic, ImageTLVInfoMagic) assert t.magic == IMAGE_TLV_INFO_MAGIC assert t.tlv_tot == 336 @@ -96,6 +109,7 @@ def test_ImageInfo(image: _ImageFileFixture) -> None: def test_ImageHeader(image: _ImageFileFixture) -> None: h = ImageHeader.load_file(str(image.PATH)) + assert_type(h.magic, ImageMagic) assert h.magic == IMAGE_MAGIC assert h.load_addr == 0 assert h.hdr_size == 512 @@ -240,9 +254,15 @@ def test_protected_tlv_parsing() -> None: ) assert image_info.protected_tlv_info is not None + assert_type(image_info.protected_tlv_info, ImageTLVInfo[ImageTLVProtInfoMagic]) + assert_type(image_info.protected_tlv_info.magic, ImageTLVProtInfoMagic) + assert image_info.protected_tlv_info.magic == IMAGE_TLV_PROT_INFO_MAGIC + assert image_info.protected_tlvs is not None assert len(image_info.protected_tlvs) == 3 assert len(image_info.tlvs) == 3 + assert "SEC_CNT=" in str(image_info) + # imgtool should put these three regular TLVs in the image image_info.get_tlv(IMAGE_TLV.SHA256) image_info.get_tlv(IMAGE_TLV.KEYHASH) @@ -252,3 +272,71 @@ def test_protected_tlv_parsing() -> None: image_info.get_tlv(IMAGE_TLV.SEC_CNT) image_info.get_tlv(IMAGE_TLV.BOOT_RECORD) image_info.get_tlv(IMAGE_TLV.DEPENDENCY) + + +def test_tlv_info_magic_type_binding() -> None: + """The expected magic argument binds the static type and the runtime check.""" + info = ImageTLVInfo.loads(struct.pack(" None: + with pytest.raises(MCUBootImageError): + ImageHeader.loads(IMAGE_HEADER_STRUCT.pack(0xDEADBEEF, 0, 32, 0, 0, 0, 0, 0, 0, 0)) + + +def test_tlv_info_total_size_too_small() -> None: + with pytest.raises(MCUBootImageError): + ImageTLVInfo.loads(struct.pack(" None: + with pytest.raises(MCUBootImageError): + ImageTLVValue(header=ImageTLV(type=0x10, len=4), value=b"\x00") + + +def test_get_tlv_not_found() -> None: + image_info = ImageInfo.load_file(str(SIGNED_BIN.PATH)) + with pytest.raises(TLVNotFound): + image_info.get_tlv(IMAGE_TLV.SEC_CNT) + + +def test_invalid_hex_file(tmp_path: Path) -> None: + bad = tmp_path / "bad.hex" + bad.write_text("not a hex file\n") + with pytest.raises(MCUBootImageError): + ImageInfo.load_file(str(bad)) + + +def test_protected_tlv_size_mismatch(tmp_path: Path) -> None: + image = tmp_path / "image.bin" + image.write_bytes( + IMAGE_HEADER_STRUCT.pack(IMAGE_MAGIC, 0, IMAGE_HEADER_STRUCT.size, 12, 0, 0, 0, 0, 0, 0) + + IMAGE_TLV_INFO_STRUCT.pack(IMAGE_TLV_PROT_INFO_MAGIC, 8) + ) + with pytest.raises(MCUBootImageError): + ImageInfo.load_file(str(image)) + + +def test_mcuimg(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None: + monkeypatch.setattr("sys.argv", ["mcuimg", str(SIGNED_BIN.PATH)]) + assert mcuimg() == 0 + assert "ImageInfo" in capsys.readouterr().out + + monkeypatch.setattr("sys.argv", ["mcuimg", "does-not-exist.bin"]) + assert mcuimg() == -1