diff --git a/src/smpclient/mcuboot.py b/src/smpclient/mcuboot.py index 4e81acc..1e9f20b 100644 --- a/src/smpclient/mcuboot.py +++ b/src/smpclient/mcuboot.py @@ -235,22 +235,34 @@ class ImageTLVInfo: tlv_tot: int """size of TLV area (including tlv_info header)""" - def __post_init__(self) -> None: - """Do initial validation of the header.""" - if self.magic != IMAGE_TLV_INFO_MAGIC: - raise MCUBootImageError( - f"TLV info magic is {hex(self.magic)}, expected {hex(IMAGE_TLV_INFO_MAGIC)}" - ) + REGION_SIZE = IMAGE_TLV_INFO_STRUCT.size @staticmethod - def loads(data: bytes) -> 'ImageTLVInfo': + def loads(data: bytes, protected: bool = False) -> 'ImageTLVInfo': """Load an `ImageTLVInfo` from bytes.""" - return ImageTLVInfo(*IMAGE_TLV_INFO_STRUCT.unpack(data)) + info = ImageTLVInfo(*IMAGE_TLV_INFO_STRUCT.unpack(data)) + + if protected and info.magic != IMAGE_TLV_PROT_INFO_MAGIC: + raise MCUBootImageError( + f"Expected protected TLV info magic {hex(IMAGE_TLV_PROT_INFO_MAGIC)}, got {hex(info.magic)}" + ) + + if not protected and info.magic != IMAGE_TLV_INFO_MAGIC: + raise MCUBootImageError( + f"Expected TLV info magic {hex(IMAGE_TLV_INFO_MAGIC)}, got {hex(info.magic)}" + ) + + if info.tlv_tot < ImageTLVInfo.REGION_SIZE: + raise MCUBootImageError( + f"TLV total size must be at least {ImageTLVInfo.REGION_SIZE}, got {info.tlv_tot}" + ) + + return info @staticmethod - def load_from(file: BytesIO | BufferedReader) -> 'ImageTLVInfo': + def load_from(file: BytesIO | BufferedReader, protected: bool = False) -> 'ImageTLVInfo': """Load an `ImageTLVInfo` from a file.""" - return ImageTLVInfo.loads(file.read(IMAGE_TLV_INFO_STRUCT.size)) + return ImageTLVInfo.loads(file.read(IMAGE_TLV_INFO_STRUCT.size), protected=protected) @dataclass(frozen=True) @@ -292,6 +304,8 @@ class ImageInfo: header: ImageHeader tlv_info: ImageTLVInfo tlvs: list[ImageTLVValue] + protected_tlv_info: ImageTLVInfo | None = None + protected_tlvs: list[ImageTLVValue] = Field(default_factory=lambda: []) file: str | None = None def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue: @@ -301,6 +315,17 @@ def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue: else: raise TLVNotFound(f"{tlv} not found in image.") + @staticmethod + def parse_tlvs(region: bytes) -> list[ImageTLVValue]: + """Parse TLVs from a byte sequence.""" + tlvs: list[ImageTLVValue] = [] + f = BytesIO(region) + while f.tell() < len(region): + tlv_header = ImageTLV.load_from(f) + tlvs.append(ImageTLVValue(header=tlv_header, value=f.read(tlv_header.len))) + + return tlvs + @staticmethod def load_file(path: str) -> 'ImageInfo': """Load MCUBoot `ImageInfo` from the file at `path`. @@ -325,18 +350,37 @@ def load_file(path: str) -> 'ImageInfo': tlv_offset = image_header.hdr_size + image_header.img_size f.seek(tlv_offset) # move to the start of the TLV area - tlv_info = ImageTLVInfo.load_from(f) - tlvs: list[ImageTLVValue] = [] - while f.tell() < tlv_offset + tlv_info.tlv_tot: - tlv_header = ImageTLV.load_from(f) - tlvs.append(ImageTLVValue(header=tlv_header, value=f.read(tlv_header.len))) + # The mcuboot design doc says that optional protected TLV entries come before regular TLV entries + protected_tlvs: list[ImageTLVValue] = [] + protected_tlv_info: ImageTLVInfo | None = None + if image_header.protect_tlv_size > 0: + protected_tlv_info = ImageTLVInfo.load_from(f, protected=True) - return ImageInfo(file=path, header=image_header, tlv_info=tlv_info, tlvs=tlvs) + if protected_tlv_info.tlv_tot != image_header.protect_tlv_size: + raise MCUBootImageError( + f"Protected TLV info total size {protected_tlv_info.tlv_tot} does not match header value {image_header.protect_tlv_size}" + ) + + protected_tlvs = ImageInfo.parse_tlvs( + f.read(protected_tlv_info.tlv_tot - ImageTLVInfo.REGION_SIZE) + ) + + tlv_info = ImageTLVInfo.load_from(f) + tlvs = ImageInfo.parse_tlvs(f.read(tlv_info.tlv_tot - ImageTLVInfo.REGION_SIZE)) + + return ImageInfo( + file=path, + header=image_header, + tlv_info=tlv_info, + tlvs=tlvs, + protected_tlv_info=protected_tlv_info, + protected_tlvs=protected_tlvs, + ) @cached_property def _map_tlv_type_to_value(self) -> dict[int, ImageTLVValue]: - return {tlv.header.type: tlv for tlv in self.tlvs} + return {tlv.header.type: tlv for tlv in (*self.tlvs, *self.protected_tlvs)} def __str__(self) -> str: rep = ( @@ -348,6 +392,12 @@ def __str__(self) -> str: for tlv in self.tlvs: rep += f" {str(tlv)}\n" + if self.protected_tlv_info: + rep += f"{self.protected_tlv_info}\n" + + for tlv in self.protected_tlvs: + rep += f" {str(tlv)}\n" + return rep diff --git a/tests/fixtures/tf-m-9a4cb1a28/tfm_s_signed.bin b/tests/fixtures/tf-m-9a4cb1a28/tfm_s_signed.bin new file mode 100644 index 0000000..5fe4c1f Binary files /dev/null and b/tests/fixtures/tf-m-9a4cb1a28/tfm_s_signed.bin differ diff --git a/tests/test_mcuboot_tools.py b/tests/test_mcuboot_tools.py index 2ebd1e6..0d1fecc 100644 --- a/tests/test_mcuboot_tools.py +++ b/tests/test_mcuboot_tools.py @@ -230,3 +230,25 @@ def test_tlv_value_str_unknown() -> None: tlv_header = ImageTLV(type=0x99, len=4) tlv_value = ImageTLVValue(header=tlv_header, value=b"\xde\xad\xbe\xef") assert str(tlv_value) == "0x99=deadbeef" + + +def test_protected_tlv_parsing() -> None: + """Test that protected TLVs are parsed correctly when present.""" + # tfm_s_signed.bin generated via https://docs.zephyrproject.org/latest/samples/tfm_integration/tfm_ipc/README.html#tfm_ipc + image_info = ImageInfo.load_file( + str(Path("tests", "fixtures", "tf-m-9a4cb1a28", "tfm_s_signed.bin")) + ) + + assert image_info.protected_tlv_info is not None + assert len(image_info.protected_tlvs) == 3 + assert len(image_info.tlvs) == 3 + + # imgtool should put these three regular TLVs in the image + image_info.get_tlv(IMAGE_TLV.SHA256) + image_info.get_tlv(IMAGE_TLV.KEYHASH) + image_info.get_tlv(IMAGE_TLV.ECDSA_SIG) + + # and these three protected TLVs + image_info.get_tlv(IMAGE_TLV.SEC_CNT) + image_info.get_tlv(IMAGE_TLV.BOOT_RECORD) + image_info.get_tlv(IMAGE_TLV.DEPENDENCY)