Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 66 additions & 53 deletions src/smpclient/mcuboot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
Specification: https://docs.mcuboot.com/design.html
"""

from __future__ import annotations

import argparse
import pathlib
import struct
from enum import IntEnum, IntFlag, unique
from functools import cached_property
from io import BufferedReader, BytesIO
from typing import Annotated, Any, Final, Union
from typing import Annotated, Any, Final, Generic, Literal, TypeVar, Union

from intelhex import hex2bin # type: ignore
from pydantic import Field, GetCoreSchemaHandler
from pydantic.dataclasses import dataclass
from pydantic_core import CoreSchema, core_schema

IMAGE_MAGIC: Final = 0x96F3B83D
ImageMagic = Literal[0x96F3B83D]
IMAGE_MAGIC: Final[ImageMagic] = 0x96F3B83D
IMAGE_HEADER_SIZE: Final = 32

_IMAGE_VERSION_FORMAT_STRING: Final = "BBHL"
Expand All @@ -26,8 +29,18 @@
IMAGE_HEADER_STRUCT: Final = struct.Struct(f"<LLHHLL{_IMAGE_VERSION_FORMAT_STRING}4x")
assert IMAGE_HEADER_STRUCT.size == IMAGE_HEADER_SIZE

IMAGE_TLV_INFO_MAGIC: Final = 0x6907
IMAGE_TLV_PROT_INFO_MAGIC: Final = 0x6908
ImageTLVInfoMagic = Literal[0x6907]
IMAGE_TLV_INFO_MAGIC: Final[ImageTLVInfoMagic] = 0x6907

ImageTLVProtInfoMagic = Literal[0x6908]
IMAGE_TLV_PROT_INFO_MAGIC: Final[ImageTLVProtInfoMagic] = 0x6908

T = TypeVar("T", ImageTLVInfoMagic, ImageTLVProtInfoMagic)
"""Any TLV info magic type."""

MagicT = TypeVar("MagicT", ImageTLVInfoMagic, ImageTLVProtInfoMagic)
"""Method-scoped equivalent of `T` - class-scoped TypeVars only bind via `self`/`cls`,
so the static loaders of `ImageTLVInfo` require their own TypeVar."""

IMAGE_TLV_INFO_STRUCT: Final = struct.Struct("<HH")
assert IMAGE_TLV_INFO_STRUCT.size == 4
Expand Down Expand Up @@ -168,7 +181,7 @@ class ImageVersion:
build_num: int

@staticmethod
def loads(data: bytes) -> 'ImageVersion':
def loads(data: bytes) -> ImageVersion:
"""Load an `ImageVersion` from `bytes`."""
return ImageVersion(*IMAGE_VERSION_STRUCT.unpack(data))

Expand All @@ -180,7 +193,7 @@ def __str__(self) -> str:
class ImageHeader:
"""An MCUBoot signed FW update header."""

magic: int
magic: ImageMagic
load_addr: int
hdr_size: int
protect_tlv_size: int
Expand All @@ -189,7 +202,7 @@ class ImageHeader:
ver: ImageVersion

@staticmethod
def loads(data: bytes) -> 'ImageHeader':
def loads(data: bytes) -> ImageHeader:
"""Load an `ImageHeader` from `bytes`."""
(
magic,
Expand All @@ -200,6 +213,10 @@ def loads(data: bytes) -> 'ImageHeader':
flags,
*ver,
) = IMAGE_HEADER_STRUCT.unpack(data)

if magic != IMAGE_MAGIC:
raise MCUBootImageError(f"Magic is {hex(magic)}, expected {hex(IMAGE_MAGIC)}")

return ImageHeader(
magic=magic,
load_addr=load_addr,
Expand All @@ -210,59 +227,51 @@ def loads(data: bytes) -> 'ImageHeader':
ver=ImageVersion(*ver),
)

def __post_init__(self) -> None:
"""Do initial validation of the header."""
if self.magic != IMAGE_MAGIC:
raise MCUBootImageError(f"Magic is {hex(self.magic)}, expected {hex(IMAGE_MAGIC)}")

@staticmethod
def load_from(file: BytesIO | BufferedReader) -> 'ImageHeader':
def load_from(file: BytesIO | BufferedReader) -> ImageHeader:
"""Load an `ImageHeader` from an open file."""
return ImageHeader.loads(file.read(IMAGE_HEADER_STRUCT.size))

@staticmethod
def load_file(path: str) -> 'ImageHeader':
def load_file(path: str) -> ImageHeader:
"""Load an `ImageHeader` the file at `path`."""
with open(path, 'rb') as f:
return ImageHeader.load_from(f)


@dataclass(frozen=True)
class ImageTLVInfo:
class ImageTLVInfo(Generic[T]):
"""An image Type-Length-Value (TLV) region header."""

magic: int
magic: T
tlv_tot: int
"""size of TLV area (including tlv_info header)"""

REGION_SIZE = IMAGE_TLV_INFO_STRUCT.size

@staticmethod
def loads(data: bytes, protected: bool = False) -> 'ImageTLVInfo':
def loads(data: bytes, magic: MagicT) -> ImageTLVInfo[MagicT]:
"""Load an `ImageTLVInfo` from bytes."""
info = ImageTLVInfo(*IMAGE_TLV_INFO_STRUCT.unpack(data))
parsed_magic, tlv_tot = IMAGE_TLV_INFO_STRUCT.unpack(data)

if protected and info.magic != IMAGE_TLV_PROT_INFO_MAGIC:
if parsed_magic != magic:
raise MCUBootImageError(
f"Expected protected TLV info magic {hex(IMAGE_TLV_PROT_INFO_MAGIC)}, got {hex(info.magic)}"
f"Expected TLV info magic {hex(magic)}, got {hex(parsed_magic)}"
)

if not protected and info.magic != IMAGE_TLV_INFO_MAGIC:
if tlv_tot < IMAGE_TLV_INFO_STRUCT.size:
raise MCUBootImageError(
f"Expected TLV info magic {hex(IMAGE_TLV_INFO_MAGIC)}, got {hex(info.magic)}"
f"TLV total size must be at least {IMAGE_TLV_INFO_STRUCT.size}, got {tlv_tot}"
)

if info.tlv_tot < ImageTLVInfo.REGION_SIZE:
raise MCUBootImageError(
f"TLV total size must be at least {ImageTLVInfo.REGION_SIZE}, got {info.tlv_tot}"
)

return info
return ImageTLVInfo(magic=magic, tlv_tot=tlv_tot)

@staticmethod
def load_from(file: BytesIO | BufferedReader, protected: bool = False) -> 'ImageTLVInfo':
def load_from(file: BytesIO | BufferedReader, magic: MagicT) -> ImageTLVInfo[MagicT]:
"""Load an `ImageTLVInfo` from a file."""
return ImageTLVInfo.loads(file.read(IMAGE_TLV_INFO_STRUCT.size), protected=protected)
return ImageTLVInfo.loads(file.read(IMAGE_TLV_INFO_STRUCT.size), magic)

def load_tlvs_from(self, file: BytesIO | BufferedReader) -> list[ImageTLVValue]:
"""Read and parse the TLV entries that follow this header in `file`."""
return ImageInfo.parse_tlvs(file.read(self.tlv_tot - IMAGE_TLV_INFO_STRUCT.size))


@dataclass(frozen=True)
Expand All @@ -274,7 +283,7 @@ class ImageTLV:
"""Data length (not including TLV header)."""

@staticmethod
def load_from(file: BytesIO | BufferedReader) -> 'ImageTLV':
def load_from(file: BytesIO | BufferedReader) -> ImageTLV:
"""Load an `ImageTLV` from a file."""
return ImageTLV(*IMAGE_TLV_STRUCT.unpack_from(file.read(IMAGE_TLV_STRUCT.size)))

Expand Down Expand Up @@ -302,10 +311,10 @@ class ImageInfo:
"""A summary of an MCUBoot FW update image."""

header: ImageHeader
tlv_info: ImageTLVInfo
tlv_info: ImageTLVInfo[ImageTLVInfoMagic]
tlvs: list[ImageTLVValue]
protected_tlv_info: ImageTLVInfo | None = None
protected_tlvs: list[ImageTLVValue] = Field(default_factory=lambda: [])
protected_tlv_info: ImageTLVInfo[ImageTLVProtInfoMagic] | None = None
protected_tlvs: list[ImageTLVValue] | None = None
file: str | None = None

def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue:
Expand All @@ -327,7 +336,7 @@ def parse_tlvs(region: bytes) -> list[ImageTLVValue]:
return tlvs

@staticmethod
def load_file(path: str) -> 'ImageInfo':
def load_file(path: str) -> ImageInfo:
"""Load MCUBoot `ImageInfo` from the file at `path`.

Files with the `.hex` extension are treated as Intel HEX format.
Expand All @@ -352,22 +361,26 @@ def load_file(path: str) -> 'ImageInfo':
f.seek(tlv_offset) # move to the start of the TLV area

# The mcuboot design doc says that optional protected TLV entries come before regular TLV entries
protected_tlvs: list[ImageTLVValue] = []
protected_tlv_info: ImageTLVInfo | None = None
if image_header.protect_tlv_size > 0:
protected_tlv_info = ImageTLVInfo.load_from(f, protected=True)

if protected_tlv_info.tlv_tot != image_header.protect_tlv_size:
raise MCUBootImageError(
f"Protected TLV info total size {protected_tlv_info.tlv_tot} does not match header value {image_header.protect_tlv_size}"
)

protected_tlvs = ImageInfo.parse_tlvs(
f.read(protected_tlv_info.tlv_tot - ImageTLVInfo.REGION_SIZE)
protected_tlv_info = (
ImageTLVInfo.load_from(f, IMAGE_TLV_PROT_INFO_MAGIC)
if image_header.protect_tlv_size > 0
else None
)

if (
protected_tlv_info is not None
and protected_tlv_info.tlv_tot != image_header.protect_tlv_size
):
raise MCUBootImageError(
f"Protected TLV info total size {protected_tlv_info.tlv_tot} does not match header value {image_header.protect_tlv_size}"
)

tlv_info = ImageTLVInfo.load_from(f)
tlvs = ImageInfo.parse_tlvs(f.read(tlv_info.tlv_tot - ImageTLVInfo.REGION_SIZE))
protected_tlvs = (
protected_tlv_info.load_tlvs_from(f) if protected_tlv_info is not None else None
)

tlv_info = ImageTLVInfo.load_from(f, IMAGE_TLV_INFO_MAGIC)
tlvs = tlv_info.load_tlvs_from(f)

return ImageInfo(
file=path,
Expand All @@ -380,7 +393,7 @@ def load_file(path: str) -> 'ImageInfo':

@cached_property
def _map_tlv_type_to_value(self) -> dict[int, ImageTLVValue]:
return {tlv.header.type: tlv for tlv in (*self.tlvs, *self.protected_tlvs)}
return {tlv.header.type: tlv for tlv in (*self.tlvs, *(self.protected_tlvs or []))}

def __str__(self) -> str:
rep = (
Expand All @@ -395,7 +408,7 @@ def __str__(self) -> str:
if self.protected_tlv_info:
rep += f"{self.protected_tlv_info}\n"

for tlv in self.protected_tlvs:
for tlv in self.protected_tlvs or []:
rep += f" {str(tlv)}\n"

return rep
Expand Down
88 changes: 88 additions & 0 deletions tests/test_mcuboot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,29 @@
from typing import Protocol

import pytest
from typing_extensions import assert_type

from smpclient.mcuboot import (
IMAGE_HEADER_STRUCT,
IMAGE_MAGIC,
IMAGE_TLV,
IMAGE_TLV_INFO_MAGIC,
IMAGE_TLV_INFO_STRUCT,
IMAGE_TLV_PROT_INFO_MAGIC,
ImageHeader,
ImageInfo,
ImageMagic,
ImageTLV,
ImageTLVInfo,
ImageTLVInfoMagic,
ImageTLVProtInfoMagic,
ImageTLVType,
ImageTLVValue,
ImageVersion,
MCUBootImageError,
TLVNotFound,
VendorTLV,
mcuimg,
)


Expand Down Expand Up @@ -67,6 +78,8 @@ def test_ImageInfo(image: _ImageFileFixture) -> None:

# TLV header
t = image_info.tlv_info
assert_type(t, ImageTLVInfo[ImageTLVInfoMagic])
assert_type(t.magic, ImageTLVInfoMagic)
assert t.magic == IMAGE_TLV_INFO_MAGIC
assert t.tlv_tot == 336

Expand Down Expand Up @@ -96,6 +109,7 @@ def test_ImageInfo(image: _ImageFileFixture) -> None:
def test_ImageHeader(image: _ImageFileFixture) -> None:
h = ImageHeader.load_file(str(image.PATH))

assert_type(h.magic, ImageMagic)
assert h.magic == IMAGE_MAGIC
assert h.load_addr == 0
assert h.hdr_size == 512
Expand Down Expand Up @@ -240,9 +254,15 @@ def test_protected_tlv_parsing() -> None:
)

assert image_info.protected_tlv_info is not None
assert_type(image_info.protected_tlv_info, ImageTLVInfo[ImageTLVProtInfoMagic])
assert_type(image_info.protected_tlv_info.magic, ImageTLVProtInfoMagic)
assert image_info.protected_tlv_info.magic == IMAGE_TLV_PROT_INFO_MAGIC
assert image_info.protected_tlvs is not None
assert len(image_info.protected_tlvs) == 3
assert len(image_info.tlvs) == 3

assert "SEC_CNT=" in str(image_info)

# imgtool should put these three regular TLVs in the image
image_info.get_tlv(IMAGE_TLV.SHA256)
image_info.get_tlv(IMAGE_TLV.KEYHASH)
Expand All @@ -252,3 +272,71 @@ def test_protected_tlv_parsing() -> None:
image_info.get_tlv(IMAGE_TLV.SEC_CNT)
image_info.get_tlv(IMAGE_TLV.BOOT_RECORD)
image_info.get_tlv(IMAGE_TLV.DEPENDENCY)


def test_tlv_info_magic_type_binding() -> None:
"""The expected magic argument binds the static type and the runtime check."""
info = ImageTLVInfo.loads(struct.pack("<HH", IMAGE_TLV_INFO_MAGIC, 100), IMAGE_TLV_INFO_MAGIC)
assert_type(info, ImageTLVInfo[ImageTLVInfoMagic])
assert_type(info.magic, ImageTLVInfoMagic)
assert info.magic == IMAGE_TLV_INFO_MAGIC

prot_info = ImageTLVInfo.loads(
struct.pack("<HH", IMAGE_TLV_PROT_INFO_MAGIC, 100), IMAGE_TLV_PROT_INFO_MAGIC
)
assert_type(prot_info, ImageTLVInfo[ImageTLVProtInfoMagic])
assert_type(prot_info.magic, ImageTLVProtInfoMagic)
assert prot_info.magic == IMAGE_TLV_PROT_INFO_MAGIC

with pytest.raises(MCUBootImageError):
ImageTLVInfo.loads(struct.pack("<HH", IMAGE_TLV_INFO_MAGIC, 100), IMAGE_TLV_PROT_INFO_MAGIC)

with pytest.raises(MCUBootImageError):
ImageTLVInfo.loads(struct.pack("<HH", IMAGE_TLV_PROT_INFO_MAGIC, 100), IMAGE_TLV_INFO_MAGIC)


def test_image_header_bad_magic() -> None:
with pytest.raises(MCUBootImageError):
ImageHeader.loads(IMAGE_HEADER_STRUCT.pack(0xDEADBEEF, 0, 32, 0, 0, 0, 0, 0, 0, 0))


def test_tlv_info_total_size_too_small() -> None:
with pytest.raises(MCUBootImageError):
ImageTLVInfo.loads(struct.pack("<HH", IMAGE_TLV_INFO_MAGIC, 2), IMAGE_TLV_INFO_MAGIC)


def test_tlv_value_length_mismatch() -> None:
with pytest.raises(MCUBootImageError):
ImageTLVValue(header=ImageTLV(type=0x10, len=4), value=b"\x00")


def test_get_tlv_not_found() -> None:
image_info = ImageInfo.load_file(str(SIGNED_BIN.PATH))
with pytest.raises(TLVNotFound):
image_info.get_tlv(IMAGE_TLV.SEC_CNT)


def test_invalid_hex_file(tmp_path: Path) -> None:
bad = tmp_path / "bad.hex"
bad.write_text("not a hex file\n")
with pytest.raises(MCUBootImageError):
ImageInfo.load_file(str(bad))


def test_protected_tlv_size_mismatch(tmp_path: Path) -> None:
image = tmp_path / "image.bin"
image.write_bytes(
IMAGE_HEADER_STRUCT.pack(IMAGE_MAGIC, 0, IMAGE_HEADER_STRUCT.size, 12, 0, 0, 0, 0, 0, 0)
+ IMAGE_TLV_INFO_STRUCT.pack(IMAGE_TLV_PROT_INFO_MAGIC, 8)
)
with pytest.raises(MCUBootImageError):
ImageInfo.load_file(str(image))


def test_mcuimg(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
monkeypatch.setattr("sys.argv", ["mcuimg", str(SIGNED_BIN.PATH)])
assert mcuimg() == 0
assert "ImageInfo" in capsys.readouterr().out

monkeypatch.setattr("sys.argv", ["mcuimg", "does-not-exist.bin"])
assert mcuimg() == -1
Loading