From 8a6b8b692fc113841e53b59687a1962d2d0c0f7b Mon Sep 17 00:00:00 2001 From: Schamper <1254028+Schamper@users.noreply.github.com> Date: Thu, 12 Feb 2026 20:55:50 +0100 Subject: [PATCH] Add endian keyword argument to type reads --- dissect/cstruct/bitbuffer.py | 7 +- dissect/cstruct/compiler.py | 14 +-- dissect/cstruct/cstruct.py | 43 ++++++++-- dissect/cstruct/types/base.py | 108 ++++++++++++++++------- dissect/cstruct/types/char.py | 20 +++-- dissect/cstruct/types/enum.py | 30 ++++--- dissect/cstruct/types/int.py | 16 ++-- dissect/cstruct/types/leb128.py | 12 ++- dissect/cstruct/types/packed.py | 24 +++--- dissect/cstruct/types/pointer.py | 132 +++++++++++++++++------------ dissect/cstruct/types/structure.py | 49 +++++++---- dissect/cstruct/types/void.py | 10 ++- dissect/cstruct/types/wchar.py | 35 +++++--- dissect/cstruct/utils.py | 62 ++++++++------ tests/test_basic.py | 10 +-- tests/test_bitbuffer.py | 35 +++++--- tests/test_compiler.py | 26 +++--- tests/test_types_base.py | 10 ++- tests/test_types_custom.py | 58 +++++++++++-- tests/test_types_pointer.py | 20 ++++- tests/test_types_structure.py | 24 ++++++ tests/test_types_union.py | 21 +++++ 22 files changed, 524 insertions(+), 242 deletions(-) diff --git a/dissect/cstruct/bitbuffer.py b/dissect/cstruct/bitbuffer.py index 57b1fb47..6f3d08d4 100644 --- a/dissect/cstruct/bitbuffer.py +++ b/dissect/cstruct/bitbuffer.py @@ -9,9 +9,10 @@ class BitBuffer: """Implements a bit buffer that can read and write bit fields.""" - def __init__(self, stream: BinaryIO, endian: str): + def __init__(self, stream: BinaryIO, *, endian: str, **kwargs): self.stream = stream self.endian = endian + self.kwargs = kwargs self._type: type[BaseType] | None = None self._buffer = 0 @@ -24,7 +25,7 @@ def read(self, field_type: type[BaseType], bits: int) -> int: self._type = field_type self._remaining = field_type.size * 8 - self._buffer = field_type._read(self.stream) + self._buffer = field_type._read(self.stream, endian=self.endian, **self.kwargs) if isinstance(self._buffer, bytes): if self.endian == "<": @@ -71,7 +72,7 @@ def write(self, field_type: type[BaseType], data: int, bits: int) -> None: def flush(self) -> None: if self._type is not None: - self._type._write(self.stream, self._buffer) + self._type._write(self.stream, self._buffer, endian=self.endian, **self.kwargs) self._type = None self._remaining = 0 self._buffer = 0 diff --git a/dissect/cstruct/compiler.py b/dissect/cstruct/compiler.py index d9445b4f..c193d902 100644 --- a/dissect/cstruct/compiler.py +++ b/dissect/cstruct/compiler.py @@ -117,7 +117,7 @@ def generate_source(self) -> str: """ if any(field.bits for field in self.fields): - preamble += "bit_reader = BitBuffer(stream, cls.cs.endian)\n" + preamble += "bit_reader = BitBuffer(stream, endian=endian, **kwargs)\n" read_code = "\n".join(self._generate_fields()) @@ -130,7 +130,7 @@ def generate_source(self) -> str: code = indent(dedent(preamble).lstrip() + read_code + dedent(outro), " ") - return f"def _read(cls, stream, context=None):\n{code}" + return f"def _read(cls, stream, *, context=None, endian, **kwargs):\n{code}" def _generate_fields(self) -> Iterator[str]: current_offset = 0 @@ -227,7 +227,7 @@ def align_to_field(field: Field) -> Iterator[str]: def _generate_structure(self, field: Field) -> Iterator[str]: template = f""" {"_s = stream.tell()" if field.type.dynamic else ""} - r["{field._name}"] = {self._map_field(field)}._read(stream, context=r) + r["{field._name}"] = {self._map_field(field)}._read(stream, context=r, endian=endian, **kwargs) {f's["{field._name}"] = stream.tell() - _s' if field.type.dynamic else ""} """ @@ -236,7 +236,7 @@ def _generate_structure(self, field: Field) -> Iterator[str]: def _generate_array(self, field: Field) -> Iterator[str]: template = f""" {"_s = stream.tell()" if field.type.dynamic else ""} - r["{field._name}"] = {self._map_field(field)}._read(stream, context=r) + r["{field._name}"] = {self._map_field(field)}._read(stream, context=r, endian=endian, **kwargs) {f's["{field._name}"] = stream.tell() - _s' if field.type.dynamic else ""} """ @@ -309,7 +309,7 @@ def _generate_packed(self, fields: list[Field]) -> Iterator[str]: item_parser = parser_template.format(type="_et", getter=f"_b[i:i + {field_type.type.size}]") list_comp = f"[{item_parser} for i in range(0, {count}, {field_type.type.size})]" elif issubclass(field_type.type, Pointer): - item_parser = "_et.__new__(_et, e, stream, r)" + item_parser = "_et.__new__(_et, e, stream, context=r, endian=endian, **kwargs)" list_comp = f"[{item_parser} for e in {getter}]" else: item_parser = parser_template.format(type="_et", getter="e") @@ -320,7 +320,7 @@ def _generate_packed(self, fields: list[Field]) -> Iterator[str]: parser = f"type.__call__({self._map_field(field)}, {getter})" elif issubclass(field_type, Pointer): reads.append(f"_pt = {self._map_field(field)}") - parser = f"_pt.__new__(_pt, {getter}, stream, r)" + parser = f"_pt.__new__(_pt, {getter}, stream, context=r, endian=endian, **kwargs)" else: parser = parser_template.format(type=self._map_field(field), getter=getter) @@ -333,7 +333,7 @@ def _generate_packed(self, fields: list[Field]) -> Iterator[str]: if fmt == "x" or (len(fmt) == 2 and fmt[1] == "x"): unpack = "" else: - unpack = f'data = _struct(cls.cs.endian, "{fmt}").unpack(buf)\n' + unpack = f'data = _struct(endian, "{fmt}").unpack(buf)\n' template = f""" buf = stream.read({size}) diff --git a/dissect/cstruct/cstruct.py b/dissect/cstruct/cstruct.py index 2907734c..c62fd555 100644 --- a/dissect/cstruct/cstruct.py +++ b/dissect/cstruct/cstruct.py @@ -1,13 +1,15 @@ from __future__ import annotations import ctypes as _ctypes +import inspect import struct import sys import types +import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO, TypeVar, cast +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypeVar, cast -from dissect.cstruct.exceptions import ResolveError +from dissect.cstruct.exceptions import Error, ResolveError from dissect.cstruct.expression import Expression from dissect.cstruct.parser import CStyleParser, TokenParser from dissect.cstruct.types import ( @@ -27,6 +29,7 @@ Void, Wchar, ) +from dissect.cstruct.types.base import normalize_endianness if TYPE_CHECKING: from collections.abc import Iterable @@ -35,20 +38,23 @@ T = TypeVar("T", bound=BaseType) +AllowedEndianness: TypeAlias = Literal["little", "big", "network", "<", ">", "!", "@", "="] +Endianness: TypeAlias = Literal["<", ">", "!", "@", "="] + class cstruct: """Main class of cstruct. All types are registered in here. Args: - endian: The endianness to use when parsing. + endian: The endianness to use when parsing (little, big, network, <, >, !, @ or =). pointer: The pointer type to use for pointers. """ DEF_CSTYLE = 1 DEF_LEGACY = 2 - def __init__(self, load: str = "", *, endian: str = "<", pointer: str | None = None): - self.endian = endian + def __init__(self, load: str = "", *, endian: AllowedEndianness = "<", pointer: str | None = None): + self.endian = normalize_endianness(endian) self.consts = {} self.lookups = {} @@ -242,6 +248,33 @@ def add_custom_type( alignment: The alignment of the type. **kwargs: Additional attributes to add to the type. """ + # In cstruct 4.8 we changed the function signature of _read and _write + # Check if the function signature is compatible, and throw an error if not + for type_to_check in (type_, type_.ArrayType): + type_name = type_.__name__ + (f".{type_.ArrayType.__name__}" if type_to_check is type_.ArrayType else "") + + for method in ("_read", "_read_array", "_read_0", "_write", "_write_array", "_write_0"): + if not hasattr(type_to_check, method): + continue + + signature = inspect.signature(getattr(type_to_check, method)) + + # We added a few keyword-only parameters to the function signature, but any custom type will + # continue to work fine as long as they accept **kwargs + if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + raise Error( + f"Custom type {type_name} has an incompatible {method} method signature. " + "Please refer to the changelog of dissect.cstruct 4.8 for more information." + ) + + # Only warn if the method doesn't accept an endian parameter + if "endian" not in signature.parameters: + warnings.warn( + f"Custom type {type_name} is missing the 'endian' keyword-only parameter in its {method} method. " # noqa: E501 + "Please refer to the changelog of dissect.cstruct 4.8 for more information.", + stacklevel=2, + ) + self.add_type(name, self._make_type(name, (type_,), size, alignment=alignment, attrs=kwargs)) def load(self, definition: str, deftype: int | None = None, **kwargs) -> cstruct: diff --git a/dissect/cstruct/types/base.py b/dissect/cstruct/types/base.py index d5ca7dbb..21d6d26a 100644 --- a/dissect/cstruct/types/base.py +++ b/dissect/cstruct/types/base.py @@ -12,7 +12,7 @@ from typing_extensions import Self - from dissect.cstruct.cstruct import cstruct + from dissect.cstruct.cstruct import AllowedEndianness, Endianness, cstruct EOF = -0xE0F # Negative counts are illegal anyway, so abuse that for our EOF sentinel @@ -41,14 +41,15 @@ def __call__(cls, *args, **kwargs) -> Self: # type: ignore stream = args[0] if _is_readable_type(stream): - return cls._read(stream) + endian = normalize_endianness(endian) if (endian := kwargs.get("endian")) is not None else cls.cs.endian + return cls._read(stream, endian=endian) if issubclass(cls, bytes) and isinstance(stream, bytes) and len(stream) == cls.size: # Shortcut for char/bytes type return type.__call__(cls, *args, **kwargs) if _is_buffer_type(stream): - return cls.reads(stream) + return cls.reads(stream, endian=kwargs.get("endian")) return type.__call__(cls, *args, **kwargs) @@ -71,60 +72,71 @@ def __default__(cls) -> Self: # type: ignore """Return the default value of this type.""" return cls() - def reads(cls, data: bytes | memoryview | bytearray) -> Self: # type: ignore + def reads(cls, data: bytes | memoryview | bytearray, *, endian: AllowedEndianness | None = None) -> Self: # type: ignore """Parse the given data from a bytes-like object. Args: data: Bytes-like object to parse. + endian: The endianness to use when parsing. If not provided, the cstruct's default endianness will be used. Returns: The parsed value of this type. """ - return cls._read(BytesIO(data)) + endian = normalize_endianness(endian) if endian is not None else cls.cs.endian + return cls._read(BytesIO(data), endian=endian) - def read(cls, obj: BinaryIO | bytes | memoryview | bytearray) -> Self: # type: ignore + def read(cls, obj: BinaryIO | bytes | memoryview | bytearray, *, endian: AllowedEndianness | None = None) -> Self: # type: ignore """Parse the given data. Args: obj: Data to parse. Can be a bytes-like object or a file-like object. + endian: The endianness to use when parsing. If not provided, the cstruct's default endianness will be used. Returns: The parsed value of this type. """ if _is_buffer_type(obj): - return cls.reads(obj) + return cls.reads(obj, endian=endian) if not _is_readable_type(obj): raise TypeError("Invalid object type") - return cls._read(obj) + endian = normalize_endianness(endian) if endian is not None else cls.cs.endian + return cls._read(obj, endian=endian) - def write(cls, stream: BinaryIO, value: Any) -> int: + def write(cls, stream: BinaryIO, value: Any, *, endian: AllowedEndianness | None = None) -> int: """Write a value to a writable file-like object. Args: stream: File-like objects that supports writing. value: Value to write. + endian: The endianness to use when writing. If not provided, the cstruct's default endianness will be used. Returns: The amount of bytes written. """ - return cls._write(stream, value) + endian = normalize_endianness(endian) if endian is not None else cls.cs.endian + return cls._write(stream, value, endian=endian) - def dumps(cls, value: Any) -> bytes: + def dumps(cls, value: Any, *, endian: AllowedEndianness | None = None) -> bytes: """Dump a value to a byte string. Args: value: Value to dump. + endian: The endianness to use when dumping. If not provided, the cstruct's default endianness will be used. Returns: The raw bytes of this type. """ + endian = normalize_endianness(endian) if endian is not None else cls.cs.endian + out = BytesIO() - cls._write(out, value) + cls._write(out, value, endian=endian) return out.getvalue() - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore + def _read( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness | None = None, **kwargs + ) -> Self: # type: ignore """Internal function for reading value. Must be implemented per type. @@ -132,10 +144,19 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: Args: stream: The stream to read from. context: Optional reading context. + endian: The endianness to use when reading. If not provided, the cstruct's default endianness will be used. """ raise NotImplementedError - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Self]: # type: ignore + def _read_array( + cls, + stream: BinaryIO, + count: int, + *, + context: dict[str, Any] | None = None, + endian: Endianness | None = None, + **kwargs, + ) -> list[Self]: # type: ignore """Internal function for reading array values. Allows type implementations to do optimized reading for their type. @@ -144,16 +165,19 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | Non stream: The stream to read from. count: The amount of values to read. context: Optional reading context. + endian: The endianness to use when reading. If not provided, the cstruct's default endianness will be used. """ if count == EOF: result = [] while not _is_eof(stream): - result.append(cls._read(stream, context)) + result.append(cls._read(stream, context=context, endian=endian, **kwargs)) return result - return [cls._read(stream, context) for _ in range(count)] + return [cls._read(stream, context=context, endian=endian, **kwargs) for _ in range(count)] - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]: + def _read_0( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness | None = None, **kwargs + ) -> list[Self]: """Internal function for reading null-terminated data. "Null" is type specific, so must be implemented per type. @@ -161,13 +185,14 @@ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> lis Args: stream: The stream to read from. context: Optional reading context. + endian: The endianness to use when reading. If not provided, the cstruct's default endianness will be used. """ raise NotImplementedError - def _write(cls, stream: BinaryIO, data: Any) -> int: + def _write(cls, stream: BinaryIO, data: Any, *, endian: Endianness | None = None, **kwargs) -> int: raise NotImplementedError - def _write_array(cls, stream: BinaryIO, array: list[Self]) -> int: # type: ignore + def _write_array(cls, stream: BinaryIO, array: list[Self], *, endian: Endianness | None = None, **kwargs) -> int: # type: ignore """Internal function for writing arrays. Allows type implementations to do optimized writing for their type. @@ -175,10 +200,11 @@ def _write_array(cls, stream: BinaryIO, array: list[Self]) -> int: # type: igno Args: stream: The stream to read from. array: The array to write. + endian: The endianness to use when reading. If not provided, the cstruct's default endianness will be used. """ - return sum(cls._write(stream, entry) for entry in array) + return sum(cls._write(stream, entry, endian=endian, **kwargs) for entry in array) - def _write_0(cls, stream: BinaryIO, array: list[Self]) -> int: # type: ignore + def _write_0(cls, stream: BinaryIO, array: list[Self], *, endian: Endianness | None = None, **kwargs) -> int: # type: ignore """Internal function for writing null-terminated arrays. Allows type implementations to do optimized writing for their type. @@ -186,8 +212,9 @@ def _write_0(cls, stream: BinaryIO, array: list[Self]) -> int: # type: ignore Args: stream: The stream to read from. array: The array to write. + endian: The endianness to use when reading. If not provided, the cstruct's default endianness will be used. """ - return cls._write_array(stream, [*array, cls.__default__()]) + return cls._write_array(stream, [*array, cls.__default__()], endian=endian, **kwargs) class _overload: @@ -250,9 +277,11 @@ def __default__(cls) -> BaseType: ) @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]: + def _read( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> list[BaseType]: if cls.null_terminated: - return cls.type._read_0(stream, context) + return cls.type._read_0(stream, context=context, endian=endian, **kwargs) if isinstance(cls.num_entries, int): num = max(0, cls.num_entries) @@ -266,23 +295,23 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[ raise num = EOF - return cls.type._read_array(stream, num, context) + return cls.type._read_array(stream, num, context=context, endian=endian, **kwargs) @classmethod - def _write(cls, stream: BinaryIO, data: list[Any]) -> int: + def _write(cls, stream: BinaryIO, data: list[Any], *, endian: Endianness, **kwargs) -> int: if cls.null_terminated: - return cls.type._write_0(stream, data) + return cls.type._write_0(stream, data, endian=endian, **kwargs) if not cls.dynamic and cls.num_entries != (actual_size := len(data)): raise ArraySizeError(f"Expected static array size {cls.num_entries}, got {actual_size} instead.") - return cls.type._write_array(stream, data) + return cls.type._write_array(stream, data, endian=endian, **kwargs) class Array(list[T], BaseArray): @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[T]: - return cls(super()._read(stream, context)) + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> list[T]: + return cls(super()._read(stream, context=context, endian=endian, **kwargs)) def _is_readable_type(value: object) -> bool: @@ -305,5 +334,24 @@ def _is_eof(stream: BinaryIO) -> bool: return False +ENDIANNESS_MAP: dict[AllowedEndianness, Endianness] = { + "<": "<", + ">": ">", + "!": "!", + "@": "@", + "=": "=", + "network": "!", + "little": "<", + "big": ">", +} + + +def normalize_endianness(endian: AllowedEndianness) -> Endianness: + """Normalize an endianness string to one of the standard format characters.""" + if endian not in ENDIANNESS_MAP: + raise ValueError(f"Invalid endianness: {endian}") + return ENDIANNESS_MAP[endian] + + # As mentioned in the BaseType class, we correctly set the type here MetaType.ArrayType = Array diff --git a/dissect/cstruct/types/char.py b/dissect/cstruct/types/char.py index a7364b46..2443a4be 100644 --- a/dissect/cstruct/types/char.py +++ b/dissect/cstruct/types/char.py @@ -7,6 +7,8 @@ if TYPE_CHECKING: from typing_extensions import Self + from dissect.cstruct.cstruct import Endianness + class CharArray(bytes, BaseArray): """Character array type for reading and writing byte strings.""" @@ -16,11 +18,11 @@ def __default__(cls) -> Self: return type.__call__(cls, b"\x00" * (0 if cls.dynamic or cls.null_terminated else cls.num_entries)) @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: - return type.__call__(cls, super()._read(stream, context)) + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: + return type.__call__(cls, super()._read(stream, context=context, endian=endian, **kwargs)) @classmethod - def _write(cls, stream: BinaryIO, data: bytes) -> int: + def _write(cls, stream: BinaryIO, data: bytes, *, endian: Endianness, **kwargs) -> int: if isinstance(data, list) and data and isinstance(data[0], int): data = bytes(data) @@ -42,11 +44,13 @@ def __default__(cls) -> Self: return type.__call__(cls, b"\x00") @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: - return cls._read_array(stream, 1, context) + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: + return cls._read_array(stream, 1, context=context, endian=endian, **kwargs) @classmethod - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Self: + def _read_array( + cls, stream: BinaryIO, count: int, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> Self: if count == 0: return type.__call__(cls, b"") @@ -57,7 +61,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | Non return type.__call__(cls, data) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: + def _read_0(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: buf = [] while True: byte = stream.read(1) @@ -72,7 +76,7 @@ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Sel return type.__call__(cls, b"".join(buf)) @classmethod - def _write(cls, stream: BinaryIO, data: bytes | int | str) -> int: + def _write(cls, stream: BinaryIO, data: bytes | int | str, *, endian: Endianness, **kwargs) -> int: if isinstance(data, int): data = chr(data) diff --git a/dissect/cstruct/types/enum.py b/dissect/cstruct/types/enum.py index d1f6bcb8..70e887d8 100644 --- a/dissect/cstruct/types/enum.py +++ b/dissect/cstruct/types/enum.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from dissect.cstruct.cstruct import cstruct + from dissect.cstruct.cstruct import Endianness, cstruct PY_311 = sys.version_info >= (3, 11, 0) @@ -83,25 +83,29 @@ def __contains__(cls, value: Any) -> bool: return True return value in cls._value2member_map_ - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: - return cls(cls.type._read(stream, context)) + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: + return cls(cls.type._read(stream, context=context, endian=endian, **kwargs)) - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Self]: - return list(map(cls, cls.type._read_array(stream, count, context))) + def _read_array( + cls, stream: BinaryIO, count: int, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> list[Self]: + return list(map(cls, cls.type._read_array(stream, count, context=context, endian=endian, **kwargs))) - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]: - return list(map(cls, cls.type._read_0(stream, context))) + def _read_0( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> list[Self]: + return list(map(cls, cls.type._read_0(stream, context=context, endian=endian, **kwargs))) - def _write(cls, stream: BinaryIO, data: Enum) -> int: - return cls.type._write(stream, data.value) + def _write(cls, stream: BinaryIO, data: Enum, *, endian: Endianness, **kwargs) -> int: + return cls.type._write(stream, data.value, endian=endian, **kwargs) - def _write_array(cls, stream: BinaryIO, array: list[BaseType | int]) -> int: + def _write_array(cls, stream: BinaryIO, array: list[BaseType | int], *, endian: Endianness, **kwargs) -> int: data = [entry.value if isinstance(entry, _Enum) else entry for entry in array] - return cls.type._write_array(stream, data) + return cls.type._write_array(stream, data, endian=endian, **kwargs) - def _write_0(cls, stream: BinaryIO, array: list[BaseType | int]) -> int: + def _write_0(cls, stream: BinaryIO, array: list[BaseType | int], *, endian: Endianness, **kwargs) -> int: data = [entry.value if isinstance(entry, _Enum) else entry for entry in array] - return cls._write_array(stream, [*data, cls.type.__default__()]) + return cls._write_array(stream, [*data, cls.type.__default__()], endian=endian, **kwargs) def _fix_alias_members(cls: type[Enum]) -> None: diff --git a/dissect/cstruct/types/int.py b/dissect/cstruct/types/int.py index 4d96ad15..9dfa1468 100644 --- a/dissect/cstruct/types/int.py +++ b/dissect/cstruct/types/int.py @@ -3,11 +3,13 @@ from typing import TYPE_CHECKING, Any, BinaryIO from dissect.cstruct.types.base import BaseType -from dissect.cstruct.utils import ENDIANNESS_MAP +from dissect.cstruct.utils import ENDIANNESS_TO_BYTEORDER_MAP if TYPE_CHECKING: from typing_extensions import Self + from dissect.cstruct.cstruct import Endianness + class Int(int, BaseType): """Integer type that can span an arbitrary amount of bytes.""" @@ -15,20 +17,20 @@ class Int(int, BaseType): signed: bool @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: data = stream.read(cls.size) if len(data) != cls.size: raise EOFError(f"Read {len(data)} bytes, but expected {cls.size}") - return cls.from_bytes(data, ENDIANNESS_MAP[cls.cs.endian], signed=cls.signed) + return cls.from_bytes(data, ENDIANNESS_TO_BYTEORDER_MAP[endian], signed=cls.signed) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: + def _read_0(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: result = [] while True: - if (value := cls._read(stream, context)) == 0: + if (value := cls._read(stream, context=context, endian=endian, **kwargs)) == 0: break result.append(value) @@ -36,5 +38,5 @@ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Sel return result @classmethod - def _write(cls, stream: BinaryIO, data: int) -> int: - return stream.write(data.to_bytes(cls.size, ENDIANNESS_MAP[cls.cs.endian], signed=cls.signed)) + def _write(cls, stream: BinaryIO, data: int, *, endian: Endianness, **kwargs) -> int: + return stream.write(data.to_bytes(cls.size, ENDIANNESS_TO_BYTEORDER_MAP[endian], signed=cls.signed)) diff --git a/dissect/cstruct/types/leb128.py b/dissect/cstruct/types/leb128.py index ccccffcc..ce9fc4a1 100644 --- a/dissect/cstruct/types/leb128.py +++ b/dissect/cstruct/types/leb128.py @@ -7,6 +7,8 @@ if TYPE_CHECKING: from typing_extensions import Self + from dissect.cstruct.cstruct import Endianness + class LEB128(int, BaseType): """Variable-length code compression to store an arbitrarily large integer in a small number of bytes. @@ -17,7 +19,7 @@ class LEB128(int, BaseType): signed: bool @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: result = 0 shift = 0 while True: @@ -37,11 +39,13 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: return cls.__new__(cls, result) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]: + def _read_0( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> list[Self]: result = [] while True: - if (value := cls._read(stream, context)) == 0: + if (value := cls._read(stream, context=context, endian=endian, **kwargs)) == 0: break result.append(value) @@ -49,7 +53,7 @@ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> lis return result @classmethod - def _write(cls, stream: BinaryIO, data: int) -> int: + def _write(cls, stream: BinaryIO, data: int, *, endian: Endianness, **kwargs) -> int: # only write negative numbers when in signed mode if data < 0 and not cls.signed: raise ValueError("Attempt to encode a negative integer using unsigned LEB128 encoding") diff --git a/dissect/cstruct/types/packed.py b/dissect/cstruct/types/packed.py index 76e70996..78d2cd26 100644 --- a/dissect/cstruct/types/packed.py +++ b/dissect/cstruct/types/packed.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from typing_extensions import Self + from dissect.cstruct.cstruct import Endianness + @lru_cache(1024) def _struct(endian: str, packchar: str) -> Struct: @@ -24,11 +26,13 @@ class Packed(BaseType, Generic[T]): packchar: str @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: - return cls._read_array(stream, 1, context)[0] + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: + return cls._read_array(stream, 1, context=context, endian=endian, **kwargs)[0] @classmethod - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Self]: + def _read_array( + cls, stream: BinaryIO, count: int, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> list[Self]: if count == EOF: data = stream.read() length = len(data) @@ -37,7 +41,7 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | Non length = cls.size * count data = stream.read(length) - fmt = _struct(cls.cs.endian, f"{count}{cls.packchar}") + fmt = _struct(endian, f"{count}{cls.packchar}") if len(data) != length: raise EOFError(f"Read {len(data)} bytes, but expected {length}") @@ -45,10 +49,10 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | Non return [cls.__new__(cls, value) for value in fmt.unpack(data)] @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None, *, endian: Endianness) -> Self: result = [] - fmt = _struct(cls.cs.endian, cls.packchar) + fmt = _struct(endian, cls.packchar) while True: data = stream.read(cls.size) @@ -63,9 +67,9 @@ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Sel return result @classmethod - def _write(cls, stream: BinaryIO, data: Packed[T]) -> int: - return stream.write(_struct(cls.cs.endian, cls.packchar).pack(data)) + def _write(cls, stream: BinaryIO, data: Packed[T], *, endian: Endianness, **kwargs) -> int: + return stream.write(_struct(endian, cls.packchar).pack(data)) @classmethod - def _write_array(cls, stream: BinaryIO, data: list[Packed[T]]) -> int: - return stream.write(_struct(cls.cs.endian, f"{len(data)}{cls.packchar}").pack(*data)) + def _write_array(cls, stream: BinaryIO, data: list[Packed[T]], *, endian: Endianness, **kwargs) -> int: + return stream.write(_struct(endian, f"{len(data)}{cls.packchar}").pack(*data)) diff --git a/dissect/cstruct/types/pointer.py b/dissect/cstruct/types/pointer.py index 289e4004..c4a6c0be 100644 --- a/dissect/cstruct/types/pointer.py +++ b/dissect/cstruct/types/pointer.py @@ -1,15 +1,20 @@ from __future__ import annotations +from functools import cache from typing import TYPE_CHECKING, Any, BinaryIO, Generic, TypeVar from dissect.cstruct.exceptions import NullPointerDereference -from dissect.cstruct.types.base import BaseType +from dissect.cstruct.types.base import BaseType, normalize_endianness from dissect.cstruct.types.char import Char from dissect.cstruct.types.void import Void if TYPE_CHECKING: + from collections.abc import Callable + from typing_extensions import Self + from dissect.cstruct.cstruct import AllowedEndianness, Endianness + T = TypeVar("T", bound=BaseType) @@ -19,13 +24,18 @@ class Pointer(int, BaseType, Generic[T]): type: type[T] _stream: BinaryIO | None _context: dict[str, Any] | None - _value: T | None + _endian: Endianness + _kwargs: dict[str, Any] - def __new__(cls, value: int, stream: BinaryIO | None, context: dict[str, Any] | None = None) -> Self: + def __new__( + cls, value: int, stream: BinaryIO | None, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> Self: obj = super().__new__(cls, value) obj._stream = stream obj._context = context - obj._value = None + obj._endian = endian + obj._kwargs = kwargs + obj.dereference = cache(obj.dereference) return obj def __repr__(self) -> str: @@ -37,68 +47,80 @@ def __str__(self) -> str: def __getattr__(self, attr: str) -> Any: return getattr(self.dereference(), attr) - def __add__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__add__(self, other), self._stream, self._context) - - def __sub__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__sub__(self, other), self._stream, self._context) - - def __mul__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__mul__(self, other), self._stream, self._context) - - def __floordiv__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__floordiv__(self, other), self._stream, self._context) - - def __mod__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__mod__(self, other), self._stream, self._context) - - def __pow__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__pow__(self, other), self._stream, self._context) - - def __lshift__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__lshift__(self, other), self._stream, self._context) - - def __rshift__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__rshift__(self, other), self._stream, self._context) - - def __and__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__and__(self, other), self._stream, self._context) - - def __xor__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__xor__(self, other), self._stream, self._context) - - def __or__(self, other: int) -> Self: - return type.__call__(self.__class__, int.__or__(self, other), self._stream, self._context) + @staticmethod + def __op(op: Callable[[int, int], int]) -> Self: + def method(self: Self, other: int) -> Self: + return type.__call__( + self.__class__, + op(self, other), + self._stream, + context=self._context, + endian=self._endian, + **self._kwargs, + ) + + return method + + __add__ = __op(int.__add__) + __sub__ = __op(int.__sub__) + __mul__ = __op(int.__mul__) + __floordiv__ = __op(int.__floordiv__) + __mod__ = __op(int.__mod__) + __pow__ = __op(int.__pow__) + __lshift__ = __op(int.__lshift__) + __rshift__ = __op(int.__rshift__) + __and__ = __op(int.__and__) + __xor__ = __op(int.__xor__) + __or__ = __op(int.__or__) @classmethod def __default__(cls) -> Self: - return cls.__new__(cls, cls.cs.pointer.__default__(), None, None) + return cls.__new__( + cls, + cls.cs.pointer.__default__(), + None, + context=None, + endian=cls.cs.endian, + ) @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: - return cls.__new__(cls, cls.cs.pointer._read(stream, context), stream, context) + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: + return cls.__new__( + cls, + cls.cs.pointer._read(stream, context=context, endian=endian, **kwargs), + stream, + context=context, + endian=endian, + **kwargs, + ) @classmethod - def _write(cls, stream: BinaryIO, data: int) -> int: - return cls.cs.pointer._write(stream, data) + def _write(cls, stream: BinaryIO, data: int, *, endian: Endianness, **kwargs) -> int: + return cls.cs.pointer._write(stream, data, endian=endian, **kwargs) + + def dereference(self, *, endian: AllowedEndianness | None = None) -> T: + """Dereference the pointer and read the value it points to. - def dereference(self) -> T: + Args: + endian: Optional endianness to use when reading the value. + If not provided, the endianness used when reading the pointer itself will be used. + """ if self == 0 or self._stream is None: raise NullPointerDereference - if self._value is None and not issubclass(self.type, Void): - # Read current position of file read/write pointer - position = self._stream.tell() - # Reposition the file read/write pointer - self._stream.seek(self) + endian = normalize_endianness(endian) if endian is not None else self._endian + if issubclass(self.type, Void): + return None - if issubclass(self.type, Char): - # this makes the assumption that a char pointer is a null-terminated string - value = self.type._read_0(self._stream, self._context) - else: - value = self.type._read(self._stream, self._context) + position = self._stream.tell() + self._stream.seek(self) - self._stream.seek(position) - self._value = value + if issubclass(self.type, Char): + # This makes the assumption that a char pointer is a null-terminated string + value = self.type._read_0(self._stream, context=self._context, endian=endian, **self._kwargs) + else: + value = self.type._read(self._stream, context=self._context, endian=endian, **self._kwargs) - return self._value + # Restore the stream position after reading the value + self._stream.seek(position) + return value diff --git a/dissect/cstruct/types/structure.py b/dissect/cstruct/types/structure.py index f2e8e037..c462109f 100644 --- a/dissect/cstruct/types/structure.py +++ b/dissect/cstruct/types/structure.py @@ -28,6 +28,8 @@ from typing_extensions import Self + from dissect.cstruct.cstruct import Endianness + class Field: """Structure field.""" @@ -246,8 +248,8 @@ def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) - # The structure size is whatever the currently calculated offset is return offset, alignment - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore - bit_buffer = BitBuffer(stream, cls.cs.endian) + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: # type: ignore + bit_buffer = BitBuffer(stream, endian=endian, **kwargs) struct_start = stream.tell() result = {} @@ -276,7 +278,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: bit_buffer.reset() - value = field.type._read(stream, result) + value = field.type._read(stream, context=result, endian=endian, **kwargs) result[field._name] = value if field.type.dynamic: @@ -292,16 +294,18 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: obj.__dynamic_sizes__ = sizes return obj - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]: # type: ignore + def _read_0( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> list[Self]: # type: ignore result = [] - while obj := cls._read(stream, context): + while obj := cls._read(stream, context=context, endian=endian, **kwargs): result.append(obj) return result - def _write(cls, stream: BinaryIO, data: Structure) -> int: - bit_buffer = BitBuffer(stream, cls.cs.endian) + def _write(cls, stream: BinaryIO, data: Structure, *, endian: Endianness, **kwargs) -> int: + bit_buffer = BitBuffer(stream, endian=endian, **kwargs) struct_start = stream.tell() num = 0 @@ -346,7 +350,7 @@ def _write(cls, stream: BinaryIO, data: Structure) -> int: else: bit_buffer.write(field_type, value, field.bits) else: - field_type._write(stream, value) + field_type._write(stream, value, endian=endian, **kwargs) num += stream.tell() - offset if bit_buffer._type is not None: @@ -459,6 +463,13 @@ class UnionMetaType(StructureMetaType): def __call__(cls, *args, **kwargs) -> Self: # type: ignore obj: Union = super().__call__(*args, **kwargs) + if not hasattr(obj, "_buf"): + # If we don't have a _buf attribute, we haven't read from a stream and are initializing with values + # Set default internal attributes + object.__setattr__(obj, "_buf", None) + object.__setattr__(obj, "_endian", cls.cs.endian) + object.__setattr__(obj, "_kwargs", {}) + # Calling with non-stream args or kwargs means we are initializing with values if (args and not (len(args) == 1 and (_is_readable_type(args[0]) or _is_buffer_type(args[0])))) or kwargs: # We don't support user initialization of dynamic unions yet @@ -502,7 +513,7 @@ def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) - return size, alignment def _read_fields( - cls, stream: BinaryIO, context: dict[str, Any] | None = None + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs ) -> tuple[dict[str, Any], dict[str, int]]: result = {} sizes = {} @@ -522,7 +533,7 @@ def _read_fields( start = field.offset buf.seek(offset + start) - value = field_type._read(buf, result) + value = field_type._read(buf, context=result, endian=endian, **kwargs) result[field._name] = value if field.type.dynamic: @@ -530,10 +541,10 @@ def _read_fields( return result, sizes - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: # type: ignore if cls.size is None: start = stream.tell() - result, sizes = cls._read_fields(stream, context) + result, sizes = cls._read_fields(stream, context=context, endian=endian, **kwargs) size = stream.tell() - start stream.seek(start) buf = stream.read(size) @@ -550,6 +561,8 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: obj: Union = type.__call__(cls, **result) object.__setattr__(obj, "__dynamic_sizes__", sizes) object.__setattr__(obj, "_buf", buf) + object.__setattr__(obj, "_endian", endian) + object.__setattr__(obj, "_kwargs", kwargs) if cls.size is not None: obj._update() @@ -559,7 +572,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: return obj - def _write(cls, stream: BinaryIO, data: Union) -> int: + def _write(cls, stream: BinaryIO, data: Union, *, endian: Endianness, **kwargs) -> int: if cls.dynamic: raise NotImplementedError("Writing dynamic unions is not yet supported") @@ -578,12 +591,12 @@ def _write(cls, stream: BinaryIO, data: Union) -> int: continue # Write the value - field.type._write(stream, getattr(data, field._name)) + field.type._write(stream, getattr(data, field._name), endian=endian, **kwargs) break # If we haven't written anything yet and we initially skipped an anonymous struct, write it now if stream.tell() == offset and anonymous_struct: - anonymous_struct._write(stream, data) + anonymous_struct._write(stream, data, endian=endian, **kwargs) # If we haven't filled the union size yet, pad it if remaining := expected_offset - stream.tell(): @@ -596,6 +609,8 @@ class Union(Structure, metaclass=UnionMetaType): """Base class for cstruct union type classes.""" _buf: bytes + _endian: Endianness + _kwargs: dict[str, Any] def __eq__(self, other: object) -> bool: return self.__class__ is other.__class__ and bytes(self) == bytes(other) @@ -619,7 +634,7 @@ def _rebuild(self, attr: str) -> None: if (value := getattr(self, attr)) is None: value = field.type.__default__() - field.type._write(buf, value) + field.type._write(buf, value, endian=self._endian, **self._kwargs) object.__setattr__(self, "_buf", buf.getvalue()) self._update() @@ -628,7 +643,7 @@ def _rebuild(self, attr: str) -> None: self._proxify() def _update(self) -> None: - result, sizes = self.__class__._read_fields(io.BytesIO(self._buf)) + result, sizes = self.__class__._read_fields(io.BytesIO(self._buf), endian=self._endian, **self._kwargs) self.__dict__.update(result) object.__setattr__(self, "__dynamic_sizes__", sizes) diff --git a/dissect/cstruct/types/void.py b/dissect/cstruct/types/void.py index fc07888c..c6afca17 100644 --- a/dissect/cstruct/types/void.py +++ b/dissect/cstruct/types/void.py @@ -7,6 +7,8 @@ if TYPE_CHECKING: from typing_extensions import Self + from dissect.cstruct.cstruct import Endianness + class VoidArray(list, BaseArray): """Array type representing void elements, primarily used for no-op reading and writing operations.""" @@ -16,11 +18,11 @@ def __default__(cls) -> Self: return cls() @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: return cls() @classmethod - def _write(cls, stream: BinaryIO, data: bytes) -> int: + def _write(cls, stream: BinaryIO, data: bytes, *, endian: Endianness, **kwargs) -> int: return 0 @@ -36,9 +38,9 @@ def __eq__(self, value: object) -> bool: return isinstance(value, Void) @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Self: return cls.__new__(cls) @classmethod - def _write(cls, stream: BinaryIO, data: Void) -> int: + def _write(cls, stream: BinaryIO, data: Void, *, endian: Endianness, **kwargs) -> int: return 0 diff --git a/dissect/cstruct/types/wchar.py b/dissect/cstruct/types/wchar.py index 1cd88224..3ec62577 100644 --- a/dissect/cstruct/types/wchar.py +++ b/dissect/cstruct/types/wchar.py @@ -1,10 +1,13 @@ from __future__ import annotations import sys -from typing import Any, BinaryIO, ClassVar +from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar from dissect.cstruct.types.base import EOF, BaseArray, BaseType +if TYPE_CHECKING: + from dissect.cstruct.cstruct import Endianness + class WcharArray(str, BaseArray): """Wide-character array type for reading and writing UTF-16 strings.""" @@ -16,14 +19,16 @@ def __default__(cls) -> WcharArray: return type.__call__(cls, "\x00" * (0 if cls.dynamic or cls.null_terminated else cls.num_entries)) @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> WcharArray: - return type.__call__(cls, super()._read(stream, context)) + def _read( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> WcharArray: + return type.__call__(cls, super()._read(stream, context=context, endian=endian, **kwargs)) @classmethod - def _write(cls, stream: BinaryIO, data: str) -> int: + def _write(cls, stream: BinaryIO, data: str, *, endian: Endianness, **kwargs) -> int: if cls.null_terminated: data += "\x00" - return stream.write(data.encode(Wchar.__encoding_map__[cls.cs.endian])) + return stream.write(data.encode(Wchar.__encoding_map__[endian])) class Wchar(str, BaseType): @@ -32,7 +37,7 @@ class Wchar(str, BaseType): ArrayType = WcharArray __slots__ = () - __encoding_map__: ClassVar[dict[str, str]] = { + __encoding_map__: ClassVar[dict[Endianness, str]] = { "@": f"utf-16-{sys.byteorder[0]}e", "=": f"utf-16-{sys.byteorder[0]}e", "<": "utf-16-le", @@ -45,11 +50,13 @@ def __default__(cls) -> Wchar: return type.__call__(cls, "\x00") @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Wchar: - return cls._read_array(stream, 1, context) + def _read(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Wchar: + return cls._read_array(stream, 1, context=context, endian=endian, **kwargs) @classmethod - def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Wchar: + def _read_array( + cls, stream: BinaryIO, count: int, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> Wchar: if count == 0: return type.__call__(cls, "") @@ -60,10 +67,10 @@ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | Non if count != EOF and len(data) != count: raise EOFError(f"Read {len(data)} bytes, but expected {count}") - return type.__call__(cls, data.decode(cls.__encoding_map__[cls.cs.endian])) + return type.__call__(cls, data.decode(cls.__encoding_map__[endian])) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Wchar: + def _read_0(cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs) -> Wchar: buf = [] while True: point = stream.read(2) @@ -75,8 +82,8 @@ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Wch buf.append(point) - return type.__call__(cls, b"".join(buf).decode(cls.__encoding_map__[cls.cs.endian])) + return type.__call__(cls, b"".join(buf).decode(cls.__encoding_map__[endian])) @classmethod - def _write(cls, stream: BinaryIO, data: str) -> int: - return stream.write(data.encode(cls.__encoding_map__[cls.cs.endian])) + def _write(cls, stream: BinaryIO, data: str, *, endian: Endianness, **kwargs) -> int: + return stream.write(data.encode(cls.__encoding_map__[endian])) diff --git a/dissect/cstruct/utils.py b/dissect/cstruct/utils.py index 72c13f03..63d93a7a 100644 --- a/dissect/cstruct/utils.py +++ b/dissect/cstruct/utils.py @@ -13,6 +13,8 @@ from collections.abc import Iterator from typing import Literal + from dissect.cstruct.cstruct import AllowedEndianness, Endianness + COLOR_RED = "\033[1;31m" COLOR_GREEN = "\033[1;32m" COLOR_YELLOW = "\033[1;33m" @@ -32,13 +34,16 @@ PRINTABLE = string.digits + string.ascii_letters + string.punctuation + " " -ENDIANNESS_MAP: dict[str, Literal["big", "little"]] = { - "@": sys.byteorder, - "=": sys.byteorder, + +ENDIANNESS_TO_BYTEORDER_MAP: dict[AllowedEndianness, Literal["big", "little"]] = { "<": "little", ">": "big", "!": "big", + "@": sys.byteorder, + "=": sys.byteorder, "network": "big", + "little": "little", + "big": "big", } Palette = list[tuple[int, str]] @@ -215,111 +220,118 @@ def dumpstruct( raise ValueError("Invalid arguments") -def pack(value: int, size: int | None = None, endian: str = "little") -> bytes: +def pack(value: int, size: int | None = None, endian: AllowedEndianness = "little") -> bytes: """Pack an integer value to a given bit size, endianness. Arguments: value: Value to pack. size: Integer size in bits. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). """ + if endian not in ENDIANNESS_TO_BYTEORDER_MAP: + raise ValueError(f"Invalid endianness: {endian!r} (should be little, big, network, <, >, !, @ or =)") + size = ((size or value.bit_length()) + 7) // 8 - return value.to_bytes(size, ENDIANNESS_MAP.get(endian, endian), signed=value < 0) + return value.to_bytes(size, ENDIANNESS_TO_BYTEORDER_MAP[endian], signed=value < 0) -def unpack(value: bytes, size: int | None = None, endian: str = "little", sign: bool = False) -> int: +def unpack(value: bytes, size: int | None = None, endian: AllowedEndianness = "little", sign: bool = False) -> int: """Unpack an integer value from a given bit size, endianness and sign. Arguments: value: Value to unpack. size: Integer size in bits. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). sign: Signedness of the integer. """ if size and len(value) != size // 8: raise ValueError(f"Invalid byte value, expected {size // 8} bytes, got {len(value)} bytes") - return int.from_bytes(value, ENDIANNESS_MAP.get(endian, endian), signed=sign) + + if endian not in ENDIANNESS_TO_BYTEORDER_MAP: + raise ValueError(f"Invalid endianness: {endian!r} (should be little, big, network, <, >, !, @ or =)") + + return int.from_bytes(value, ENDIANNESS_TO_BYTEORDER_MAP[endian], signed=sign) -def p8(value: int, endian: str = "little") -> bytes: +def p8(value: int, endian: AllowedEndianness = "little") -> bytes: """Pack an 8 bit integer. Arguments: value: Value to pack. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). """ return pack(value, 8, endian) -def p16(value: int, endian: str = "little") -> bytes: +def p16(value: int, endian: AllowedEndianness = "little") -> bytes: """Pack a 16 bit integer. Arguments: value: Value to pack. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). """ return pack(value, 16, endian) -def p32(value: int, endian: str = "little") -> bytes: +def p32(value: int, endian: AllowedEndianness = "little") -> bytes: """Pack a 32 bit integer. Arguments: value: Value to pack. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). """ return pack(value, 32, endian) -def p64(value: int, endian: str = "little") -> bytes: +def p64(value: int, endian: AllowedEndianness = "little") -> bytes: """Pack a 64 bit integer. Arguments: value: Value to pack. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). """ return pack(value, 64, endian) -def u8(value: bytes, endian: str = "little", sign: bool = False) -> int: +def u8(value: bytes, endian: AllowedEndianness = "little", sign: bool = False) -> int: """Unpack an 8 bit integer. Arguments: value: Value to unpack. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). sign: Signedness of the integer. """ return unpack(value, 8, endian, sign) -def u16(value: bytes, endian: str = "little", sign: bool = False) -> int: +def u16(value: bytes, endian: AllowedEndianness = "little", sign: bool = False) -> int: """Unpack a 16 bit integer. Arguments: value: Value to unpack. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). sign: Signedness of the integer. """ return unpack(value, 16, endian, sign) -def u32(value: bytes, endian: str = "little", sign: bool = False) -> int: +def u32(value: bytes, endian: AllowedEndianness = "little", sign: bool = False) -> int: """Unpack a 32 bit integer. Arguments: value: Value to unpack. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). sign: Signedness of the integer. """ return unpack(value, 32, endian, sign) -def u64(value: bytes, endian: str = "little", sign: bool = False) -> int: +def u64(value: bytes, endian: AllowedEndianness = "little", sign: bool = False) -> int: """Unpack a 64 bit integer. Arguments: value: Value to unpack. - endian: Endianness to use (little, big, network, <, > or !) + endian: Endianness to use (little, big, network, <, >, !, @ or =). sign: Signedness of the integer. """ return unpack(value, 64, endian, sign) diff --git a/tests/test_basic.py b/tests/test_basic.py index bb0dd477..64a9a9f3 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -6,7 +6,7 @@ import pytest -from dissect.cstruct.cstruct import cstruct +from dissect.cstruct.cstruct import Endianness, cstruct from dissect.cstruct.exceptions import ArraySizeError, ParserError, ResolveError from dissect.cstruct.types import BaseType @@ -215,12 +215,12 @@ class OffByOne(int, BaseType): type: BaseType @classmethod - def _read(cls, stream: BinaryIO, context: dict | None = None) -> OffByOne: - return cls(cls.type._read(stream, context) + 1) + def _read(cls, stream: BinaryIO, *, context: dict | None = None, endian: Endianness, **kwargs) -> OffByOne: + return cls(cls.type._read(stream, context=context, endian=endian, **kwargs) + 1) @classmethod - def _write(cls, stream: BinaryIO, data: int) -> OffByOne: - return cls(cls.type._write(stream, data - 1)) + def _write(cls, stream: BinaryIO, data: int, *, endian: Endianness, **kwargs) -> OffByOne: + return cls(cls.type._write(stream, data - 1, endian=endian, **kwargs)) # Add an unsupported type for the cstruct compiler # so that it returns the original struct, diff --git a/tests/test_bitbuffer.py b/tests/test_bitbuffer.py index 9d3a1637..770dd66c 100644 --- a/tests/test_bitbuffer.py +++ b/tests/test_bitbuffer.py @@ -12,26 +12,41 @@ def test_bitbuffer_read(cs: cstruct) -> None: - bb = BitBuffer(BytesIO(b"\xff"), "<") + # http://mjfrazer.org/mjfrazer/bitfields/ + bb = BitBuffer(BytesIO(b"\xff"), endian="<") assert bb.read(cs.uint8, 8) == 0b11111111 - bb = BitBuffer(BytesIO(b"\xf0"), "<") + bb = BitBuffer(BytesIO(b"\xf0"), endian="<") assert bb.read(cs.uint8, 4) == 0b0000 assert bb.read(cs.uint8, 4) == 0b1111 - bb = BitBuffer(BytesIO(b"\xf0"), ">") + bb = BitBuffer(BytesIO(b"\xf0"), endian=">") assert bb.read(cs.uint8, 4) == 0b1111 assert bb.read(cs.uint8, 4) == 0b0000 - bb = BitBuffer(BytesIO(b"\xff\x00"), "<") - assert bb.read(cs.uint16, 12) == 0b11111111 + bb = BitBuffer(BytesIO(b"\xff\x00"), endian="<") + assert bb.read(cs.uint16, 12) == 0b000011111111 assert bb.read(cs.uint16, 4) == 0b0 - bb = BitBuffer(BytesIO(b"\xff\x00"), ">") - assert bb.read(cs.uint16, 12) == 0b000000001111 - assert bb.read(cs.uint16, 4) == 0b1111 - - bb = BitBuffer(BytesIO(b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"), "<") + bb = BitBuffer(BytesIO(b"\xff\x00"), endian=">") + assert bb.read(cs.uint16, 12) == 0b111111110000 + assert bb.read(cs.uint16, 4) == 0b0000 + + bb = BitBuffer(BytesIO(b"\x12\x34"), endian=">") + assert bb.read(cs.uint16, 4) == 1 + assert bb.read(cs.uint16, 4) == 2 + assert bb.read(cs.uint16, 4) == 3 + assert bb.read(cs.uint16, 4) == 4 + + bb = BitBuffer(BytesIO(b"\x12\x34"), endian="<") + assert bb.read(cs.uint16, 4) == 2 + assert bb.read(cs.uint16, 4) == 1 + assert bb.read(cs.uint16, 4) == 4 + assert bb.read(cs.uint16, 4) == 3 + + bb = BitBuffer( + BytesIO(b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"), endian="<" + ) with pytest.raises(ValueError, match="Reading straddled bits is unsupported"): assert bb.read(cs.uint32, 160) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 6ce1017a..9103715f 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -109,7 +109,7 @@ def test_generate_packed_read(cs: cstruct) -> None: expected = """ buf = stream.read(15) if len(buf) != 15: raise EOFError() - data = _struct(cls.cs.endian, "BhIq").unpack(buf) + data = _struct(endian, "BhIq").unpack(buf) r["a"] = type.__call__(_0, data[0]) @@ -135,7 +135,7 @@ def test_generate_packed_read_array(cs: cstruct) -> None: expected = """ buf = stream.read(64) if len(buf) != 64: raise EOFError() - data = _struct(cls.cs.endian, "2B3h4I5q").unpack(buf) + data = _struct(endian, "2B3h4I5q").unpack(buf) _t = _0 _et = _t.type @@ -171,7 +171,7 @@ def test_generate_packed_read_byte_types(cs: cstruct) -> None: expected = """ buf = stream.read(18) if len(buf) != 18: raise EOFError() - data = _struct(cls.cs.endian, "18x").unpack(buf) + data = _struct(endian, "18x").unpack(buf) r["a"] = type.__call__(_0, buf[0:1]) @@ -207,12 +207,12 @@ def test_generate_packed_read_composite_types(cs: cstruct, TestEnum: type[Enum]) expected = """ buf = stream.read(11) if len(buf) != 11: raise EOFError() - data = _struct(cls.cs.endian, "BQ2B").unpack(buf) + data = _struct(endian, "BQ2B").unpack(buf) r["a"] = type.__call__(_0, data[0]) _pt = _1 - r["b"] = _pt.__new__(_pt, data[1], stream, r) + r["b"] = _pt.__new__(_pt, data[1], stream, context=r, endian=endian, **kwargs) _t = _2 _et = _t.type @@ -232,7 +232,7 @@ def test_generate_packed_read_offsets(cs: cstruct) -> None: expected = """ buf = stream.read(9) if len(buf) != 9: raise EOFError() - data = _struct(cls.cs.endian, "B7xB").unpack(buf) + data = _struct(endian, "B7xB").unpack(buf) r["a"] = type.__call__(_0, data[0]) @@ -251,7 +251,7 @@ def test_generate_structure_read(cs: cstruct) -> None: expected = """ _s = stream.tell() - r["a"] = _0._read(stream, context=r) + r["a"] = _0._read(stream, context=r, endian=endian, **kwargs) s["a"] = stream.tell() - _s """ @@ -267,7 +267,7 @@ def test_generate_structure_read_anonymous(cs: cstruct) -> None: expected = """ _s = stream.tell() - r["a"] = _0._read(stream, context=r) + r["a"] = _0._read(stream, context=r, endian=endian, **kwargs) s["a"] = stream.tell() - _s """ @@ -280,7 +280,7 @@ def test_generate_array_read(cs: cstruct) -> None: expected = """ _s = stream.tell() - r["a"] = _0._read(stream, context=r) + r["a"] = _0._read(stream, context=r, endian=endian, **kwargs) s["a"] = stream.tell() - _s """ @@ -325,7 +325,7 @@ def test_generate_fields_dynamic_after_bitfield(cs: cstruct, TestEnum: Enum, oth expected = """ buf = stream.read(2) if len(buf) != 2: raise EOFError() - data = _struct(cls.cs.endian, "H").unpack(buf) + data = _struct(endian, "H").unpack(buf) r["size"] = type.__call__(_0, data[0]) @@ -341,7 +341,7 @@ def test_generate_fields_dynamic_after_bitfield(cs: cstruct, TestEnum: Enum, oth stream.seek(o + 3) _s = stream.tell() - r["c"] = _3._read(stream, context=r) + r["c"] = _3._read(stream, context=r, endian=endian, **kwargs) s["c"] = stream.tell() - _s """ @@ -364,7 +364,7 @@ def test_generate_fields_dynamic_before_bitfield(cs: cstruct, TestEnum: Enum, ot expected = """ buf = stream.read(2) if len(buf) != 2: raise EOFError() - data = _struct(cls.cs.endian, "H").unpack(buf) + data = _struct(endian, "H").unpack(buf) r["size"] = type.__call__(_0, data[0]) @@ -380,7 +380,7 @@ def test_generate_fields_dynamic_before_bitfield(cs: cstruct, TestEnum: Enum, ot stream.seek(o + 3) _s = stream.tell() - r["c"] = _3._read(stream, context=r) + r["c"] = _3._read(stream, context=r, endian=endian, **kwargs) s["c"] = stream.tell() - _s """ diff --git a/tests/test_types_base.py b/tests/test_types_base.py index 6e73ceb4..45d466b7 100644 --- a/tests/test_types_base.py +++ b/tests/test_types_base.py @@ -10,7 +10,7 @@ from .utils import verify_compiled if TYPE_CHECKING: - from dissect.cstruct.cstruct import cstruct + from dissect.cstruct.cstruct import Endianness, cstruct def test_array_size_mismatch(cs: cstruct) -> None: @@ -93,7 +93,7 @@ def __init__(self, value: bytes = b""): self.value = value.upper() @classmethod - def _read(cls, stream: BinaryIO, context: dict | None = None) -> CustomType: + def _read(cls, stream: BinaryIO, *, context: dict | None = None, endian: Endianness, **kwargs) -> CustomType: length = stream.read(1)[0] value = stream.read(length) return type.__call__(cls, value) @@ -104,8 +104,10 @@ def __default__(cls) -> CustomType: return cls.type() @classmethod - def _read(cls, stream: BinaryIO, context: dict | None = None) -> CustomType: - value = cls.type._read(stream, context) + def _read( + cls, stream: BinaryIO, *, context: dict | None = None, endian: Endianness, **kwargs + ) -> CustomType: + value = cls.type._read(stream, context=context, endian=endian, **kwargs) if str(cls.num_entries) == "lower": value.value = value.value.lower() diff --git a/tests/test_types_custom.py b/tests/test_types_custom.py index 74656dcd..507a8400 100644 --- a/tests/test_types_custom.py +++ b/tests/test_types_custom.py @@ -4,10 +4,12 @@ import pytest +from dissect.cstruct.exceptions import Error from dissect.cstruct.types import BaseType +from dissect.cstruct.types.base import BaseArray if TYPE_CHECKING: - from dissect.cstruct.cstruct import cstruct + from dissect.cstruct.cstruct import Endianness, cstruct class EtwPointer(BaseType): @@ -19,16 +21,20 @@ def __default__(cls) -> int: return cls.cs.uint64.__default__() @classmethod - def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> BaseType: - return cls.type._read(stream, context) + def _read( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> BaseType: + return cls.type._read(stream, context=context, endian=endian, **kwargs) @classmethod - def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]: - return cls.type._read_0(stream, context) + def _read_0( + cls, stream: BinaryIO, *, context: dict[str, Any] | None = None, endian: Endianness, **kwargs + ) -> list[BaseType]: + return cls.type._read_0(stream, context=context, endian=endian, **kwargs) @classmethod - def _write(cls, stream: BinaryIO, data: Any) -> int: - return cls.type._write(stream, data) + def _write(cls, stream: BinaryIO, data: Any, *, endian: Endianness, **kwargs) -> int: + return cls.type._write(stream, data, endian=endian, **kwargs) @classmethod def as_32bit(cls) -> None: @@ -91,3 +97,41 @@ def test_custom_default(cs: cstruct) -> None: assert cs.EtwPointer[1].__default__() == [0] assert cs.EtwPointer[None].__default__() == [] + + +def test_custom_deprecated_signature(cs: cstruct) -> None: + class New(int, BaseType): + @classmethod + def _read(cls, stream: BinaryIO, *, context: dict | None = None, endian: Endianness, **kwargs) -> New: + pass + + @classmethod + def _write(cls, stream: BinaryIO, data: int, *, endian: Endianness, **kwargs) -> New: + pass + + class OldRead(int, BaseType): + @classmethod + def _read(cls, stream: BinaryIO, context: dict | None = None) -> OldRead: + pass + + class OldWrite(int, BaseType): + @classmethod + def _write(cls, stream: BinaryIO, data: int) -> OldRead: + pass + + class OldArrayRead(int, BaseType): + class ArrayType(list, BaseArray): + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list: + pass + + # No errors + cs.add_custom_type("New", New) + + with pytest.raises(Error, match=r"OldRead has an incompatible _read method signature"): + cs.add_custom_type("OldRead", OldRead) + + with pytest.raises(Error, match=r"OldWrite has an incompatible _write method signature"): + cs.add_custom_type("OldWrite", OldWrite) + + with pytest.raises(Error, match=r"OldArrayRead\.ArrayType has an incompatible _read method signature"): + cs.add_custom_type("OldArrayRead", OldArrayRead) diff --git a/tests/test_types_pointer.py b/tests/test_types_pointer.py index 05bcbf04..3d0538af 100644 --- a/tests/test_types_pointer.py +++ b/tests/test_types_pointer.py @@ -27,7 +27,7 @@ def test_pointer(cs: cstruct) -> None: assert str(obj) == "255" with pytest.raises(NullPointerDereference): - ptr(0, None).dereference() + ptr(0, None, endian=cs.endian).dereference() def test_pointer_char(cs: cstruct) -> None: @@ -251,3 +251,21 @@ def test_pointer_default(cs: cstruct) -> None: with pytest.raises(NullPointerDereference): ptr.__default__().dereference() + + +def test_pointer_changing_endian(cs: cstruct) -> None: + cs.pointer = cs.uint16 + + ptr = cs._make_pointer(cs.uint32) + assert issubclass(ptr, Pointer) + assert ptr.__name__ == "uint32*" + + assert cs.endian == "<" + + obj = ptr(b"\x00\x02\x01\x02\x03\x04", endian=">") + assert repr(obj) == "" + + assert obj == 2 + assert obj.dumps(endian=">") == b"\x00\x02" + assert obj.dereference() == 0x01020304 + assert obj.dereference(endian="<") == 0x04030201 diff --git a/tests/test_types_structure.py b/tests/test_types_structure.py index a137fa37..93862aa1 100644 --- a/tests/test_types_structure.py +++ b/tests/test_types_structure.py @@ -844,3 +844,27 @@ def test_structure_definition_newline(cs: cstruct, compiled: bool) -> None: obj.wstring = "test" assert obj.dumps() == buf + + +def test_structure_changing_endian(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + uint32 a; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + assert cs.endian == "<" + + buf = b"\x01\x02\x03\x04" + obj = cs.test(buf) + + assert obj.a == 0x04030201 + + assert obj.dumps() == buf + + obj = cs.test(buf, endian=">") + assert obj.a == 0x01020304 + + assert obj.dumps(endian=">") == buf diff --git a/tests/test_types_union.py b/tests/test_types_union.py index 660684ff..df8cab16 100644 --- a/tests/test_types_union.py +++ b/tests/test_types_union.py @@ -547,3 +547,24 @@ def test_codegen_hashable(cs: cstruct) -> None: assert hash(structure._generate_union__init__(hashable_fields).__code__) assert hash(structure._generate_union__init__(unhashable_fields).__code__) + + +def test_union_changing_endian(cs: cstruct) -> None: + cdef = """ + union test { + uint32 a; + char b[8]; + }; + """ + cs.load(cdef, compiled=False) + + assert len(cs.test) == 8 + assert cs.endian == "<" + + buf = b"zomgbeef" + obj = cs.test(buf, endian=">") + + assert obj.a == 0x7A6F6D67 + assert obj.b == b"zomgbeef" + + assert obj.dumps(endian=">") == buf