Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 67 additions & 17 deletions src/smpclient/mcuboot.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,34 @@ class ImageTLVInfo:
tlv_tot: int
"""size of TLV area (including tlv_info header)"""

def __post_init__(self) -> None:
"""Do initial validation of the header."""
if self.magic != IMAGE_TLV_INFO_MAGIC:
raise MCUBootImageError(
f"TLV info magic is {hex(self.magic)}, expected {hex(IMAGE_TLV_INFO_MAGIC)}"
)
REGION_SIZE = IMAGE_TLV_INFO_STRUCT.size

@staticmethod
def loads(data: bytes) -> 'ImageTLVInfo':
def loads(data: bytes, protected: bool = False) -> 'ImageTLVInfo':
"""Load an `ImageTLVInfo` from bytes."""
return ImageTLVInfo(*IMAGE_TLV_INFO_STRUCT.unpack(data))
info = ImageTLVInfo(*IMAGE_TLV_INFO_STRUCT.unpack(data))

if protected and info.magic != IMAGE_TLV_PROT_INFO_MAGIC:
raise MCUBootImageError(
f"Expected protected TLV info magic {hex(IMAGE_TLV_PROT_INFO_MAGIC)}, got {hex(info.magic)}"
)

if not protected and info.magic != IMAGE_TLV_INFO_MAGIC:
raise MCUBootImageError(
f"Expected TLV info magic {hex(IMAGE_TLV_INFO_MAGIC)}, got {hex(info.magic)}"
)

if info.tlv_tot < ImageTLVInfo.REGION_SIZE:
raise MCUBootImageError(
f"TLV total size must be at least {ImageTLVInfo.REGION_SIZE}, got {info.tlv_tot}"
)

return info

@staticmethod
def load_from(file: BytesIO | BufferedReader) -> 'ImageTLVInfo':
def load_from(file: BytesIO | BufferedReader, protected: bool = False) -> 'ImageTLVInfo':
"""Load an `ImageTLVInfo` from a file."""
return ImageTLVInfo.loads(file.read(IMAGE_TLV_INFO_STRUCT.size))
return ImageTLVInfo.loads(file.read(IMAGE_TLV_INFO_STRUCT.size), protected=protected)


@dataclass(frozen=True)
Expand Down Expand Up @@ -292,6 +304,8 @@ class ImageInfo:
header: ImageHeader
tlv_info: ImageTLVInfo
tlvs: list[ImageTLVValue]
protected_tlv_info: ImageTLVInfo | None = None
protected_tlvs: list[ImageTLVValue] = Field(default_factory=lambda: [])
file: str | None = None

def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue:
Expand All @@ -301,6 +315,17 @@ def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue:
else:
raise TLVNotFound(f"{tlv} not found in image.")

@staticmethod
def parse_tlvs(region: bytes) -> list[ImageTLVValue]:
"""Parse TLVs from a byte sequence."""
tlvs: list[ImageTLVValue] = []
f = BytesIO(region)
while f.tell() < len(region):
tlv_header = ImageTLV.load_from(f)
tlvs.append(ImageTLVValue(header=tlv_header, value=f.read(tlv_header.len)))

return tlvs
Comment on lines +318 to +327

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that I agree with abstracting this to reduce repetition, but I don't love that it has the side effect of mutating the file read position. Consider inlining again (even though it's repetition) or documenting the side effect.

Maybe some cleverness can be done with an iterator or generator, probably not worth the effort.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I reworked it to take in a bytes object and just operate off of that.


@staticmethod
def load_file(path: str) -> 'ImageInfo':
"""Load MCUBoot `ImageInfo` from the file at `path`.
Expand All @@ -325,18 +350,37 @@ def load_file(path: str) -> 'ImageInfo':
tlv_offset = image_header.hdr_size + image_header.img_size

f.seek(tlv_offset) # move to the start of the TLV area
tlv_info = ImageTLVInfo.load_from(f)

tlvs: list[ImageTLVValue] = []
while f.tell() < tlv_offset + tlv_info.tlv_tot:
tlv_header = ImageTLV.load_from(f)
tlvs.append(ImageTLVValue(header=tlv_header, value=f.read(tlv_header.len)))
# The mcuboot design doc says that optional protected TLV entries come before regular TLV entries
protected_tlvs: list[ImageTLVValue] = []
protected_tlv_info: ImageTLVInfo | None = None
if image_header.protect_tlv_size > 0:
protected_tlv_info = ImageTLVInfo.load_from(f, protected=True)

return ImageInfo(file=path, header=image_header, tlv_info=tlv_info, tlvs=tlvs)
if protected_tlv_info.tlv_tot != image_header.protect_tlv_size:
raise MCUBootImageError(
f"Protected TLV info total size {protected_tlv_info.tlv_tot} does not match header value {image_header.protect_tlv_size}"
)

protected_tlvs = ImageInfo.parse_tlvs(
f.read(protected_tlv_info.tlv_tot - ImageTLVInfo.REGION_SIZE)
)
Comment thread
ChapterSevenSeeds marked this conversation as resolved.
Comment thread
ChapterSevenSeeds marked this conversation as resolved.

tlv_info = ImageTLVInfo.load_from(f)
tlvs = ImageInfo.parse_tlvs(f.read(tlv_info.tlv_tot - ImageTLVInfo.REGION_SIZE))
Comment thread
ChapterSevenSeeds marked this conversation as resolved.

return ImageInfo(
file=path,
header=image_header,
tlv_info=tlv_info,
tlvs=tlvs,
protected_tlv_info=protected_tlv_info,
protected_tlvs=protected_tlvs,
)

@cached_property
def _map_tlv_type_to_value(self) -> dict[int, ImageTLVValue]:
return {tlv.header.type: tlv for tlv in self.tlvs}
return {tlv.header.type: tlv for tlv in (*self.tlvs, *self.protected_tlvs)}

def __str__(self) -> str:
rep = (
Expand All @@ -348,6 +392,12 @@ def __str__(self) -> str:
for tlv in self.tlvs:
rep += f" {str(tlv)}\n"

if self.protected_tlv_info:
rep += f"{self.protected_tlv_info}\n"

for tlv in self.protected_tlvs:
rep += f" {str(tlv)}\n"

return rep


Expand Down
Binary file added tests/fixtures/tf-m-9a4cb1a28/tfm_s_signed.bin
Binary file not shown.
22 changes: 22 additions & 0 deletions tests/test_mcuboot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,25 @@ def test_tlv_value_str_unknown() -> None:
tlv_header = ImageTLV(type=0x99, len=4)
tlv_value = ImageTLVValue(header=tlv_header, value=b"\xde\xad\xbe\xef")
assert str(tlv_value) == "0x99=deadbeef"


def test_protected_tlv_parsing() -> None:
"""Test that protected TLVs are parsed correctly when present."""
# tfm_s_signed.bin generated via https://docs.zephyrproject.org/latest/samples/tfm_integration/tfm_ipc/README.html#tfm_ipc
image_info = ImageInfo.load_file(
str(Path("tests", "fixtures", "tf-m-9a4cb1a28", "tfm_s_signed.bin"))
)

assert image_info.protected_tlv_info is not None
assert len(image_info.protected_tlvs) == 3
assert len(image_info.tlvs) == 3

# imgtool should put these three regular TLVs in the image
image_info.get_tlv(IMAGE_TLV.SHA256)
image_info.get_tlv(IMAGE_TLV.KEYHASH)
image_info.get_tlv(IMAGE_TLV.ECDSA_SIG)

# and these three protected TLVs
image_info.get_tlv(IMAGE_TLV.SEC_CNT)
image_info.get_tlv(IMAGE_TLV.BOOT_RECORD)
image_info.get_tlv(IMAGE_TLV.DEPENDENCY)
Loading