diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a88768..0b8605d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ fail_fast: false repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.2 + rev: v0.15.5 hooks: - id: ruff-check - id: ruff-format @@ -18,7 +18,7 @@ repos: - id: check-added-large-files - id: check-shebang-scripts-are-executable - repo: https://github.com/tombi-toml/tombi-pre-commit - rev: v0.7.32 + rev: v0.9.4 hooks: - id: tombi-format - id: tombi-lint diff --git a/.vscode/settings.json b/.vscode/settings.json index a7fee3e..f3ec19f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,7 +3,7 @@ "[python]": { "editor.defaultFormatter": "charliermarsh.ruff", "editor.codeActionsOnSave": { - "source.organizeImports": "explicit" + "source.organizeImports.ruff": "explicit" } }, "[markdown][yaml]": { @@ -21,8 +21,7 @@ "**/templates/common/**": "mako" }, "pylint.enabled": false, - "ruff.enabled": true, + "ruff.enable": true, "python.analysis.importFormat": "relative", - "python.analysis.typeCheckingMode": "standard", "python-envs.pythonProjects": [] } diff --git a/pyproject.toml b/pyproject.toml index be9ea13..a28647d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,13 @@ build-backend = "setuptools.build_meta" [tool.pyright] venvPath = "." venv = ".venv" +exclude = [ + ".venv", + "build", + "**/.*", + "**/node_modules", + "**/__pycache__", +] [tool.ruff.lint] select = [ @@ -107,3 +114,6 @@ packages = ["zmk"] [tool.setuptools_scm] version_file = "zmk/_version.py" + +[tool.tombi.files] +exclude = ["uv.lock"] diff --git a/zmk/build.py b/zmk/build.py index e1aaab7..9b8e1e7 100644 --- a/zmk/build.py +++ b/zmk/build.py @@ -3,47 +3,19 @@ """ from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from pathlib import Path from typing import Any, TypeVar, cast, overload import dacite +from .hardware import BoardTarget, BuildItem from .repo import Repo from .yaml import YAML T = TypeVar("T") -@dataclass -class BuildItem: - """An item in the build matrix""" - - board: str - shield: str | None = None - snippet: str | None = None - cmake_args: str | None = None - artifact_name: str | None = None - - def __rich__(self) -> str: - parts = [] - parts.append(self.board) - - if self.shield: - parts.append(self.shield) - - if self.snippet: - parts.append(f"[dim]snippet: {self.snippet}[/dim]") - - if self.artifact_name: - parts.append(f"[dim]artifact-name: {self.artifact_name}[/dim]") - - if self.cmake_args: - parts.append(f"[dim]cmake-args: {self.cmake_args}[/dim]") - - return "[dim], [/dim]".join(parts) - - @dataclass class _BuildMatrixWrapper: include: list[BuildItem] = field(default_factory=list) @@ -71,7 +43,7 @@ def __init__(self, path: Path): self._data = None def write(self) -> None: - """Updated the YAML file, creating it if necessary""" + """Update the YAML file, creating it if necessary""" self._yaml.dump(self._data, self._path) @property @@ -86,7 +58,8 @@ def include(self) -> list[BuildItem]: if not normalized: return [] - wrapper = dacite.from_dict(_BuildMatrixWrapper, normalized) + config = dacite.Config(type_hooks={BoardTarget: BoardTarget.parse}) + wrapper = dacite.from_dict(_BuildMatrixWrapper, normalized, config) return wrapper.include def has_item(self, item: BuildItem) -> bool: @@ -183,12 +156,19 @@ def fix_key(key: str): def _to_yaml(item: BuildItem): """ - Convert a BuildItem to a dict with keys changed back from underscores to hyphens. + Convert a BuildItem to a dict with keys changed back from underscores to hyphens + and values changed to YAML-compatible types. """ def fix_key(key: str): return key.replace("_", "-") - data = asdict(item) + def fix_value(value: Any): + match value: + case BoardTarget(): + return str(value) + + case _: + return value - return {fix_key(k): v for k, v in data.items() if v is not None} + return {fix_key(k): fix_value(v) for k, v in item.__dict__.items() if v is not None} diff --git a/zmk/commands/keyboard/add.py b/zmk/commands/keyboard/add.py index 269df6b..3450f94 100644 --- a/zmk/commands/keyboard/add.py +++ b/zmk/commands/keyboard/add.py @@ -2,31 +2,22 @@ "zmk keyboard add" command. """ -import itertools import shutil from pathlib import Path from typing import Annotated import rich import typer +from rich.padding import Padding -from ...build import BuildItem, BuildMatrix +from ...build import BuildMatrix from ...config import get_config from ...exceptions import FatalError -from ...hardware import ( - Board, - Hardware, - Keyboard, - Shield, - append_revision, - get_hardware, - is_compatible, - normalize_revision, - show_hardware_menu, - show_revision_menu, - split_revision, -) +from ...hardware import Board, BoardTarget, Keyboard, Shield +from ...hardware_list import get_hardware, show_hardware_menu, show_revision_menu from ...repo import Repo +from ...revision import Revision +from ...styles import THEME, BoardIdHighlighter, chain_highlighters from ...util import spinner @@ -55,6 +46,8 @@ def keyboard_add( """Add configuration for a keyboard and add it to the build.""" console = rich.get_console() + console.push_theme(THEME) + console.highlighter = chain_highlighters(console.highlighter, BoardIdHighlighter()) cfg = get_config(ctx) repo = cfg.get_repo() @@ -62,94 +55,111 @@ def keyboard_add( with spinner("Finding hardware..."): hardware = get_hardware(repo) - keyboard = None - controller = None - revision = None + keyboard = Keyboard() if keyboard_id: - keyboard_id, keyboard_revision = split_revision(keyboard_id) + # User specified a keyboard. It might be either a board (with optional + # revision and board qualifiers) or a shield. + target = BoardTarget.parse(keyboard_id) + keyboard_id = target.name + target.qualifiers - keyboard = hardware.find_keyboard(keyboard_id) - if keyboard is None: + keys_component = hardware.find_keyboard(keyboard_id) + if not keys_component: raise KeyboardNotFoundError(keyboard_id) - # If the keyboard ID contained a revision, use that. - # Make sure it is valid before continuing to any other prompts. - if keyboard_revision: - revision = keyboard_revision - _check_revision(keyboard, revision) + keyboard.add_component(keys_component) + if target.revision: + _check_revision(keys_component, target.revision) + keyboard.board_revision = target.revision if controller_id: - if not isinstance(keyboard, Shield): + # User also specified a controller board (with optional revision and + # board qualifiers). If the specified keyboard was already a board, + # then this is invalid because we can't have two boards. + if keyboard.board: raise FatalError( - f'Keyboard "{keyboard.id}" has an onboard controller ' + f'Keyboard "{keys_component.id}" has an onboard controller ' "and does not require a controller board." ) + target = BoardTarget.parse(controller_id) + controller_id = target.name + target.qualifiers + controller = hardware.find_controller(controller_id) - if controller is None: + if not controller: raise ControllerNotFoundError(controller_id) + keyboard.add_component(controller) + if target.revision: + _check_revision(controller, target.revision) + keyboard.board_revision = target.revision + elif controller_id: - controller_id, controller_revision = split_revision(controller_id) + # User specified a controller but not a keyboard. Find the controller. + # We will prompt for the shield later. + + target = BoardTarget.parse(controller_id) - # User specified a controller but not a keyboard. Filter the keyboard - # list to just those compatible with the controller. controller = hardware.find_controller(controller_id) - if controller is None: + if not controller: raise ControllerNotFoundError(controller_id) - # If the controller ID contained a revision, use that. - # Make sure it is valid before continuing to any other prompts. - if controller_revision: - revision = controller_revision - _check_revision(controller, revision) + keyboard.add_component(controller) + if target.revision: + _check_revision(controller, target.revision) + keyboard.board_revision = target.revision - hardware.keyboards = [ - kb - for kb in hardware.keyboards - if isinstance(kb, Shield) and is_compatible(controller, kb) - ] + # When prompting for a keyboard later, it should only display the shields + # that are compatible with the chosen board. + hardware.filter_compatible_keyboards(keyboard) # Prompt the user for any necessary components they didn't specify - if keyboard is None: - keyboard = show_hardware_menu("Select a keyboard:", hardware.keyboards) - - if isinstance(keyboard, Shield): - if controller is None: - hardware.controllers = [ - c for c in hardware.controllers if is_compatible(c, keyboard) - ] - controller = show_hardware_menu( - "Select a controller:", hardware.controllers - ) - - # Sanity check that everything is compatible - if not is_compatible(controller, keyboard): - raise FatalError( - f'Keyboard "{keyboard.id}" is not compatible with controller "{controller.id}"' - ) - - # Check if the controller needs a revision. - revision = _get_revision(controller, revision) - else: - # If the keyboard isn't a shield, it may need a revision. - revision = _get_revision(keyboard, revision) - - name = keyboard.id - if controller: - name += ", " + controller.id - - if revision: - revision = normalize_revision(revision) - name = append_revision(name, revision) - - if _add_keyboard(repo, keyboard, controller, revision): - console.print(f'Added "{name}".') + if not keyboard.keys_component: + keyboard.add_component( + show_hardware_menu("Select a keyboard:", hardware.keyboards) + ) + + if not keyboard.board: + hardware.filter_compatible_controllers(keyboard) + + keyboard.add_component( + show_hardware_menu("Select a controller:", hardware.controllers) + ) + + # Sanity check the resulting hardware compatibility + if not keyboard.board: + raise FatalError( + "Controller board is missing (this is probably a bug in ZMK CLI)." + ) + + if not keyboard.keys_component: + raise FatalError( + "Component with 'keys' feature is missing (this is probably a bug in ZMK CLI)." + ) + + if keyboard.missing_requirements: + raise FatalError( + f'Keyboard "{keyboard.keys_component}" is not compatible with controller "{keyboard.board}". ' + f"Required interconnects are missing: {', '.join(keyboard.missing_requirements)}" + ) + + # If a revision wasn't already set from the command line, the user may need + # to choose a revision. + if not keyboard.board_revision and keyboard.board_revisions: + keyboard.board_revision = show_revision_menu(keyboard.board) + + if added := _add_keyboard(repo, keyboard): + console.print("[title]Added:") + + for item in added: + console.print(Padding.indent(item, 2)) + + console.print() else: + name = keyboard.keys_component.name console.print(f'"{name}" is already in the build matrix.') - keymap_name = keyboard.get_keymap_path(revision).with_suffix("").name + keymap_name = keyboard.get_keymap_path().with_suffix("").name console.print(f'Run "zmk code {keymap_name}" to edit the keymap.') @@ -174,64 +184,28 @@ def _copy_keyboard_file(repo: Repo, path: Path): shutil.copy2(path, dest_path) -def _get_build_items( - keyboard: Keyboard, controller: Board | None, revision: str | None -): - boards = [] - shields = [] - - match keyboard: - case Shield(id=shield_id, siblings=siblings): - if controller is None: - raise FatalError("controller may not be None if keyboard is a shield") - - shields = siblings or [shield_id] - boards = [append_revision(controller.id, revision)] +def _check_revision(board: Board | Shield, revision: Revision): + if not isinstance(board, Board): + raise FatalError(f"{board.id} is a shield. Only boards support revisions.") - case Board(id=board_id, siblings=siblings): - boards = siblings or [board_id] - boards = [append_revision(board, revision) for board in boards] - - case _: - raise FatalError("Unexpected keyboard/controller combination") - - if shields: - return [ - BuildItem(board=b, shield=s) for b, s in itertools.product(boards, shields) - ] - - return [BuildItem(board=b) for b in boards] - - -def _get_revision(board: Hardware, revision: str | None): - # If no revision was specified and the board uses revisions, prompt to - # select a revision. - return revision or show_revision_menu(board) - - -def _check_revision(board: Hardware, revision: str): - if board.has_revision(revision): + if revision in board.revisions: # Revision is OK return - supported_revisions = board.get_revisions() - - if not supported_revisions: + if not board.revisions: raise FatalError(f"{board.id} does not have any revisions.") raise FatalError( f'{board.id} does not support revision "@{revision}". Use one of:\n' - + "\n".join(f" @{normalize_revision(rev)}" for rev in supported_revisions) + + "\n".join(f" @{rev}" for rev in board.revisions) ) -def _add_keyboard( - repo: Repo, keyboard: Keyboard, controller: Board | None, revision: str | None -): - _copy_keyboard_file(repo, keyboard.get_keymap_path(revision)) - _copy_keyboard_file(repo, keyboard.get_config_path(revision)) +def _add_keyboard(repo: Repo, keyboard: Keyboard): + items = keyboard.get_build_items() - items = _get_build_items(keyboard, controller, revision) + _copy_keyboard_file(repo, keyboard.get_keymap_path()) + _copy_keyboard_file(repo, keyboard.get_config_path()) matrix = BuildMatrix.from_repo(repo) added = matrix.append(items) diff --git a/zmk/commands/keyboard/list.py b/zmk/commands/keyboard/list.py index 6853658..553de7a 100644 --- a/zmk/commands/keyboard/list.py +++ b/zmk/commands/keyboard/list.py @@ -14,18 +14,11 @@ from zmk import styles -from ...build import BuildItem, BuildMatrix +from ...build import BuildMatrix from ...config import get_config from ...exceptions import FatalError -from ...hardware import ( - Board, - Hardware, - Shield, - append_revision, - get_hardware, - is_compatible, - normalize_revision, -) +from ...hardware import Board, BoardTarget, BuildItem, Hardware, Keyboard, Shield +from ...hardware_list import get_hardware from ...util import spinner # TODO: allow output as unformatted list @@ -48,7 +41,7 @@ def _list_build_matrix(ctx: typer.Context, *, value: bool): console = Console( highlighter=styles.chain_highlighters( - [styles.BoardIdHighlighter(), styles.CommandLineHighlighter()] + styles.BoardIdHighlighter(), styles.CommandLineHighlighter() ), theme=styles.THEME, ) @@ -64,7 +57,7 @@ def _list_build_matrix(ctx: typer.Context, *, value: bool): has_artifact_name = any(item.artifact_name for item in include) table = Table( - box=box.SQUARE, + box=box.ROUNDED, border_style="dim blue", header_style="bright_cyan", highlight=True, @@ -123,7 +116,7 @@ def keyboard_list( "--board", "-b", metavar="BOARD", - help="List only keyboards compatible with this controller board.", + help="List keyboards compatible with this controller board.", ), ] = None, shield: Annotated[ @@ -132,7 +125,7 @@ def keyboard_list( "--shield", "-s", metavar="SHIELD", - help="List only controllers compatible with this keyboard shield.", + help="List controllers compatible with this keyboard shield.", ), ] = None, interconnect: Annotated[ @@ -141,7 +134,7 @@ def keyboard_list( "--interconnect", "-i", metavar="INTERCONNECT", - help="List only keyboards and controllers that have this interconnect.", + help="List keyboards and controllers that use this interconnect.", ), ] = None, standalone: Annotated[ @@ -173,11 +166,9 @@ def keyboard_list( if item is None: raise FatalError(f'Could not find controller board "{board}".') - groups.keyboards = [ - kb - for kb in groups.keyboards - if isinstance(kb, Shield) and is_compatible(item, kb) - ] + keyboard = Keyboard() + keyboard.add_component(item) + groups.filter_compatible_keyboards(keyboard) list_type = ListType.KEYBOARD elif shield: @@ -189,23 +180,21 @@ def keyboard_list( if not isinstance(item, Shield): raise FatalError(f'Keyboard "{shield}" is a standalone keyboard.') - groups.controllers = [c for c in groups.controllers if is_compatible(c, item)] + keyboard = Keyboard() + keyboard.add_component(item) + groups.filter_compatible_controllers(keyboard) list_type = ListType.CONTROLLER elif interconnect: - # Filter to controllers that provide an interconnect and keyboards that use it. + # Filter to controllers that provide an interconnect and keyboards that + # use or provide it. item = groups.find_interconnect(interconnect) if item is None: raise FatalError(f'Could not find interconnect "{interconnect}".') - groups.controllers = [ - c for c in groups.controllers if c.exposes and item.id in c.exposes - ] - groups.keyboards = [ - kb - for kb in groups.keyboards - if isinstance(kb, Shield) and kb.requires and item.id in kb.requires - ] + groups.filter_to_interconnect(item) + + # When filtering to an interconnect, don't show interconnects. groups.interconnects = [] elif standalone: @@ -214,14 +203,16 @@ def keyboard_list( list_type = ListType.KEYBOARD def print_items(header: str, items: Iterable[Hardware]): - if revisions: - names = [ - append_revision(item.id, normalize_revision(rev)) - for item in items - for rev in (item.get_revisions() or [None]) - ] - else: - names = [item.id for item in items] + names: list[str] = [] + + for item in items: + if revisions and isinstance(item, Board): + names.extend( + str(BoardTarget.parse(item.id).with_revision(rev)) + for rev in item.revisions + ) + else: + names.append(item.id) if not names: return @@ -233,6 +224,10 @@ def print_items(header: str, items: Iterable[Hardware]): console.print(columns) console.print() + # TODO: when filtering to an interconnect, we should specify which hardware + # exposes vs. which requires. This would be useful if we start to add things + # like non-keyboard shields to the hardware info. + if list_type in (ListType.ALL, ListType.KEYBOARD): print_items("Keyboards:", groups.keyboards) diff --git a/zmk/commands/keyboard/new.py b/zmk/commands/keyboard/new.py index d884232..593fabd 100644 --- a/zmk/commands/keyboard/new.py +++ b/zmk/commands/keyboard/new.py @@ -14,7 +14,8 @@ from ...config import get_config from ...exceptions import FatalError -from ...hardware import Interconnect, get_hardware, show_hardware_menu +from ...hardware import Interconnect +from ...hardware_list import get_hardware, show_hardware_menu from ...menu import detail_list, show_menu from ...repo import Repo from ...templates import get_template_files diff --git a/zmk/commands/keyboard/remove.py b/zmk/commands/keyboard/remove.py index e0a2808..e0efb95 100644 --- a/zmk/commands/keyboard/remove.py +++ b/zmk/commands/keyboard/remove.py @@ -2,12 +2,14 @@ "zmk remove" command. """ -import rich import typer +from rich.console import Console +from rich.padding import Padding from ...build import BuildMatrix from ...config import get_config from ...menu import show_menu +from ...styles import MENU_THEME, THEME, BoardIdHighlighter # TODO: add options to select items from command line @@ -19,10 +21,14 @@ def keyboard_remove(ctx: typer.Context) -> None: matrix = BuildMatrix.from_repo(repo) items = matrix.include - item = show_menu("Select a build to remove:", items) + console = Console(theme=THEME, highlighter=BoardIdHighlighter()) + + with console.use_theme(MENU_THEME): + item = show_menu("Select a build to remove:", items, console=console) if removed := matrix.remove(item): - items = ", ".join(f'"{item.__rich__()}"' for item in removed) - rich.print(f"Removed {items} from the build.") + console.print("[title]Removed:") + for item in removed: + console.print(Padding.indent(item, 2)) matrix.write() diff --git a/zmk/commands/module/remove.py b/zmk/commands/module/remove.py index 6891876..5fb6334 100644 --- a/zmk/commands/module/remove.py +++ b/zmk/commands/module/remove.py @@ -5,17 +5,19 @@ import shutil import stat import subprocess +from collections.abc import Iterable from dataclasses import dataclass from pathlib import Path from typing import Annotated, Any import rich import typer +from rich.console import RenderableType from west.manifest import Project from ...config import get_config from ...exceptions import FatalError -from ...menu import Detail, detail_list, show_menu +from ...menu import show_menu from ...repo import Repo from ...util import spinner from ...yaml import YAML @@ -78,7 +80,12 @@ def remove_readonly(func, path, _): with spinner("Deleting module files."): try: # Make sure Git isn't locking the folder first. - subprocess.call(["git", "fsmonitor--daemon", "stop"], cwd=module_path) + subprocess.call( + ["git", "fsmonitor--daemon", "stop"], + cwd=module_path, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) shutil.rmtree(module_path, onexc=remove_readonly) except FileNotFoundError: @@ -91,10 +98,10 @@ def remove_readonly(func, path, _): def _prompt_project(projects: list[Project]): - items = detail_list((ProjectWrapper(p), p.url) for p in projects) + items = [ProjectWrapper(p) for p in projects] result = show_menu("Select the module to remove:", items, filter_func=_filter) - return result.data.project + return result.project @dataclass @@ -103,10 +110,11 @@ class ProjectWrapper: project: Project - def __str__(self): - return self.project.name + def __menu_row__(self) -> Iterable[RenderableType]: + yield self.project.name + yield self.project.url -def _filter(item: Detail[ProjectWrapper], text: str): +def _filter(item: ProjectWrapper, text: str): text = text.casefold() - return text in item.data.project.name.casefold() or text in item.detail + return text in item.project.name.casefold() or text in item.project.url.casefold() diff --git a/zmk/hardware.py b/zmk/hardware.py index 170c053..bbae1fa 100644 --- a/zmk/hardware.py +++ b/zmk/hardware.py @@ -1,20 +1,16 @@ -""" -Hardware metadata discovery and processing. -""" - +import itertools import re -from collections.abc import Generator, Iterable -from dataclasses import dataclass, field -from functools import reduce +from collections.abc import Iterable +from dataclasses import dataclass, field, replace from pathlib import Path -from typing import Any, Literal, Self, TypeAlias, TypedDict, TypeGuard, TypeVar +from typing import Literal, Self, TypeAlias, TypedDict, cast import dacite +from rich.console import RenderableType +from rich.measure import Measurement -from .menu import show_menu -from .repo import Repo -from .util import flatten -from .yaml import read_yaml +from .revision import Revision +from .util import horizontal_group, union Feature: TypeAlias = ( Literal["keys", "display", "encoder", "underglow", "backlight", "pointer", "studio"] @@ -32,7 +28,97 @@ class VariantDict(TypedDict): Variant: TypeAlias = str | VariantDict -_HW = TypeVar("_HW", bound="Hardware") + +@dataclass +class BoardTarget: + """ + Zephyr board target. Identifies a unique combination of board name, + board revision, and board qualifiers (SoC, CPU cluster, and variant). + """ + + name: str = "" + """Zephyr Board ID (not the ZMK display name)""" + revision: Revision = field(default_factory=Revision) + """Optional board revision""" + qualifiers: str = "" + """Optional board qualifiers (including forward slashes)""" + + @classmethod + def parse(cls, target: str): + """ + Parse a board target into its parts. + + Examples: + ``` + BoardTarget.parse("nice_nano//zmk") + # returns + BoardTarget(name="nice_nano", revision=Revision(), qualifiers="//zmk") + + BoardTarget.parse("bl5340_dvk@1.2.0/nrf5340/cpuapp/ns") + # returns + BoardTarget( + name="bl5340_dvk", + revision=Revision("1.2.0"), + qualifiers="/nrf5340/cpuapp/ns" + ) + ``` + """ + name, qualifiers = split_board_qualifiers(target) + name, _, revision = name.partition("@") + + return BoardTarget( + name=name, revision=Revision(revision), qualifiers=qualifiers + ) + + def __str__(self): + return self.name + self.revision.at_str + self.qualifiers + + # Can't use just __rich__ here, or this won't display properly in tables. + # https://github.com/Textualize/rich/issues/3188 + + def __rich_console__(self, console, options): + yield str(self) + + def __rich_measure__(self, console, options): + length = len(str(self)) + return Measurement(length, length) + + def with_revision(self, revision: Revision): + """Get a copy of this BoardTarget() with a different revision.""" + return replace(self, revision=revision) + + +@dataclass +class BuildItem: + """Entry in the build.yaml file""" + + board: BoardTarget + shield: str | None = None + snippet: str | None = None + cmake_args: str | None = None + artifact_name: str | None = None + + def __rich__(self): + return horizontal_group( + *(item for item in self.__menu_row__() if item), padding=(0, 2) + ) + + def __menu_row__(self) -> Iterable[RenderableType]: + yield self.board + yield self.shield or "" + + extras = [] + if self.snippet: + extras.append(f"snippet: {self.snippet}") + + if self.artifact_name: + extras.append(f"artifact-name: {self.artifact_name}") + + if self.cmake_args: + extras.append(f"cmake-args: {self.cmake_args}") + + if extras: + yield f"[dim]{', '.join(extras)}" @dataclass @@ -41,382 +127,376 @@ class Hardware: directory: Path """Path to the directory containing this hardware""" + file_format: str | None type: str id: str + """Zephyr identifier for the hardware. Board IDs include board qualifiers.""" name: str + """Display name for the hardware""" + url: str | None + description: str | None + manufacturer: str | None - file_format: str | None = None - url: str | None = None - description: str | None = None - manufacturer: str | None = None - version: str | None = None + @classmethod + def from_dict(cls, data) -> Self: + config = dacite.Config(cast=[set, Revision]) + return dacite.from_dict(data_class=cls, data=data, config=config) def __str__(self) -> str: return self.id - def __rich__(self) -> Any: + def __rich__(self) -> RenderableType: return f"{self.id} [dim]{self.name}" - @classmethod - def from_dict(cls, data) -> Self: - """Read a hardware description from a dict""" - return dacite.from_dict(cls, data) - - def has_id(self, hardware_id: str) -> bool: - """Get whether this hardware has the given ID (case insensitive)""" - return hardware_id.casefold() == self.id.casefold() + def __menu_row__(self) -> Iterable[RenderableType]: + return [self.id, self.name] - def has_revision(self, revision: str) -> bool: + def get_normalized_ids(self) -> list[str]: """ - Get whether this hardware supports the given revision. - - Any empty string is treated as "default revision" and always returns True. + Returns a list of names by which this hardware can be matched, for example + in command line parameters. All results are casefolded. """ - # By default, the only supported revision is no revision at all. - return not revision - - def get_revisions(self) -> list[str]: - """Get a list of supported revisions""" - return [] - - def get_default_revision(self) -> str | None: - """Get the default item from get_revisions or None if no default is set""" - return None + return [self.id.casefold()] @dataclass class Interconnect(Hardware): - """Description of the connection between two pieces of hardware""" + """ + Description of the connection between two pieces of hardware. + + Matches #/$defs/interconnect from + https://github.com/zmkfirmware/zmk/blob/main/schema/hardware-metadata.schema.json + """ node_labels: dict[str, str] = field(default_factory=dict) design_guideline: str | None = None @dataclass -class Keyboard(Hardware): +class KeyboardComponent(Hardware): """Base class for hardware that forms a keyboard""" - siblings: list[str] | None = field(default_factory=list) - """List of board/shield IDs for a split keyboard""" - exposes: list[str] | None = field(default_factory=list) + siblings: list[str] = field(default_factory=list) + """List of board/shield IDs for a split keyboard. Board IDs include board qualifiers""" + exposes: set[str] = field(default_factory=set) """List of interconnect IDs this board/shield provides""" - features: list[Feature] | None = field(default_factory=list) + features: set[Feature] = field(default_factory=set) """List of features this board/shield supports""" - variants: list[Variant] | None = field(default_factory=list) - - def __post_init__(self): - self.siblings = self.siblings or [] - self.exposes = self.exposes or [] - self.features = self.features or [] - self.variants = self.variants or [] + variants: list[Variant] = field(default_factory=list) - def get_config_path(self, revision: str | None = None) -> Path: - """Path to the .conf file for this keyboard""" - return self._get_keyboard_file(".conf", revision) - - def get_keymap_path(self, revision: str | None = None) -> Path: - """Path to the .keymap file for this keyboard""" - return self._get_keyboard_file(".keymap", revision) - - def _get_revision_suffixes(self, revision: str | None = None) -> Generator[str]: - if revision: - for rev in get_revision_forms(revision): - yield "_" + rev.replace(".", "_") - - def _get_keyboard_file(self, extension: str, revision: str | None = None) -> Path: - if revision: - for rev in get_revision_forms(revision): - path = self.directory / f"{self.id}_{rev.replace('.', '_')}{extension}" - if path.exists(): - return path - - return self.directory / f"{self.id}{extension}" + @property + def has_keys(self) -> bool: + """Get whether this hardware has the "keys" feature.""" + return "keys" in self.features @dataclass -class Board(Keyboard): - """Hardware with a processor. May be a keyboard or a controller.""" - - arch: str | None = None - outputs: list[Output] = field(default_factory=list) - """List of methods by which this board supports sending HID data""" - - revisions: list[str] = field(default_factory=list) - default_revision: str | None = None - - def __post_init__(self): - super().__post_init__() - self.outputs = self.outputs or [] - self.revisions = self.revisions or [] - - def has_revision(self, revision: str): - # Empty string means "use default revision" - if not revision: - return True - - revision = normalize_revision(revision) - - return any(normalize_revision(rev) == revision for rev in self.revisions) - - def get_revisions(self): - return self.revisions - - def get_default_revision(self): - return self.default_revision - - -def split_revision(identifier: str) -> tuple[str, str]: +class Board(KeyboardComponent): """ - Splits a string containing a hardware ID and optionally a revision into the - ID and revision. + Description of a Zephyr board. May be a controller or a standalone keyboard. - Examples: - "foo" -> "foo", "" - "foo@2" -> "foo", "2" + Matches #/$defs/board from + https://github.com/zmkfirmware/zmk/blob/main/schema/hardware-metadata.schema.json """ - hardware_id, _, revision = identifier.partition("@") - return hardware_id, revision + arch: str | None = None + outputs: set[Output] = field(default_factory=set) + """List of methods by which this board supports sending HID data""" -def append_revision(identifier: str, revision: str | None): - """ - Joins a hardware ID with a revision string. - - Examples: - "foo" + None -> "foo" - "foo" + "2" -> "foo@2" - """ - return f"{identifier}@{revision}" if revision else identifier - + revisions: list[Revision] = field(default_factory=list) + default_revision: Revision = field(default_factory=Revision) -def normalize_revision(revision: str | None) -> str: - """ - Normalizes letter revisions to uppercase and shortens numeric versions to - the smallest form with the same meaning. + def get_normalized_ids(self) -> list[str]: + """ + Returns a list of names by which this hardware can be matched, for example + in command line parameters. All results are casefolded. - Examples: - "a" -> "A" - "1.2.0" -> "1.2" - "2.0.0" -> "2" - """ - if not revision: - return "" + To make specifying board IDs as command line parameters easier, the "zmk" + board qualifier is optional, e.g. "nice_nano" matches a board with + `id="nice_nano//zmk"`, and "nrfmicro/nrf52840" matches a board with + `id="nrfmicro/nrf52840/zmk"`. + """ + norm_id = self.id.casefold() - return re.sub(r"(?:\.0){1,2}$", "", revision).upper() + result = [norm_id] + if norm_id.endswith("//zmk"): + result.append(norm_id.removesuffix("//zmk")) + elif norm_id.endswith("/zmk"): + result.append(norm_id.removesuffix("/zmk")) + return result -def get_revision_forms(revision: str) -> list[str]: - """ - Returns a list of all equivalent spellings of a revision. - Examples: - "a" -> ["A", "a"] - "1.2.3" -> ["1.2.3"] - "1.2.0" -> ["1.2", "1.2.0"] - "2.0.0" -> ["2", "2.0", "2.0.0"] +@dataclass +class Shield(KeyboardComponent): """ - revision = normalize_revision(revision) - - if revision.isalpha(): - return [revision.upper(), revision.lower()] - - result = [] - - dot_count = revision.count(".") - if dot_count == 0: - result.append(revision + ".0.0") - if dot_count <= 1: - result.append(revision + ".0") - - result.append(revision) + Description of a Zephyr shield. May be a keyboard or a peripheral. - return result + Matches #/$defs/shield from + https://github.com/zmkfirmware/zmk/blob/main/schema/hardware-metadata.schema.json + """ -@dataclass -class Shield(Keyboard): - """Hardware that attaches to a board. May be a keyboard or a peripheral.""" + requires: set[str] = field(default_factory=set) + """List of interconnects this shield requires to be attached to""" - requires: list[str] | None = field(default_factory=list) - """List of interconnects to which this shield attaches""" - def __post_init__(self): - super().__post_init__() - self.requires = self.requires or [] +class IncompleteKeyboardError(Exception): + pass @dataclass -class GroupedHardware: - """Hardware grouped by type.""" - - keyboards: list[Keyboard] = field(default_factory=list) - """List of boards/shields that are keyboard PCBs""" - controllers: list[Board] = field(default_factory=list) - """List of boards that are controllers for keyboards""" - interconnects: list[Interconnect] = field(default_factory=list) - """List of interconnect descriptions""" - - # TODO: add displays and other peripherals? - - def find_keyboard(self, item_id: str) -> Keyboard | None: - """Find a keyboard by ID""" - item_id = item_id.casefold() - return next((i for i in self.keyboards if i.id.casefold() == item_id), None) - - def find_controller(self, item_id: str) -> Board | None: - """Find a controller by ID""" - item_id = item_id.casefold() - return next((i for i in self.controllers if i.id.casefold() == item_id), None) - - def find_interconnect(self, item_id: str) -> Interconnect | None: - """Find an interconnect by ID""" - item_id = item_id.casefold() - return next( - (i for i in self.interconnects if i.id.casefold() == item_id), - None, - ) +class Keyboard: + """ + Collection of information needed to determine how to build keyboard firmware. + This consists of: + - A board + - Optionally, a specific board revision to use + - Optionally, some number of shields. -def is_keyboard(hardware: Hardware) -> TypeGuard[Keyboard]: - """Test whether an item is a keyboard (board or shield supporting keys)""" - match hardware: - case Keyboard(features=feat) if feat and "keys" in feat: - return True + At least one item between the board and shields is required to have the "keys" + before the keyboard is considered "complete" and can be used to build firmware. + """ - case _: - return False + board: Board | None = None + """The controller board""" + board_revision: Revision = field(default_factory=Revision) + """The board revision selected to build""" + shields: list[Shield] = field(default_factory=list) + """List of shields to attach to the board""" + @property + def board_targets(self) -> list[BoardTarget]: + """ + The Zephyr board target(s) from board and board_revision. -def is_controller(hardware: Hardware) -> TypeGuard[Board]: - """Test whether an item is a keyboard controller (board which isn't a keyboard)""" - return isinstance(hardware, Board) and not is_keyboard(hardware) + If board.siblings is not empty, this returns one item per sibling. + Otherwise, it returns a single item for board.id. + :raises IncompleteKeyboardError: if board is not set. + """ + if not self.board: + raise IncompleteKeyboardError("Cannot get board_target when board is None") -def is_interconnect(hardware: Hardware) -> TypeGuard[Interconnect]: - """Test whether an item is an interconnect description""" - return isinstance(hardware, Interconnect) + board_ids = self.board.siblings or [self.board.id] + revision = self.board_revision or self.board.default_revision + return [ + BoardTarget.parse(board_id).with_revision(revision) + for board_id in board_ids + ] -def is_compatible( - base: Board | Shield | Iterable[Board | Shield], shield: Shield -) -> bool: - """ - Get whether a shield can be attached to the given hardware. + @property + def board_revisions(self) -> list[Revision]: + """The board's supported revisions.""" + return self.board.revisions if self.board else [] - This simply checks whether all the interconnects required by "shield" are - provided by the hardware in "base". If "base" is a list of hardware, it does - not account for the fact that one of the items in "base" may already be using - an interconnect provided by another item. - """ + @property + def keys_component(self) -> Board | Shield | None: + """The first item in [board, *shields] which has the "keys" feature.""" + if self.board and self.board.has_keys: + return self.board - if not shield.requires: - return True + return next((s for s in self.shields if s.has_keys), None) - base = [base] if isinstance(base, Keyboard) else base - exposed = flatten(b.exposes for b in base) + @property + def exposes(self) -> set[str]: + """Set of interconnects exposed by the board and shields.""" + board_exposes = self.board.exposes if self.board else set() - return all(ic in exposed for ic in shield.requires) + return board_exposes | union(s.exposes for s in self.shields) + @property + def requires(self) -> set[str]: + """Set of interconnects required by shields.""" + return union(s.requires for s in self.shields) -def get_board_roots(repo: Repo) -> Iterable[Path]: - """Get the paths that contain hardware definitions for a repo""" - roots = set() + @property + def missing_requirements(self) -> set[str]: + """ + Gets any interconnects in self.requires that are not satisfied by + self.exposes. - if root := repo.board_root: - roots.add(root) + This does not attempt to account for multiple instances of the same + interconnect. + """ + return self.requires - self.exposes - for module in repo.get_modules(): - if root := module.board_root: - roots.add(root) + @property + def revisions(self) -> list[Revision]: + """List of available board revisions""" + return self.board.revisions if self.board else [] - return roots + @property + def default_revision(self) -> Revision: + """Board revision that will be used if one isn't explicitly selected""" + return self.board.default_revision if self.board else Revision() + def add_component(self, component: KeyboardComponent): + """Add a board or shield to the keyboard.""" + match component: + case Board(): + self.board = component -def get_hardware(repo: Repo) -> GroupedHardware: - """Get lists of hardware descriptions, grouped by type, for a repo""" - hardware = flatten(_find_hardware(root) for root in get_board_roots(repo)) + case Shield(): + self.shields.append(component) - def func(groups: GroupedHardware, item: Hardware): - if is_keyboard(item): - groups.keyboards.append(item) - elif is_controller(item): - groups.controllers.append(item) - elif is_interconnect(item): - groups.interconnects.append(item) + case _: + raise TypeError("Unknown component type") - return groups + def is_compatible(self, component: Board | Shield): + """ + Get whether a board or shield can be attached to the keyboard and all + interconnect requirements would be satisfied. - groups = reduce(func, hardware, GroupedHardware()) + If given a shield, this checks whether all the interconnects required by + the shield are provided by the board and/or shields already in the keyboard. - groups.controllers = sorted(groups.controllers, key=lambda x: x.id) - groups.keyboards = sorted(groups.keyboards, key=lambda x: x.id) - groups.interconnects = sorted(groups.interconnects, key=lambda x: x.id) + If given a board, this checks whether adding the board would satisfy all + the requirements for the shields already in the keyboard. + """ + match component: + case Board(): + # This assumes we're replacing any existing board, so determine + # the new set of exposed interconnects with the new board and + # without any existing board. + new_exposes = component.exposes | union(s.exposes for s in self.shields) + return not (self.requires - new_exposes) + + case Shield(): + return not (component.requires - self.exposes) + + def get_build_items(self) -> list[BuildItem]: + """ + Get the individual builds needed to make the firmware. This currently + accounts for splits but not more complex configurations where each build + may need different options. + + Each build item will contain the board and every shield. If any board or + shield has siblings, then this will return one item per possible + combination of siblings. For example: + ``` + self.board = Board(id="nice_nano//zmk", default_revision="2") + self.shields = [Shield(id="two_percent_milk")] + # returns + [BuildItem(board_target="nice_nano@2//zmk", shield="two_percent_milk")] + + self.board = Board(id="nice_nano//zmk", default_revision="2") + self.shields = [Shield(siblings=["a_left", "a_right"]), Shield(id="display")] + # returns + [ + BuildItem(board_target="nice_nano@2//zmk", shield="a_left display"), + BuildItem(board_target="nice_nano@2//zmk", shield="a_right display"), + ] + ``` + + :raises IncompleteKeyboardError: if board is not set or there is no + component with the "keys" feature. + """ + if not self.board or not self.keys_component: + raise IncompleteKeyboardError( + "Cannot get build items for an incomplete keyboard" + ) - return groups + build_items: list[BuildItem] = [] + shield_lists = (shield.siblings or [shield.id] for shield in self.shields) + for combination in itertools.product(self.board_targets, *shield_lists): + # Python's type system can't represent that the first item is always a + # BoardTarget and the rest are always str, so explicit casts are needed. + target = cast("BoardTarget", combination[0]) + shields = cast("tuple[str, ...]", combination[1:]) + shield_str = " ".join(shields) if shields else None -def _find_hardware(path: Path) -> Generator[Hardware, None, None]: - for meta_path in path.rglob("*.zmk.yml"): - meta = read_yaml(meta_path) - meta["directory"] = meta_path.parent + build_items.append(BuildItem(board=target, shield=shield_str)) - match meta.get("type"): - case "board": - yield Board.from_dict(meta) + return build_items - case "shield": - yield Shield.from_dict(meta) + def get_config_path(self) -> Path: + """ + Get the path to the keyboard's .conf file. - case "interconnect": - yield Interconnect.from_dict(meta) + :raises IncompleteKeyboardError: if there is no component with the "keys" feature. + """ + return _get_keyboard_file(self.keys_component, self.board_revision, ".conf") + def get_keymap_path(self) -> Path: + """ + Get the path to the keyboard's .keymap file. -def _filter_hardware(item: Hardware, text: str): - text = text.casefold().strip() - return text in item.id.casefold() or text in item.name.casefold() + :raises IncompleteKeyboardError: if there is no component with the "keys" feature. + """ + return _get_keyboard_file(self.keys_component, self.board_revision, ".keymap") -def show_hardware_menu( - title: str, - items: Iterable[_HW], - **kwargs, -) -> _HW: +def split_board_qualifiers(identifier: str) -> tuple[str, str]: """ - Show a menu to select from a list of Hardware objects. + Splits a string into a board ID and board qualifiers. If the string contains + a revision, it is ignored and returned as part of the first value. - kwargs are passed through to zmk.menu.show_menu(), except for filter_func, - which is set to a function appropriate for filtering Hardware objects. + Examples: + "foo" -> "foo", "" + "foo/bar/baz" -> "foo", "/bar/baz" + "foo@2/bar/baz" -> "foo@2", "/bar/baz" """ - return show_menu(title=title, items=items, **kwargs, filter_func=_filter_hardware) + try: + index = identifier.index("/") + return identifier[0:index], identifier[index:] + except ValueError: + return identifier, "" -def show_revision_menu( - board: Hardware, title: str | None = None, **kwargs -) -> str | None: +def _get_filename(name: str): """ - Show a menu to select from a list of revisions for a board. - - If the board has no revisions, returns None without showing a menu. - If the board has only one revision, returns it without showing a menu. - - kwargs are passed through to zmk.menu.show_menu(), except for default_index, - which is set based on the board's default revision. + Replaces all special characters used in revisions and board qualifiers that + are not valid in a filename with underscores. """ + return re.sub(r"[./]+", "_", name) + - revisions = board.get_revisions() - if not revisions: - return None +def _get_keyboard_file( + keys_component: Board | Shield | None, revision: Revision, suffix: str +) -> Path: + if not keys_component: + raise IncompleteKeyboardError("Cannot get file path for an incomplete keyboard") - if len(revisions) == 1: - return revisions[0] + search_names: list[str] = [keys_component.id] + search_revisions: list[str] = [] + directory = keys_component.directory - default_revision = board.get_default_revision() - default_index = revisions.index(default_revision) if default_revision else 0 + if isinstance(keys_component, Board): + # If the keyboard has board qualifiers, also search without them. + target = BoardTarget.parse(keys_component.id) + if target.qualifiers: + search_names.append(target.name) - return show_menu( - title=title or f"Select a {board.name} revision:", - items=revisions, - default_index=default_index, - **kwargs, - ) + # If the keyboard has revisions, search for files with each possible + # spelling of the revision. + revision = revision or keys_component.default_revision + if revision: + search_revisions = revision.get_spellings() + + # Combine everything into a list of paths to search from most to least + # specific. + search_paths: list[Path] = [] + + for name in search_names: + for rev in search_revisions: + # Despite Zephyr board targets having the revision before the + # qualifiers, paths have the revision at the end for some reason. + path = (directory / _get_filename(name + "_" + rev)).with_suffix(suffix) + search_paths.append(path) + + path = (directory / _get_filename(name)).with_suffix(suffix) + search_paths.append(path) + + for path in search_paths: + if path.exists(): + return path + + # If none of these files exists, then just return the most general of + # the possible paths. For a shield, this will just be the shield ID itself. + # For a board, it will be the ID without revisions or qualifiers. + return search_paths[-1] diff --git a/zmk/hardware_list.py b/zmk/hardware_list.py new file mode 100644 index 0000000..d590054 --- /dev/null +++ b/zmk/hardware_list.py @@ -0,0 +1,205 @@ +from collections.abc import Generator, Iterable +from dataclasses import dataclass, field +from functools import reduce +from pathlib import Path +from typing import TypeVar + +from rich.console import Console + +from .hardware import Board, BoardTarget, Hardware, Interconnect, Keyboard, Shield +from .menu import show_menu +from .repo import Repo +from .revision import Revision +from .styles import MENU_THEME, BoardIdHighlighter +from .util import flatten +from .yaml import read_yaml + +_HW = TypeVar("_HW", bound=Hardware) + + +@dataclass +class HardwareGroups: + """Hardware grouped by type.""" + + keyboards: list[Board | Shield] = field(default_factory=list) + """List of boards/shields that are keyboard PCBs""" + controllers: list[Board] = field(default_factory=list) + """List of boards that are controllers for keyboards""" + interconnects: list[Interconnect] = field(default_factory=list) + """List of interconnect descriptions""" + + # TODO: add displays and other peripherals? + + def find_keyboard(self, item_id: str) -> Board | Shield | None: + """Find a keyboard by ID""" + # Ignore any board revision provided + target = BoardTarget.parse(item_id).with_revision(Revision()) + return _find_by_id(self.keyboards, str(target)) + + def find_controller(self, item_id: str) -> Board | None: + """Find a controller by ID""" + # Ignore any board revision provided + target = BoardTarget.parse(item_id).with_revision(Revision()) + return _find_by_id(self.controllers, str(target)) + + def find_interconnect(self, item_id: str) -> Interconnect | None: + """Find an interconnect by ID""" + return _find_by_id(self.interconnects, item_id) + + def filter_compatible_keyboards(self, keyboard: Keyboard): + """ + Modifies the "keyboards" list so it contains only shields compatible with + the boards and shields in a given keyboard object. + """ + self.keyboards = [ + kb + for kb in self.keyboards + if isinstance(kb, Shield) and keyboard.is_compatible(kb) + ] + + def filter_compatible_controllers(self, keyboard: Keyboard): + """ + Modifies the "controllers" list so it contains only boards compatible + with the shields on a given keyboard object. + """ + self.controllers = [c for c in self.controllers if keyboard.is_compatible(c)] + + def filter_to_interconnect(self, interconnect: Interconnect): + """ + Modifies the "controllers" list so it contains only boards that expose + the given interconnect. + + Modifies the "keyboards" list so it contains only shields that require + the given interconnect and boards/shields that provide it. + """ + self.controllers = [c for c in self.controllers if interconnect.id in c.exposes] + self.keyboards = [ + kb + for kb in self.keyboards + if interconnect.id in kb.exposes + or (isinstance(kb, Shield) and interconnect.id in kb.requires) + ] + + +def _find_by_id(hardware: Iterable[_HW], item_id: str) -> _HW | None: + norm_id = item_id.casefold() + return next((hw for hw in hardware if norm_id in hw.get_normalized_ids()), None) + + +def _find_hardware(path: Path) -> Generator[Hardware, None, None]: + for meta_path in path.rglob("*.zmk.yml"): + meta = read_yaml(meta_path) + meta["directory"] = meta_path.parent + + match meta.get("type"): + case "board": + yield Board.from_dict(meta) + + case "shield": + yield Shield.from_dict(meta) + + case "interconnect": + yield Interconnect.from_dict(meta) + + +def get_board_roots(repo: Repo) -> Iterable[Path]: + """Get the paths that contain hardware definitions for a repo""" + roots = set() + + if root := repo.board_root: + roots.add(root) + + for module in repo.get_modules(): + if root := module.board_root: + roots.add(root) + + return roots + + +def get_hardware(repo: Repo) -> HardwareGroups: + """Get lists of hardware descriptions, grouped by type, for a repo""" + hardware = flatten(_find_hardware(root) for root in get_board_roots(repo)) + + def func(groups: HardwareGroups, item: Hardware): + if isinstance(item, (Shield, Board)) and item.has_keys: + groups.keyboards.append(item) + elif isinstance(item, Board): + groups.controllers.append(item) + elif isinstance(item, Interconnect): + groups.interconnects.append(item) + + return groups + + groups = reduce(func, hardware, HardwareGroups()) + + groups.controllers = sorted(groups.controllers, key=lambda x: x.id) + groups.keyboards = sorted(groups.keyboards, key=lambda x: x.id) + groups.interconnects = sorted(groups.interconnects, key=lambda x: x.id) + + return groups + + +def show_hardware_menu( + title: str, + items: Iterable[_HW], + console: Console | None = None, + **kwargs, +) -> _HW: + """ + Show a menu to select from a list of Hardware objects. + + kwargs are passed through to zmk.menu.show_menu(), except for filter_func, + which is set to a function appropriate for filtering Hardware objects. + """ + if console is None: + console = Console(theme=MENU_THEME, highlighter=BoardIdHighlighter()) + + def filter_hardware(item: Hardware, text: str): + text = text.casefold().strip() + return text in item.id.casefold() or text in item.name.casefold() + + return show_menu( + title=title, + items=items, + console=console, + filter_func=filter_hardware, + **kwargs, + ) + + +def show_revision_menu( + board: Board, title: str | None = None, console: Console | None = None, **kwargs +) -> Revision: + """ + Show a menu to select from a list of revisions for a board. + + If the board has no revisions, returns Revision() without showing a menu. + If the board has only one revision, returns it without showing a menu. + + kwargs are passed through to zmk.menu.show_menu(), except for default_index, + which is set based on the board's default revision. + """ + + # Revisions could be listed in the .zmk.yml file in any order. Sort them + # descending so the latest revisions appear at the top. + revisions = sorted(board.revisions, reverse=True) + if not revisions: + return Revision() + + if len(revisions) == 1: + return revisions[0] + + default_revision = board.default_revision + default_index = revisions.index(default_revision) if default_revision else 0 + + if console is None: + # Disable the default highlighter so it doesn't colorize numbers + console = Console(highlighter=None) + + return show_menu( + title=title or f"Select a {board.name} revision:", + items=revisions, + default_index=default_index, + console=console, + **kwargs, + ) diff --git a/zmk/menu.py b/zmk/menu.py index 7917a5d..f4a9f6a 100644 --- a/zmk/menu.py +++ b/zmk/menu.py @@ -2,19 +2,24 @@ Terminal menus """ -from collections.abc import Callable, Iterable +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager, suppress -from typing import Generic, Self, TypeVar +from typing import Generic, Protocol, TypeVar, runtime_checkable import rich -from rich.console import Console, RenderableType +from rich.console import Console, RenderableType, group +from rich.control import Control from rich.highlighter import Highlighter -from rich.style import Style +from rich.live import Live +from rich.padding import Padding +from rich.style import Style, StyleType +from rich.table import Table from rich.text import Text from rich.theme import Theme from . import terminal -from .util import splice +from .styles import chain_highlighters +from .util import horizontal_group, splice class StopMenu(KeyboardInterrupt): @@ -23,7 +28,23 @@ class StopMenu(KeyboardInterrupt): """ -T = TypeVar("T") +@runtime_checkable +class MenuRow(Protocol): + """ + An object that will display multiple values in a menu. + + __menu_row__() functions like __rich__(), except it returns any number of + renderables, and each must be no more than one line tall. They will be + aligned in columns in the menu. + """ + + def __menu_row__(self) -> Iterable[RenderableType]: ... + + +T = TypeVar("T", bound=RenderableType | MenuRow) + +_MenuRow = tuple[list[RenderableType], StyleType | None] +"""Type alias for a list of values to render for a row and the style to apply to them""" class TerminalMenu(Generic[T], Highlighter): @@ -50,7 +71,7 @@ class TerminalMenu(Generic[T], Highlighter): } ) - title: RenderableType | None + title: str | None items: list[T] default_index: int @@ -67,13 +88,14 @@ class TerminalMenu(Generic[T], Highlighter): def __init__( self, - title: RenderableType | None, + title: str | None, items: Iterable[T], *, default_index=0, filter_func: Callable[[T, str], bool] | None = None, console: Console | None = None, theme: Theme | None = None, + padding=3, ): """ An interactive terminal menu. @@ -87,12 +109,14 @@ def __init__( field will be displayed after the menu title. :param console: Console in which to display the menu. :param theme: Theme to apply. See TerminalMenu.DEFAULT_THEME for style names. + :param padding: Number of spaces between columns. """ self.title = title self.items = list(items) self.console = console or rich.get_console() self.theme = theme or self.DEFAULT_THEME self.default_index = default_index + self.padding = padding self._filter_func = filter_func self._filter_text = "" @@ -109,11 +133,17 @@ def __init__( self._num_title_lines = 0 self._last_title_line_len = 0 - if self._get_display_count() == self._max_items_per_page: - self._top_row = 1 + if terminal.cursor_control_supported(): + if self._get_display_count() == self._max_items_per_page: + self._top_row = 1 + else: + _, y = terminal.get_cursor_pos() + self._top_row = min(y, self.console.height - self._get_menu_height()) else: - row, _ = terminal.get_cursor_pos() - self._top_row = min(row, self.console.height - self._get_menu_height()) + # If get_cursor_pos() is unsupported, then we can't move the cursor + # accurately between the end of the menu and the filter text. Disable + # the filter feature so the menu still mostly works. + self._filter_func = None self._apply_filter() @@ -127,28 +157,16 @@ def show(self) -> T: self._focus_index = self.default_index - try: - with self._context(): - while True: - self._scroll_index = self._get_scroll_index() - self._print_menu() - self._move_cursor_to_filter() - - if self.has_filter: - terminal.show_cursor() + with self._context() as live: + while True: + self._scroll_index = self._get_scroll_index() + live.update(self._render_menu(), refresh=True) + with self._move_cursor_to_filter(): if self._handle_input(): - try: + # _focus_index may be invalid if _filter_items is empty. + with suppress(IndexError): return self._filter_items[self._focus_index] - except IndexError: - pass - - if self.has_filter: - terminal.hide_cursor() - - self._move_cursor_to_top() - finally: - self._erase_controls() @property def has_filter(self) -> bool: @@ -173,16 +191,47 @@ def highlight(self, text: Text) -> None: @contextmanager def _context(self): + controls = Text(self.CONTROLS, style="dim", no_wrap=True, overflow="crop") + if self.has_filter: + controls.append(self.FILTER_CONTROLS) + old_highlighter = self.console.highlighter try: - terminal.hide_cursor() - self.console.highlighter = self - - with self.console.use_theme(self.theme): - yield + self.console.show_cursor(show=False) + + # Merge the console's existing highlighter with the menu highlighter + self.console.highlighter = chain_highlighters(old_highlighter, self) + + # Display the title and menu items in a live view, which we will + # update as the user interacts with the menu. Below that, display + # the control help text. This never gets updated, but it uses a + # transient live view to it gets hidden when the menu is closed. + # Both need redirect_stdout=False because the default behavior will + # conflict with our cursor position modifications, resulting in the + # rendered menu not being cleaned up properly on Esc/Ctrl+C. + with ( + self.console.use_theme(self.theme), + Live( + console=self.console, + redirect_stdout=False, + auto_refresh=False, + ) as live, + Live( + controls, + console=self.console, + redirect_stdout=False, + auto_refresh=False, + transient=True, + ), + ): + yield live finally: - terminal.show_cursor() self.console.highlighter = old_highlighter + self.console.show_cursor(show=True) + + # Add one blank line when the menu is closed to give some space + # between the menu and whatever follows it. + self.console.print() def _apply_filter(self): if self._filter_func: @@ -195,78 +244,91 @@ def _apply_filter(self): i for i in self.items if self._filter_func(i, self._filter_text) ] + # If the previously-focused item is still visible, update the focus + # index to that item's new index. Update the scroll index as well to + # try to keep that item in the same place on the screen. if old_focus is not None: with suppress(ValueError): + scroll_offset = self._focus_index - self._scroll_index self._focus_index = self._filter_items.index(old_focus) + self._scroll_index = self._focus_index - scroll_offset else: self._filter_items = self.items self._clamp_focus_index() - def _print_menu(self): + @group() + def _render_menu(self): if self.title: - self.console.print( - f"[title]{self.title}[/title] [filter]{self._filter_text}[/filter]", - justify="left", - highlight=False, + yield Text.assemble( + (self.title, "title"), " ", (self._filter_text, "filter") ) + # Find the items that are visible and render them to a list of rows. + # Organize the rows into a grid to align columns of data. + grid = Table.grid(padding=(0, self.padding)) + grid.highlight = True + + if rows := list(self._render_rows()): + max_columns = max(len(row) for row, _ in rows) + + for _ in range(max_columns): + grid.add_column(no_wrap=True) + + for row, style in rows: + grid.add_row(*row, style=style) + + # Wrap the grid in a Padding() so it clears the entire width of the terminal. + # (expand=True on the grid would also work, but that affects column widths.) + yield Padding(grid) + + def _render_rows(self) -> Generator[_MenuRow]: display_count = self._get_display_count() - for row in range(display_count): - if row == 0 and not self._filter_items: - self.console.print( - "[dim]No matching items", - justify="left", - highlight=False, - no_wrap=True, - ) - continue + # If the filter doesn't match any items, display a message on the first + # line and blank the rest of the menu. + if not self._filter_items: + yield (["No matching items"], Style(dim=True)) + for _ in range(display_count - 1): + yield ([], None) + + return + + scroll_at_top = self._scroll_index == 0 + scroll_at_bottom = self._scroll_index + display_count >= len(self._filter_items) + for row in range(display_count): index = self._scroll_index + row focused = index == self._focus_index - at_start = self._scroll_index == 0 - at_end = self._scroll_index + display_count >= len(self._filter_items) - show_more = (not at_start and row == 0) or ( - not at_end and row == display_count - 1 - ) + is_top_ellipsis = row == 0 and not scroll_at_top + is_bottom_ellipsis = row == display_count - 1 and not scroll_at_bottom - try: - item = self._filter_items[index] - self._print_item(item, focused=focused, show_more=show_more) - except IndexError: - self.console.print(justify="left") + if is_top_ellipsis or is_bottom_ellipsis: + yield ([" ..."], "ellipsis") + else: + try: + yield self._render_item(self._filter_items[index], focused=focused) + except IndexError: + yield ([], None) - controls = self.CONTROLS - if self.has_filter: - controls += self.FILTER_CONTROLS - - self.console.print( - controls, - style="controls", - end="", - highlight=False, - no_wrap=True, - overflow="crop", - ) + def _render_item(self, item: T | str, *, focused: bool) -> _MenuRow: + style = "focus" if focused else "unfocus" - def _print_item(self, item: T | str, *, focused: bool, show_more: bool): - style = "ellipsis" if show_more else "focus" if focused else "unfocus" + columns: list[RenderableType] + if isinstance(item, MenuRow): + columns = list(item.__menu_row__()) or [""] + else: + columns = [item] + # The table has larger padding between columns than we want for the + # focused item indicator or indent on unfocused items, so modify the + # value in the first column to contain the indicator/indent instead of + # putting it in a separate column. indent = "> " if focused else " " - item = "..." if show_more else item - - self.console.print( - indent, - item, - sep="", - style=style, - highlight=True, - justify="left", - no_wrap=True, - overflow="ellipsis", - ) + columns[0] = horizontal_group(indent, columns[0]) + + return ([*columns], style) def _clamp_focus_index(self): self._focus_index = min(max(0, self._focus_index), len(self._filter_items) - 1) @@ -325,7 +387,7 @@ def _handle_input(self): return False def _handle_backspace(self): - if self._cursor_index == 0: + if not self.has_filter or self._cursor_index == 0: return self._filter_text = splice(self._filter_text, self._cursor_index - 1, count=1) @@ -333,13 +395,16 @@ def _handle_backspace(self): self._apply_filter() def _handle_delete(self): - if self._cursor_index == len(self._filter_text): + if not self.has_filter or self._cursor_index == len(self._filter_text): return self._filter_text = splice(self._filter_text, self._cursor_index, count=1) self._apply_filter() def _handle_text(self, key: bytes): + if not self.has_filter: + return + text = key.decode() self._filter_text = splice( self._filter_text, self._cursor_index, insert_text=text @@ -370,44 +435,65 @@ def _get_scroll_index(self): items_count = len(self._filter_items) display_count = self._get_display_count() - if items_count < display_count: + if items_count <= display_count: + # There is enough space to show the whole menu without scrolling. return 0 first_displayed = self._scroll_index last_displayed = first_displayed + display_count - 1 + if last_displayed >= items_count: + # There are more items in the menu than available space, but the + # current scroll position would leave blank spaces at the bottom. + # Scroll up enough to fill every row of the menu with an item. + first_displayed = items_count - display_count + last_displayed = first_displayed + display_count - 1 + if self._focus_index <= first_displayed + self.SCROLL_MARGIN: - return max(0, self._focus_index - 1 - self.SCROLL_MARGIN) + # The focused item is off the top of the screen. Scroll up enough to + # get it in view. Also fit as many menu items in as possible so we + # don't get just a couple visible when there's room for more. + start = min( + items_count - display_count, + self._focus_index - 1 - self.SCROLL_MARGIN, + ) + return max(0, start) if self._focus_index >= last_displayed - self.SCROLL_MARGIN: + # Focused item is off the bottom of the screen. Scroll down enough + # to get it in view. end = min(items_count - 1, self._focus_index + 1 + self.SCROLL_MARGIN) return end - (display_count - 1) - return self._scroll_index - - def _move_cursor_to_top(self): - """Move the cursor to the start of the menu""" - terminal.set_cursor_pos(row=self._top_row) + return first_displayed + @contextmanager def _move_cursor_to_filter(self): - """Move the cursor to the filter text field""" - row = self._top_row + self._num_title_lines - 1 - col = self._last_title_line_len + self._cursor_index + """ + Context manager which move the cursor to the filter text field and shows + it, runs the context, then sets the cursor back where it was and hides it. + """ - terminal.set_cursor_pos(row, col) + if not self.has_filter: + yield + return - def _erase_controls(self): - """Hide the controls text and reset the cursor to after the menu""" - row = self.console.height - 1 + orig_x, orig_y = terminal.get_cursor_pos() - terminal.set_cursor_pos(row=row, col=0) - self.console.print(justify="left") + x = self._last_title_line_len + self._cursor_index + y = self._top_row + self._num_title_lines - 1 - terminal.set_cursor_pos(self._top_row + len(self._filter_items) + 1) + try: + self.console.control(Control.move_to(x, y)) + self.console.show_cursor(show=True) + yield + finally: + self.console.show_cursor(show=False) + self.console.control(Control.move_to(orig_x, orig_y)) def show_menu( - title: RenderableType | None, + title: str | None, items: Iterable[T], *, default_index=0, @@ -442,45 +528,26 @@ def show_menu( class Detail(Generic[T]): - """A menu item with a description appended to the end.""" - - MIN_PAD = 2 + """A menu item with a description.""" data: T detail: str - pad_len: int def __init__(self, data: T, detail: str): self.data = data self.detail = detail - self.pad_len = self.MIN_PAD - - def __rich__(self): - text = Text.assemble(str(self.data), " " * self.pad_len, (self.detail, "dim")) - # Returning the Text object directly works, but it doesn't get highlighted. - return text.markup - - @classmethod - def align(cls, items: Iterable[Self], console: Console | None = None) -> list[Self]: - """Set the padding for each item in the list to align the detail strings.""" - items = list(items) - console = console or rich.get_console() - for item in items: - item.pad_len = console.measure(str(item.data)).minimum - - width = max(item.pad_len for item in items) - - for item in items: - item.pad_len = width - item.pad_len + cls.MIN_PAD + def __menu_row__(self) -> Iterable[RenderableType]: + if isinstance(self.data, MenuRow): + yield from self.data.__menu_row__() + else: + yield self.data - return items + yield f"[dim]{self.detail}" -def detail_list( - items: Iterable[tuple[T, str]], console: Console | None = None -) -> list[Detail[T]]: +def detail_list(items: Iterable[tuple[T, str]]) -> list[Detail[T]]: """ - Create a list of menu items with a description appended to each item. + Create a list of menu items with a description next to each item. """ - return Detail.align([Detail(item, desc) for item, desc in items], console=console) + return [Detail(item, desc) for item, desc in items] diff --git a/zmk/revision.py b/zmk/revision.py new file mode 100644 index 0000000..f197505 --- /dev/null +++ b/zmk/revision.py @@ -0,0 +1,156 @@ +import re + + +class Revision: + """ + Represents a Zephyr board revision. + + Revisions are automatically normalized, so comparing two objects for equality + works even if the same revision is spelled differently, e.g. + ``` + Revision("1") == Revision("1.1") # false + Revision("1") == Revision("1.0") # true + Revision("1.0.0") in {Revision("1"), Revision("2")} # true + ``` + + `Revision()` represents the lack of a revision. Truth testing a revision with + an empty string value returns false, so `Revision | None` is not needed to + represent an optional revision. + """ + + def __init__(self, value: str = ""): + self._value = _normalize(value) + + def __bool__(self): + return bool(self._value) + + def __hash__(self): + return hash(self._value) + + def __eq__(self, other): + return isinstance(other, Revision) and other._value == self._value + + def __lt__(self, other): + if not isinstance(other, Revision): + raise TypeError() + return _revision_lt(self, other) + + def __str__(self): + return self._value + + def __rich__(self): + return str(self) + + def __repr__(self): + return f'Revision("{self._value})"' if self else "Revision()" + + def __menu_row__(self): + # In menus, use the most specific spelling so adjacent items don't have + # varying widths. + return [self.specific_str] + + @property + def at_str(self) -> str: + """ + The revision prefixed with "@" if this object has a value, else "". + """ + return f"@{self}" if self else "" + + @property + def specific_str(self) -> str: + """ + The most specific spelling of this revision, e.g. Revision("1") -> "1.0.0" + """ + if not self: + return "" + + if self._value.isalpha(): + return self._value + + result = self._value + for _ in range(2 - self._value.count(".")): + result += ".0" + + return result + + def get_spellings(self) -> list[str]: + """ + Returns a list of all equivalent spellings of this revision. This may be + necessary for example when you have a normalized revision and need to find + board files that apply to it (e.g. board_1_0_0.overlay). + + Return values are ordered from most to least specific. + + Examples: + ``` + "A" -> ["A", "a"] + "1" -> ["1.0.0", "1.0", "1"] + "1.2" -> ["1.2.0", "1.2"] + "1.2.3" -> ["1.2.3"] + ``` + """ + if not self: + return [""] + + if self._value.isalpha(): + return [self._value, self._value.lower()] + + value = self.specific_str + result = [value] + + while value.endswith(".0"): + value = value[:-2] + result.append(value) + + return result + + +def _normalize(revision: str) -> str: + """ + Normalizes letter revisions to uppercase and shortens numeric versions to + the smallest form with the same meaning. + + Examples: + ``` + "" -> "" + "a" -> "A" + "1.2.0" -> "1.2" + "2.0.0" -> "2" + ``` + """ + return re.sub(r"(?:\.0){1,2}$", "", revision.strip()).upper() + + +def _to_number_list(revision: str) -> list[int]: + """ + Splits a revision string into a list of numbers. Must not be an alphabetical + revision. + """ + return [int(part) for part in revision.split(".")] + + +def _revision_lt(lhs: Revision, rhs: Revision) -> bool: + """ """ + lhs_str = lhs.specific_str + rhs_str = rhs.specific_str + + # Place Revision() before Revision("...") + if not lhs_str: + return bool(rhs_str) + + if not rhs_str: + return False + + # Place Revision("A") before Revision("1"). + if lhs_str.isalpha(): + if rhs_str.isalpha(): + # Both revisions are alphabetical. Can compare directly. + return lhs_str < rhs_str + + return True + + if rhs_str.isalpha(): + return False + + # Both revisions are numerical. Can compare as a list of numbers. + return _to_number_list(lhs_str) < _to_number_list(rhs_str) diff --git a/zmk/styles.py b/zmk/styles.py index 29dcc6c..dc39d1c 100644 --- a/zmk/styles.py +++ b/zmk/styles.py @@ -11,6 +11,8 @@ class KeyValueHighlighter(RegexHighlighter): """Highlight "key=value" items.""" + base_style = "kv." + highlights = [ # noqa: RUF012 https://github.com/astral-sh/ruff/issues/5429 r"(?P[\w.]+)(?P=)(?P.*)" ] @@ -22,7 +24,7 @@ class BoardIdHighlighter(RegexHighlighter): base_style = "board." highlights = [ # noqa: RUF012 https://github.com/astral-sh/ruff/issues/5429 - r"(?P[/])", + r"(?P/[a-zA-Z_]?\w*/[a-zA-Z_]\w*\b)", r"(?P@(?:[A-Z]|([0-9]+(\.[0-9]+){0,2})))", ] @@ -33,30 +35,39 @@ class CommandLineHighlighter(RegexHighlighter): base_style = "cmd." highlights = [ # noqa: RUF012 https://github.com/astral-sh/ruff/issues/5429 - r"(?P-[a-zA-Z]|--[a-zA-Z-_]+)" + r"(?-[a-zA-Z]|--[a-zA-Z-_]+)" ] THEME = Theme( { - "key": "bright_blue", - "equals": "dim blue", + "title": "bright_magenta", + "kv.key": "bright_blue", + "kv.equals": "dim blue", "value": "default", - "board.separator": "dim", - "board.revision": "sky_blue2", + "board.revision": "rgb(135,95,175)", + "board.qualifier": "rgb(135,95,175)", "cmd.flag": "dim", } ) +# Don't use colors in menus, as that will override the focus style. +MENU_THEME = Theme({"board.revision": "dim", "board.qualifier": "dim"}) + + +def chain_highlighters(*highlighters: HighlighterType | None) -> HighlighterType: + """ + Return a new highlighter which runs each of the given highlighters in order. -def chain_highlighters(highlighters: list[HighlighterType]) -> HighlighterType: - """Return a new highlighter which runs each of the given highlighters in order""" + Arguments that are None will be skipped. + """ def run_all_highlighters(text: str | Text): text = text if isinstance(text, Text) else Text(text) for item in highlighters: - text = item(text) + if item: + text = item(text) return text diff --git a/zmk/terminal.py b/zmk/terminal.py index fc58cb9..b56c7cf 100644 --- a/zmk/terminal.py +++ b/zmk/terminal.py @@ -1,5 +1,5 @@ """ -Terminal utilities +Terminal utilities for things not already provided by Rich. """ # Ignore missing attributes for platform-specific modules @@ -13,19 +13,6 @@ from collections.abc import Generator from contextlib import contextmanager - -def hide_cursor() -> None: - """Hides the terminal cursor.""" - sys.stdout.write("\x1b[?25l") - sys.stdout.flush() - - -def show_cursor() -> None: - """Unhides the terminal cursor.""" - sys.stdout.write("\x1b[?25h") - sys.stdout.flush() - - ESCAPE = b"\x1b" BACKSPACE = b"\b" RETURN = b"\n" @@ -48,14 +35,7 @@ def show_cursor() -> None: _STD_INPUT_HANDLE = -10 _STD_OUTPUT_HANDLE = -11 - _ENABLE_PROCESSED_OUTPUT = 1 - _ENABLE_WRAP_AT_EOL_OUTPUT = 2 _ENABLE_VIRTUAL_TERMINAL_PROCESSING = 4 - _VT_FLAGS = ( - _ENABLE_PROCESSED_OUTPUT - | _ENABLE_WRAP_AT_EOL_OUTPUT - | _ENABLE_VIRTUAL_TERMINAL_PROCESSING - ) _WINDOWS_SPECIAL_KEYS = { 71: HOME, @@ -89,25 +69,6 @@ def read_key() -> bytes: return key - @contextmanager - def enable_vt_mode() -> Generator[None, None, None]: - """ - Context manager which enables virtual terminal processing. - """ - kernel32 = windll.kernel32 - stdout_handle = kernel32.GetStdHandle(_STD_OUTPUT_HANDLE) - - old_stdout_mode = wintypes.DWORD() - kernel32.GetConsoleMode(stdout_handle, byref(old_stdout_mode)) - - new_stdout_mode = old_stdout_mode.value | _VT_FLAGS - - try: - kernel32.SetConsoleMode(stdout_handle, new_stdout_mode) - yield - finally: - kernel32.SetConsoleMode(stdout_handle, old_stdout_mode) - @contextmanager def disable_echo() -> Generator[None, None, None]: """ @@ -125,16 +86,22 @@ def disable_echo() -> Generator[None, None, None]: finally: kernel32.SetConsoleMode(stdin_handle, old_stdin_mode) -except ImportError: - import termios - - @contextmanager - def enable_vt_mode() -> Generator[None, None, None]: + def cursor_control_supported() -> bool: """ - Context manager which enables virtual terminal processing. + Gets whether this terminal supports the virtual terminal escape sequence + for getting the cursor position. """ - # Assume that Unix terminals support VT escape sequences by default. - yield + kernel32 = windll.kernel32 + stdout_handle = kernel32.GetStdHandle(_STD_OUTPUT_HANDLE) + + stdout_mode = wintypes.DWORD() + kernel32.GetConsoleMode(stdout_handle, byref(stdout_mode)) + + return bool(stdout_mode.value & _ENABLE_VIRTUAL_TERMINAL_PROCESSING) + + +except ImportError: + import termios @contextmanager def disable_echo() -> Generator[None, None, None]: @@ -165,10 +132,20 @@ def read_key() -> bytes: return key + def cursor_control_supported() -> bool: + """ + Gets whether this terminal supports the virtual terminal escape sequence + for getting the cursor position. + """ + # Assume that Unix terminals support VT escape sequences by default. + return True + def get_cursor_pos() -> tuple[int, int]: """ - Returns the cursor position as a tuple (row, column). Positions are 0-based. + Returns the cursor position as a tuple (x, y). Positions are 0-based. + + This function may not work properly if cursor_control_supported() returns False. """ with disable_echo(): sys.stdout.write("\x1b[6n") @@ -179,13 +156,4 @@ def get_cursor_pos() -> tuple[int, int]: result += sys.stdin.read(1) row, _, col = result.removeprefix("\x1b[").removesuffix("R").partition(";") - return (int(row) - 1, int(col) - 1) - - -def set_cursor_pos(row=0, col=0) -> None: - """ - Sets the cursor to the given row and column. Positions are 0-based. - """ - with disable_echo(): - sys.stdout.write(f"\x1b[{row + 1};{col + 1}H") - sys.stdout.flush() + return (int(col) - 1, int(row) - 1) diff --git a/zmk/util.py b/zmk/util.py index ff38f50..6904b83 100644 --- a/zmk/util.py +++ b/zmk/util.py @@ -4,14 +4,14 @@ import functools import operator -import os -from collections.abc import Generator, Iterable +from collections.abc import Iterable from contextlib import contextmanager -from pathlib import Path from typing import TypeVar -from rich.console import Console +from rich.console import Console, RenderableType +from rich.padding import PaddingDimensions from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.table import Table T = TypeVar("T") @@ -21,7 +21,12 @@ def flatten(items: Iterable[T | Iterable[T]]) -> Iterable[T]: return functools.reduce(operator.iconcat, items, []) -def splice(text: str, index: int, count: int = 0, insert_text: str = ""): +def union(items: Iterable[set[T]]) -> set[T]: + """Compute the union of any number of sets""" + return functools.reduce(operator.or_, items, set()) + + +def splice(text: str, index: int, count: int = 0, insert_text: str = "") -> str: """ Remove `count` characters starting from `index` in `text` and replace them with `insert_text`. @@ -29,16 +34,27 @@ def splice(text: str, index: int, count: int = 0, insert_text: str = ""): return text[0:index] + insert_text + text[index + count :] -@contextmanager -def set_directory(path: Path) -> Generator[None, None, None]: - """Context manager to temporarily change the working directory""" - original = Path().absolute() - - try: - os.chdir(path) - yield - finally: - os.chdir(original) +def horizontal_group( + *renderables: RenderableType, + padding: PaddingDimensions = 0, + collapse_padding=True, + pad_edge=False, + expand=False, + highlight=True, +) -> Table: + """ + Similar to rich.group.Group, but uses a table to place renderables in a row + instead of a column. + """ + grid = Table.grid( + padding=padding, + collapse_padding=collapse_padding, + pad_edge=pad_edge, + expand=expand, + ) + grid.highlight = highlight + grid.add_row(*renderables) + return grid @contextmanager