From 19d94c16887d153b9c2c37016c338e12cf472588 Mon Sep 17 00:00:00 2001 From: Joel Spadin Date: Mon, 2 Mar 2026 00:15:41 -0600 Subject: [PATCH 1/8] fix: Fix incorrect vscode settings The "ruff.enabled" setting is actually named "ruff.enable". The "python.analysis.typeCheckingMode" setting is ignored when the corresponding option is set in pyproject.toml, so having it set in settings.json was unnecessary. --- .vscode/settings.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index a7fee3e..771f298 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -21,8 +21,7 @@ "**/templates/common/**": "mako" }, "pylint.enabled": false, - "ruff.enabled": true, + "ruff.enable": true, "python.analysis.importFormat": "relative", - "python.analysis.typeCheckingMode": "standard", "python-envs.pythonProjects": [] } From 2b0eaa4bb7fedd642732277b4c4ee0c094d9a5e5 Mon Sep 17 00:00:00 2001 From: Joel Spadin Date: Sat, 7 Mar 2026 19:23:06 -0600 Subject: [PATCH 2/8] fix: Use Ruff for import sorting The "source.organizeImports" code action could be handled by other extensions such as isort. This can result in some strange behaviors like isort organizing imports on multiple lines, Ruff reformatting them back to a single line, and the last line of the file getting duplicated. Explicitly state that we want Ruff to handle organizing imports. --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 771f298..f3ec19f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,7 +3,7 @@ "[python]": { "editor.defaultFormatter": "charliermarsh.ruff", "editor.codeActionsOnSave": { - "source.organizeImports": "explicit" + "source.organizeImports.ruff": "explicit" } }, "[markdown][yaml]": { From 80c5ea8ad86cd6b35deb20fcbc1ed67e659b1529 Mon Sep 17 00:00:00 2001 From: Joel Spadin Date: Sat, 7 Mar 2026 19:25:21 -0600 Subject: [PATCH 3/8] fix: Add "build" to exclusions for pyright Updated pyright settings to add an exclusion for the "build" directory. This prevents VS Code from finding copies of zmk files in that directory and suggesting things from them in code completion. --- pyproject.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index be9ea13..2a0c3ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,13 @@ build-backend = "setuptools.build_meta" [tool.pyright] venvPath = "." venv = ".venv" +exclude = [ + ".venv", + "build", + "**/.*", + "**/node_modules", + "**/__pycache__", +] [tool.ruff.lint] select = [ From faf6768eedbd1ca7c11ae6e151048ad02333dfd8 Mon Sep 17 00:00:00 2001 From: Joel Spadin Date: Fri, 6 Mar 2026 20:02:43 -0600 Subject: [PATCH 4/8] feat: Rework menu rendering Menus now use Rich's Live object instead of manually printing lines and resetting the cursor to the top of the terminal. This also handles erasing the help line at the bottom of the menu when the menu is closed. Replaced the hide_cursor(), show_cursor(), and set_cursor_pos() functions with ones already provided by Rich. Rich doesn't provide a function to get the cursor position though, so get_cursor_pos() is still needed. The order of the return values was changed from (row, col) to (x, y) to match the Rich functions. Updated the menu to render as a grid. Objects supporting a new MenuRow protocol can now provide multiple cells of data, which will be aligned into columns. This replaces the manual padding that was done to align columns with the Detail class, and it allows Rich to automatically size the columns based on just the items that are visible. Also improved the behavior when filtering the menu to try to keep the focused item in the same spot on the screen, and to prevent cases where the scroll position is too far down, so there are blank rows even though there are more items that could be displayed there. --- zmk/commands/module/remove.py | 17 +- zmk/menu.py | 322 ++++++++++++++++++++-------------- zmk/terminal.py | 28 +-- 3 files changed, 203 insertions(+), 164 deletions(-) diff --git a/zmk/commands/module/remove.py b/zmk/commands/module/remove.py index 6891876..16ae8f0 100644 --- a/zmk/commands/module/remove.py +++ b/zmk/commands/module/remove.py @@ -5,17 +5,19 @@ import shutil import stat import subprocess +from collections.abc import Iterable from dataclasses import dataclass from pathlib import Path from typing import Annotated, Any import rich import typer +from rich.console import RenderableType from west.manifest import Project from ...config import get_config from ...exceptions import FatalError -from ...menu import Detail, detail_list, show_menu +from ...menu import show_menu from ...repo import Repo from ...util import spinner from ...yaml import YAML @@ -91,10 +93,10 @@ def remove_readonly(func, path, _): def _prompt_project(projects: list[Project]): - items = detail_list((ProjectWrapper(p), p.url) for p in projects) + items = [ProjectWrapper(p) for p in projects] result = show_menu("Select the module to remove:", items, filter_func=_filter) - return result.data.project + return result.project @dataclass @@ -103,10 +105,11 @@ class ProjectWrapper: project: Project - def __str__(self): - return self.project.name + def __menu_row__(self) -> Iterable[RenderableType]: + yield self.project.name + yield self.project.url -def _filter(item: Detail[ProjectWrapper], text: str): +def _filter(item: ProjectWrapper, text: str): text = text.casefold() - return text in item.data.project.name.casefold() or text in item.detail + return text in item.project.name.casefold() or text in item.project.url.casefold() diff --git a/zmk/menu.py b/zmk/menu.py index 7917a5d..25e5561 100644 --- a/zmk/menu.py +++ b/zmk/menu.py @@ -2,19 +2,24 @@ Terminal menus """ -from collections.abc import Callable, Iterable +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager, suppress -from typing import Generic, Self, TypeVar +from typing import Generic, Protocol, TypeVar, runtime_checkable import rich -from rich.console import Console, RenderableType +from rich.console import Console, RenderableType, group +from rich.control import Control from rich.highlighter import Highlighter -from rich.style import Style +from rich.live import Live +from rich.padding import Padding +from rich.style import Style, StyleType +from rich.table import Table from rich.text import Text from rich.theme import Theme from . import terminal -from .util import splice +from .styles import chain_highlighters +from .util import horizontal_group, splice class StopMenu(KeyboardInterrupt): @@ -23,7 +28,23 @@ class StopMenu(KeyboardInterrupt): """ -T = TypeVar("T") +@runtime_checkable +class MenuRow(Protocol): + """ + An object that will display multiple values in a menu. + + __menu_row__() functions like __rich__(), except it returns any number of + renderables, and each must be no more than one line tall. They will be + aligned in columns in the menu. + """ + + def __menu_row__(self) -> Iterable[RenderableType]: ... + + +T = TypeVar("T", bound=RenderableType | MenuRow) + +_MenuRow = tuple[list[RenderableType], StyleType | None] +"""Type alias for a list of values to render for a row and the style to apply to them""" class TerminalMenu(Generic[T], Highlighter): @@ -50,7 +71,7 @@ class TerminalMenu(Generic[T], Highlighter): } ) - title: RenderableType | None + title: str | None items: list[T] default_index: int @@ -67,13 +88,14 @@ class TerminalMenu(Generic[T], Highlighter): def __init__( self, - title: RenderableType | None, + title: str | None, items: Iterable[T], *, default_index=0, filter_func: Callable[[T, str], bool] | None = None, console: Console | None = None, theme: Theme | None = None, + padding=3, ): """ An interactive terminal menu. @@ -87,12 +109,14 @@ def __init__( field will be displayed after the menu title. :param console: Console in which to display the menu. :param theme: Theme to apply. See TerminalMenu.DEFAULT_THEME for style names. + :param padding: Number of spaces between columns. """ self.title = title self.items = list(items) self.console = console or rich.get_console() self.theme = theme or self.DEFAULT_THEME self.default_index = default_index + self.padding = padding self._filter_func = filter_func self._filter_text = "" @@ -112,8 +136,8 @@ def __init__( if self._get_display_count() == self._max_items_per_page: self._top_row = 1 else: - row, _ = terminal.get_cursor_pos() - self._top_row = min(row, self.console.height - self._get_menu_height()) + _, y = terminal.get_cursor_pos() + self._top_row = min(y, self.console.height - self._get_menu_height()) self._apply_filter() @@ -127,28 +151,16 @@ def show(self) -> T: self._focus_index = self.default_index - try: - with self._context(): - while True: - self._scroll_index = self._get_scroll_index() - self._print_menu() - self._move_cursor_to_filter() - - if self.has_filter: - terminal.show_cursor() + with self._context() as live: + while True: + self._scroll_index = self._get_scroll_index() + live.update(self._render_menu(), refresh=True) + with self._move_cursor_to_filter(): if self._handle_input(): - try: + # _focus_index may be invalid if _filter_items is empty. + with suppress(IndexError): return self._filter_items[self._focus_index] - except IndexError: - pass - - if self.has_filter: - terminal.hide_cursor() - - self._move_cursor_to_top() - finally: - self._erase_controls() @property def has_filter(self) -> bool: @@ -173,16 +185,47 @@ def highlight(self, text: Text) -> None: @contextmanager def _context(self): + controls = Text(self.CONTROLS, style="dim", no_wrap=True, overflow="crop") + if self.has_filter: + controls.append(self.FILTER_CONTROLS) + old_highlighter = self.console.highlighter try: - terminal.hide_cursor() - self.console.highlighter = self - - with self.console.use_theme(self.theme): - yield + self.console.show_cursor(show=False) + + # Merge the console's existing highlighter with the menu highlighter + self.console.highlighter = chain_highlighters(old_highlighter, self) + + # Display the title and menu items in a live view, which we will + # update as the user interacts with the menu. Below that, display + # the control help text. This never gets updated, but it uses a + # transient live view to it gets hidden when the menu is closed. + # Both need redirect_stdout=False because the default behavior will + # conflict with our cursor position modifications, resulting in the + # rendered menu not being cleaned up properly on Esc/Ctrl+C. + with ( + self.console.use_theme(self.theme), + Live( + console=self.console, + redirect_stdout=False, + auto_refresh=False, + ) as live, + Live( + controls, + console=self.console, + redirect_stdout=False, + auto_refresh=False, + transient=True, + ), + ): + yield live finally: - terminal.show_cursor() self.console.highlighter = old_highlighter + self.console.show_cursor(show=True) + + # Add one blank line when the menu is closed to give some space + # between the menu and whatever follows it. + self.console.print() def _apply_filter(self): if self._filter_func: @@ -195,78 +238,91 @@ def _apply_filter(self): i for i in self.items if self._filter_func(i, self._filter_text) ] + # If the previously-focused item is still visible, update the focus + # index to that item's new index. Update the scroll index as well to + # try to keep that item in the same place on the screen. if old_focus is not None: with suppress(ValueError): + scroll_offset = self._focus_index - self._scroll_index self._focus_index = self._filter_items.index(old_focus) + self._scroll_index = self._focus_index - scroll_offset else: self._filter_items = self.items self._clamp_focus_index() - def _print_menu(self): + @group() + def _render_menu(self): if self.title: - self.console.print( - f"[title]{self.title}[/title] [filter]{self._filter_text}[/filter]", - justify="left", - highlight=False, + yield Text.assemble( + (self.title, "title"), " ", (self._filter_text, "filter") ) + # Find the items that are visible and render them to a list of rows. + # Organize the rows into a grid to align columns of data. + grid = Table.grid(padding=(0, self.padding)) + grid.highlight = True + + if rows := list(self._render_rows()): + max_columns = max(len(row) for row, _ in rows) + + for _ in range(max_columns): + grid.add_column(no_wrap=True) + + for row, style in rows: + grid.add_row(*row, style=style) + + # Wrap the grid in a Padding() so it clears the entire width of the terminal. + # (expand=True on the grid would also work, but that affects column widths.) + yield Padding(grid) + + def _render_rows(self) -> Generator[_MenuRow]: display_count = self._get_display_count() - for row in range(display_count): - if row == 0 and not self._filter_items: - self.console.print( - "[dim]No matching items", - justify="left", - highlight=False, - no_wrap=True, - ) - continue + # If the filter doesn't match any items, display a message on the first + # line and blank the rest of the menu. + if not self._filter_items: + yield (["No matching items"], Style(dim=True)) + for _ in range(display_count - 1): + yield ([], None) + + return + + scroll_at_top = self._scroll_index == 0 + scroll_at_bottom = self._scroll_index + display_count >= len(self._filter_items) + for row in range(display_count): index = self._scroll_index + row focused = index == self._focus_index - at_start = self._scroll_index == 0 - at_end = self._scroll_index + display_count >= len(self._filter_items) - show_more = (not at_start and row == 0) or ( - not at_end and row == display_count - 1 - ) + is_top_ellipsis = row == 0 and not scroll_at_top + is_bottom_ellipsis = row == display_count - 1 and not scroll_at_bottom - try: - item = self._filter_items[index] - self._print_item(item, focused=focused, show_more=show_more) - except IndexError: - self.console.print(justify="left") + if is_top_ellipsis or is_bottom_ellipsis: + yield ([" ..."], "ellipsis") + else: + try: + yield self._render_item(self._filter_items[index], focused=focused) + except IndexError: + yield ([], None) - controls = self.CONTROLS - if self.has_filter: - controls += self.FILTER_CONTROLS - - self.console.print( - controls, - style="controls", - end="", - highlight=False, - no_wrap=True, - overflow="crop", - ) + def _render_item(self, item: T | str, *, focused: bool) -> _MenuRow: + style = "focus" if focused else "unfocus" - def _print_item(self, item: T | str, *, focused: bool, show_more: bool): - style = "ellipsis" if show_more else "focus" if focused else "unfocus" + columns: list[RenderableType] + if isinstance(item, MenuRow): + columns = list(item.__menu_row__()) or [""] + else: + columns = [item] + # The table has larger padding between columns than we want for the + # focused item indicator or indent on unfocused items, so modify the + # value in the first column to contain the indicator/indent instead of + # putting it in a separate column. indent = "> " if focused else " " - item = "..." if show_more else item - - self.console.print( - indent, - item, - sep="", - style=style, - highlight=True, - justify="left", - no_wrap=True, - overflow="ellipsis", - ) + columns[0] = horizontal_group(indent, columns[0]) + + return ([*columns], style) def _clamp_focus_index(self): self._focus_index = min(max(0, self._focus_index), len(self._filter_items) - 1) @@ -370,44 +426,65 @@ def _get_scroll_index(self): items_count = len(self._filter_items) display_count = self._get_display_count() - if items_count < display_count: + if items_count <= display_count: + # There is enough space to show the whole menu without scrolling. return 0 first_displayed = self._scroll_index last_displayed = first_displayed + display_count - 1 + if last_displayed >= items_count: + # There are more items in the menu than available space, but the + # current scroll position would leave blank spaces at the bottom. + # Scroll up enough to fill every row of the menu with an item. + first_displayed = items_count - display_count + last_displayed = first_displayed + display_count - 1 + if self._focus_index <= first_displayed + self.SCROLL_MARGIN: - return max(0, self._focus_index - 1 - self.SCROLL_MARGIN) + # The focused item is off the top of the screen. Scroll up enough to + # get it in view. Also fit as many menu items in as possible so we + # don't get just a couple visible when there's room for more. + start = min( + items_count - display_count, + self._focus_index - 1 - self.SCROLL_MARGIN, + ) + return max(0, start) if self._focus_index >= last_displayed - self.SCROLL_MARGIN: + # Focused item is off the bottom of the screen. Scroll down enough + # to get it in view. end = min(items_count - 1, self._focus_index + 1 + self.SCROLL_MARGIN) return end - (display_count - 1) - return self._scroll_index - - def _move_cursor_to_top(self): - """Move the cursor to the start of the menu""" - terminal.set_cursor_pos(row=self._top_row) + return first_displayed + @contextmanager def _move_cursor_to_filter(self): - """Move the cursor to the filter text field""" - row = self._top_row + self._num_title_lines - 1 - col = self._last_title_line_len + self._cursor_index + """ + Context manager which move the cursor to the filter text field and shows + it, runs the context, then sets the cursor back where it was and hides it. + """ - terminal.set_cursor_pos(row, col) + if not self.has_filter: + yield + return - def _erase_controls(self): - """Hide the controls text and reset the cursor to after the menu""" - row = self.console.height - 1 + orig_x, orig_y = terminal.get_cursor_pos() - terminal.set_cursor_pos(row=row, col=0) - self.console.print(justify="left") + x = self._last_title_line_len + self._cursor_index + y = self._top_row + self._num_title_lines - 1 - terminal.set_cursor_pos(self._top_row + len(self._filter_items) + 1) + try: + self.console.control(Control.move_to(x, y)) + self.console.show_cursor(show=True) + yield + finally: + self.console.show_cursor(show=False) + self.console.control(Control.move_to(orig_x, orig_y)) def show_menu( - title: RenderableType | None, + title: str | None, items: Iterable[T], *, default_index=0, @@ -442,45 +519,26 @@ def show_menu( class Detail(Generic[T]): - """A menu item with a description appended to the end.""" - - MIN_PAD = 2 + """A menu item with a description.""" data: T detail: str - pad_len: int def __init__(self, data: T, detail: str): self.data = data self.detail = detail - self.pad_len = self.MIN_PAD - - def __rich__(self): - text = Text.assemble(str(self.data), " " * self.pad_len, (self.detail, "dim")) - # Returning the Text object directly works, but it doesn't get highlighted. - return text.markup - - @classmethod - def align(cls, items: Iterable[Self], console: Console | None = None) -> list[Self]: - """Set the padding for each item in the list to align the detail strings.""" - items = list(items) - console = console or rich.get_console() - for item in items: - item.pad_len = console.measure(str(item.data)).minimum - - width = max(item.pad_len for item in items) - - for item in items: - item.pad_len = width - item.pad_len + cls.MIN_PAD + def __menu_row__(self) -> Iterable[RenderableType]: + if isinstance(self.data, MenuRow): + yield from self.data.__menu_row__() + else: + yield self.data - return items + yield f"[dim]{self.detail}" -def detail_list( - items: Iterable[tuple[T, str]], console: Console | None = None -) -> list[Detail[T]]: +def detail_list(items: Iterable[tuple[T, str]]) -> list[Detail[T]]: """ - Create a list of menu items with a description appended to each item. + Create a list of menu items with a description next to each item. """ - return Detail.align([Detail(item, desc) for item, desc in items], console=console) + return [Detail(item, desc) for item, desc in items] diff --git a/zmk/terminal.py b/zmk/terminal.py index fc58cb9..fb8220b 100644 --- a/zmk/terminal.py +++ b/zmk/terminal.py @@ -1,5 +1,5 @@ """ -Terminal utilities +Terminal utilities for things not already provided by Rich. """ # Ignore missing attributes for platform-specific modules @@ -13,19 +13,6 @@ from collections.abc import Generator from contextlib import contextmanager - -def hide_cursor() -> None: - """Hides the terminal cursor.""" - sys.stdout.write("\x1b[?25l") - sys.stdout.flush() - - -def show_cursor() -> None: - """Unhides the terminal cursor.""" - sys.stdout.write("\x1b[?25h") - sys.stdout.flush() - - ESCAPE = b"\x1b" BACKSPACE = b"\b" RETURN = b"\n" @@ -168,7 +155,7 @@ def read_key() -> bytes: def get_cursor_pos() -> tuple[int, int]: """ - Returns the cursor position as a tuple (row, column). Positions are 0-based. + Returns the cursor position as a tuple (x, y). Positions are 0-based. """ with disable_echo(): sys.stdout.write("\x1b[6n") @@ -179,13 +166,4 @@ def get_cursor_pos() -> tuple[int, int]: result += sys.stdin.read(1) row, _, col = result.removeprefix("\x1b[").removesuffix("R").partition(";") - return (int(row) - 1, int(col) - 1) - - -def set_cursor_pos(row=0, col=0) -> None: - """ - Sets the cursor to the given row and column. Positions are 0-based. - """ - with disable_echo(): - sys.stdout.write(f"\x1b[{row + 1};{col + 1}H") - sys.stdout.flush() + return (int(col) - 1, int(row) - 1) From e1f0dde8b620a53e6605a92f2f984c690f27cf6b Mon Sep 17 00:00:00 2001 From: Joel Spadin Date: Sat, 7 Mar 2026 19:36:21 -0600 Subject: [PATCH 5/8] feat: Rework hardware data model This rewrites the classes for representing hardware (boards, shields, and interconnects), hardware revisions, and build matrix items to match the design outlined in #47. A new Revision class simplifies the revision handling. It automatically normalizes revisions to the shortest form for numerical revisions and uppercase for alphabetical revisions, and it provides operators for testing equality and ordering. A new BoardTarget class represents Zephyr board targets. This handles parsing board IDs, board qualifiers, and revisions from strings and reassembling them into the expected order with the revision inbetween the ID and qualifiers. The old Keyboard class is renamed to KeyboardComponent, since it represents a component which might or might not be a complete keyboard. A new Keyboard class now manages combining a board with some number of shields to create a keyboard. Code for discovering hardware and prompting the user to select hardware is extracted to a new hardware_list.py file, since hardware.py is already very large. All the "zmk keyboard ..." commands are updated to use the new classes. There are some minor changes to the messages and formatting of the output as well. Fixes #60. --- zmk/build.py | 50 +-- zmk/commands/keyboard/add.py | 222 +++++------ zmk/commands/keyboard/list.py | 71 ++-- zmk/commands/keyboard/new.py | 3 +- zmk/commands/keyboard/remove.py | 14 +- zmk/hardware.py | 670 ++++++++++++++++++-------------- zmk/hardware_list.py | 205 ++++++++++ zmk/revision.py | 156 ++++++++ zmk/styles.py | 29 +- zmk/util.py | 46 ++- 10 files changed, 945 insertions(+), 521 deletions(-) create mode 100644 zmk/hardware_list.py create mode 100644 zmk/revision.py diff --git a/zmk/build.py b/zmk/build.py index e1aaab7..9b8e1e7 100644 --- a/zmk/build.py +++ b/zmk/build.py @@ -3,47 +3,19 @@ """ from collections.abc import Iterable, Mapping, Sequence -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from pathlib import Path from typing import Any, TypeVar, cast, overload import dacite +from .hardware import BoardTarget, BuildItem from .repo import Repo from .yaml import YAML T = TypeVar("T") -@dataclass -class BuildItem: - """An item in the build matrix""" - - board: str - shield: str | None = None - snippet: str | None = None - cmake_args: str | None = None - artifact_name: str | None = None - - def __rich__(self) -> str: - parts = [] - parts.append(self.board) - - if self.shield: - parts.append(self.shield) - - if self.snippet: - parts.append(f"[dim]snippet: {self.snippet}[/dim]") - - if self.artifact_name: - parts.append(f"[dim]artifact-name: {self.artifact_name}[/dim]") - - if self.cmake_args: - parts.append(f"[dim]cmake-args: {self.cmake_args}[/dim]") - - return "[dim], [/dim]".join(parts) - - @dataclass class _BuildMatrixWrapper: include: list[BuildItem] = field(default_factory=list) @@ -71,7 +43,7 @@ def __init__(self, path: Path): self._data = None def write(self) -> None: - """Updated the YAML file, creating it if necessary""" + """Update the YAML file, creating it if necessary""" self._yaml.dump(self._data, self._path) @property @@ -86,7 +58,8 @@ def include(self) -> list[BuildItem]: if not normalized: return [] - wrapper = dacite.from_dict(_BuildMatrixWrapper, normalized) + config = dacite.Config(type_hooks={BoardTarget: BoardTarget.parse}) + wrapper = dacite.from_dict(_BuildMatrixWrapper, normalized, config) return wrapper.include def has_item(self, item: BuildItem) -> bool: @@ -183,12 +156,19 @@ def fix_key(key: str): def _to_yaml(item: BuildItem): """ - Convert a BuildItem to a dict with keys changed back from underscores to hyphens. + Convert a BuildItem to a dict with keys changed back from underscores to hyphens + and values changed to YAML-compatible types. """ def fix_key(key: str): return key.replace("_", "-") - data = asdict(item) + def fix_value(value: Any): + match value: + case BoardTarget(): + return str(value) + + case _: + return value - return {fix_key(k): v for k, v in data.items() if v is not None} + return {fix_key(k): fix_value(v) for k, v in item.__dict__.items() if v is not None} diff --git a/zmk/commands/keyboard/add.py b/zmk/commands/keyboard/add.py index 269df6b..3450f94 100644 --- a/zmk/commands/keyboard/add.py +++ b/zmk/commands/keyboard/add.py @@ -2,31 +2,22 @@ "zmk keyboard add" command. """ -import itertools import shutil from pathlib import Path from typing import Annotated import rich import typer +from rich.padding import Padding -from ...build import BuildItem, BuildMatrix +from ...build import BuildMatrix from ...config import get_config from ...exceptions import FatalError -from ...hardware import ( - Board, - Hardware, - Keyboard, - Shield, - append_revision, - get_hardware, - is_compatible, - normalize_revision, - show_hardware_menu, - show_revision_menu, - split_revision, -) +from ...hardware import Board, BoardTarget, Keyboard, Shield +from ...hardware_list import get_hardware, show_hardware_menu, show_revision_menu from ...repo import Repo +from ...revision import Revision +from ...styles import THEME, BoardIdHighlighter, chain_highlighters from ...util import spinner @@ -55,6 +46,8 @@ def keyboard_add( """Add configuration for a keyboard and add it to the build.""" console = rich.get_console() + console.push_theme(THEME) + console.highlighter = chain_highlighters(console.highlighter, BoardIdHighlighter()) cfg = get_config(ctx) repo = cfg.get_repo() @@ -62,94 +55,111 @@ def keyboard_add( with spinner("Finding hardware..."): hardware = get_hardware(repo) - keyboard = None - controller = None - revision = None + keyboard = Keyboard() if keyboard_id: - keyboard_id, keyboard_revision = split_revision(keyboard_id) + # User specified a keyboard. It might be either a board (with optional + # revision and board qualifiers) or a shield. + target = BoardTarget.parse(keyboard_id) + keyboard_id = target.name + target.qualifiers - keyboard = hardware.find_keyboard(keyboard_id) - if keyboard is None: + keys_component = hardware.find_keyboard(keyboard_id) + if not keys_component: raise KeyboardNotFoundError(keyboard_id) - # If the keyboard ID contained a revision, use that. - # Make sure it is valid before continuing to any other prompts. - if keyboard_revision: - revision = keyboard_revision - _check_revision(keyboard, revision) + keyboard.add_component(keys_component) + if target.revision: + _check_revision(keys_component, target.revision) + keyboard.board_revision = target.revision if controller_id: - if not isinstance(keyboard, Shield): + # User also specified a controller board (with optional revision and + # board qualifiers). If the specified keyboard was already a board, + # then this is invalid because we can't have two boards. + if keyboard.board: raise FatalError( - f'Keyboard "{keyboard.id}" has an onboard controller ' + f'Keyboard "{keys_component.id}" has an onboard controller ' "and does not require a controller board." ) + target = BoardTarget.parse(controller_id) + controller_id = target.name + target.qualifiers + controller = hardware.find_controller(controller_id) - if controller is None: + if not controller: raise ControllerNotFoundError(controller_id) + keyboard.add_component(controller) + if target.revision: + _check_revision(controller, target.revision) + keyboard.board_revision = target.revision + elif controller_id: - controller_id, controller_revision = split_revision(controller_id) + # User specified a controller but not a keyboard. Find the controller. + # We will prompt for the shield later. + + target = BoardTarget.parse(controller_id) - # User specified a controller but not a keyboard. Filter the keyboard - # list to just those compatible with the controller. controller = hardware.find_controller(controller_id) - if controller is None: + if not controller: raise ControllerNotFoundError(controller_id) - # If the controller ID contained a revision, use that. - # Make sure it is valid before continuing to any other prompts. - if controller_revision: - revision = controller_revision - _check_revision(controller, revision) + keyboard.add_component(controller) + if target.revision: + _check_revision(controller, target.revision) + keyboard.board_revision = target.revision - hardware.keyboards = [ - kb - for kb in hardware.keyboards - if isinstance(kb, Shield) and is_compatible(controller, kb) - ] + # When prompting for a keyboard later, it should only display the shields + # that are compatible with the chosen board. + hardware.filter_compatible_keyboards(keyboard) # Prompt the user for any necessary components they didn't specify - if keyboard is None: - keyboard = show_hardware_menu("Select a keyboard:", hardware.keyboards) - - if isinstance(keyboard, Shield): - if controller is None: - hardware.controllers = [ - c for c in hardware.controllers if is_compatible(c, keyboard) - ] - controller = show_hardware_menu( - "Select a controller:", hardware.controllers - ) - - # Sanity check that everything is compatible - if not is_compatible(controller, keyboard): - raise FatalError( - f'Keyboard "{keyboard.id}" is not compatible with controller "{controller.id}"' - ) - - # Check if the controller needs a revision. - revision = _get_revision(controller, revision) - else: - # If the keyboard isn't a shield, it may need a revision. - revision = _get_revision(keyboard, revision) - - name = keyboard.id - if controller: - name += ", " + controller.id - - if revision: - revision = normalize_revision(revision) - name = append_revision(name, revision) - - if _add_keyboard(repo, keyboard, controller, revision): - console.print(f'Added "{name}".') + if not keyboard.keys_component: + keyboard.add_component( + show_hardware_menu("Select a keyboard:", hardware.keyboards) + ) + + if not keyboard.board: + hardware.filter_compatible_controllers(keyboard) + + keyboard.add_component( + show_hardware_menu("Select a controller:", hardware.controllers) + ) + + # Sanity check the resulting hardware compatibility + if not keyboard.board: + raise FatalError( + "Controller board is missing (this is probably a bug in ZMK CLI)." + ) + + if not keyboard.keys_component: + raise FatalError( + "Component with 'keys' feature is missing (this is probably a bug in ZMK CLI)." + ) + + if keyboard.missing_requirements: + raise FatalError( + f'Keyboard "{keyboard.keys_component}" is not compatible with controller "{keyboard.board}". ' + f"Required interconnects are missing: {', '.join(keyboard.missing_requirements)}" + ) + + # If a revision wasn't already set from the command line, the user may need + # to choose a revision. + if not keyboard.board_revision and keyboard.board_revisions: + keyboard.board_revision = show_revision_menu(keyboard.board) + + if added := _add_keyboard(repo, keyboard): + console.print("[title]Added:") + + for item in added: + console.print(Padding.indent(item, 2)) + + console.print() else: + name = keyboard.keys_component.name console.print(f'"{name}" is already in the build matrix.') - keymap_name = keyboard.get_keymap_path(revision).with_suffix("").name + keymap_name = keyboard.get_keymap_path().with_suffix("").name console.print(f'Run "zmk code {keymap_name}" to edit the keymap.') @@ -174,64 +184,28 @@ def _copy_keyboard_file(repo: Repo, path: Path): shutil.copy2(path, dest_path) -def _get_build_items( - keyboard: Keyboard, controller: Board | None, revision: str | None -): - boards = [] - shields = [] - - match keyboard: - case Shield(id=shield_id, siblings=siblings): - if controller is None: - raise FatalError("controller may not be None if keyboard is a shield") - - shields = siblings or [shield_id] - boards = [append_revision(controller.id, revision)] +def _check_revision(board: Board | Shield, revision: Revision): + if not isinstance(board, Board): + raise FatalError(f"{board.id} is a shield. Only boards support revisions.") - case Board(id=board_id, siblings=siblings): - boards = siblings or [board_id] - boards = [append_revision(board, revision) for board in boards] - - case _: - raise FatalError("Unexpected keyboard/controller combination") - - if shields: - return [ - BuildItem(board=b, shield=s) for b, s in itertools.product(boards, shields) - ] - - return [BuildItem(board=b) for b in boards] - - -def _get_revision(board: Hardware, revision: str | None): - # If no revision was specified and the board uses revisions, prompt to - # select a revision. - return revision or show_revision_menu(board) - - -def _check_revision(board: Hardware, revision: str): - if board.has_revision(revision): + if revision in board.revisions: # Revision is OK return - supported_revisions = board.get_revisions() - - if not supported_revisions: + if not board.revisions: raise FatalError(f"{board.id} does not have any revisions.") raise FatalError( f'{board.id} does not support revision "@{revision}". Use one of:\n' - + "\n".join(f" @{normalize_revision(rev)}" for rev in supported_revisions) + + "\n".join(f" @{rev}" for rev in board.revisions) ) -def _add_keyboard( - repo: Repo, keyboard: Keyboard, controller: Board | None, revision: str | None -): - _copy_keyboard_file(repo, keyboard.get_keymap_path(revision)) - _copy_keyboard_file(repo, keyboard.get_config_path(revision)) +def _add_keyboard(repo: Repo, keyboard: Keyboard): + items = keyboard.get_build_items() - items = _get_build_items(keyboard, controller, revision) + _copy_keyboard_file(repo, keyboard.get_keymap_path()) + _copy_keyboard_file(repo, keyboard.get_config_path()) matrix = BuildMatrix.from_repo(repo) added = matrix.append(items) diff --git a/zmk/commands/keyboard/list.py b/zmk/commands/keyboard/list.py index 6853658..553de7a 100644 --- a/zmk/commands/keyboard/list.py +++ b/zmk/commands/keyboard/list.py @@ -14,18 +14,11 @@ from zmk import styles -from ...build import BuildItem, BuildMatrix +from ...build import BuildMatrix from ...config import get_config from ...exceptions import FatalError -from ...hardware import ( - Board, - Hardware, - Shield, - append_revision, - get_hardware, - is_compatible, - normalize_revision, -) +from ...hardware import Board, BoardTarget, BuildItem, Hardware, Keyboard, Shield +from ...hardware_list import get_hardware from ...util import spinner # TODO: allow output as unformatted list @@ -48,7 +41,7 @@ def _list_build_matrix(ctx: typer.Context, *, value: bool): console = Console( highlighter=styles.chain_highlighters( - [styles.BoardIdHighlighter(), styles.CommandLineHighlighter()] + styles.BoardIdHighlighter(), styles.CommandLineHighlighter() ), theme=styles.THEME, ) @@ -64,7 +57,7 @@ def _list_build_matrix(ctx: typer.Context, *, value: bool): has_artifact_name = any(item.artifact_name for item in include) table = Table( - box=box.SQUARE, + box=box.ROUNDED, border_style="dim blue", header_style="bright_cyan", highlight=True, @@ -123,7 +116,7 @@ def keyboard_list( "--board", "-b", metavar="BOARD", - help="List only keyboards compatible with this controller board.", + help="List keyboards compatible with this controller board.", ), ] = None, shield: Annotated[ @@ -132,7 +125,7 @@ def keyboard_list( "--shield", "-s", metavar="SHIELD", - help="List only controllers compatible with this keyboard shield.", + help="List controllers compatible with this keyboard shield.", ), ] = None, interconnect: Annotated[ @@ -141,7 +134,7 @@ def keyboard_list( "--interconnect", "-i", metavar="INTERCONNECT", - help="List only keyboards and controllers that have this interconnect.", + help="List keyboards and controllers that use this interconnect.", ), ] = None, standalone: Annotated[ @@ -173,11 +166,9 @@ def keyboard_list( if item is None: raise FatalError(f'Could not find controller board "{board}".') - groups.keyboards = [ - kb - for kb in groups.keyboards - if isinstance(kb, Shield) and is_compatible(item, kb) - ] + keyboard = Keyboard() + keyboard.add_component(item) + groups.filter_compatible_keyboards(keyboard) list_type = ListType.KEYBOARD elif shield: @@ -189,23 +180,21 @@ def keyboard_list( if not isinstance(item, Shield): raise FatalError(f'Keyboard "{shield}" is a standalone keyboard.') - groups.controllers = [c for c in groups.controllers if is_compatible(c, item)] + keyboard = Keyboard() + keyboard.add_component(item) + groups.filter_compatible_controllers(keyboard) list_type = ListType.CONTROLLER elif interconnect: - # Filter to controllers that provide an interconnect and keyboards that use it. + # Filter to controllers that provide an interconnect and keyboards that + # use or provide it. item = groups.find_interconnect(interconnect) if item is None: raise FatalError(f'Could not find interconnect "{interconnect}".') - groups.controllers = [ - c for c in groups.controllers if c.exposes and item.id in c.exposes - ] - groups.keyboards = [ - kb - for kb in groups.keyboards - if isinstance(kb, Shield) and kb.requires and item.id in kb.requires - ] + groups.filter_to_interconnect(item) + + # When filtering to an interconnect, don't show interconnects. groups.interconnects = [] elif standalone: @@ -214,14 +203,16 @@ def keyboard_list( list_type = ListType.KEYBOARD def print_items(header: str, items: Iterable[Hardware]): - if revisions: - names = [ - append_revision(item.id, normalize_revision(rev)) - for item in items - for rev in (item.get_revisions() or [None]) - ] - else: - names = [item.id for item in items] + names: list[str] = [] + + for item in items: + if revisions and isinstance(item, Board): + names.extend( + str(BoardTarget.parse(item.id).with_revision(rev)) + for rev in item.revisions + ) + else: + names.append(item.id) if not names: return @@ -233,6 +224,10 @@ def print_items(header: str, items: Iterable[Hardware]): console.print(columns) console.print() + # TODO: when filtering to an interconnect, we should specify which hardware + # exposes vs. which requires. This would be useful if we start to add things + # like non-keyboard shields to the hardware info. + if list_type in (ListType.ALL, ListType.KEYBOARD): print_items("Keyboards:", groups.keyboards) diff --git a/zmk/commands/keyboard/new.py b/zmk/commands/keyboard/new.py index d884232..593fabd 100644 --- a/zmk/commands/keyboard/new.py +++ b/zmk/commands/keyboard/new.py @@ -14,7 +14,8 @@ from ...config import get_config from ...exceptions import FatalError -from ...hardware import Interconnect, get_hardware, show_hardware_menu +from ...hardware import Interconnect +from ...hardware_list import get_hardware, show_hardware_menu from ...menu import detail_list, show_menu from ...repo import Repo from ...templates import get_template_files diff --git a/zmk/commands/keyboard/remove.py b/zmk/commands/keyboard/remove.py index e0a2808..e0efb95 100644 --- a/zmk/commands/keyboard/remove.py +++ b/zmk/commands/keyboard/remove.py @@ -2,12 +2,14 @@ "zmk remove" command. """ -import rich import typer +from rich.console import Console +from rich.padding import Padding from ...build import BuildMatrix from ...config import get_config from ...menu import show_menu +from ...styles import MENU_THEME, THEME, BoardIdHighlighter # TODO: add options to select items from command line @@ -19,10 +21,14 @@ def keyboard_remove(ctx: typer.Context) -> None: matrix = BuildMatrix.from_repo(repo) items = matrix.include - item = show_menu("Select a build to remove:", items) + console = Console(theme=THEME, highlighter=BoardIdHighlighter()) + + with console.use_theme(MENU_THEME): + item = show_menu("Select a build to remove:", items, console=console) if removed := matrix.remove(item): - items = ", ".join(f'"{item.__rich__()}"' for item in removed) - rich.print(f"Removed {items} from the build.") + console.print("[title]Removed:") + for item in removed: + console.print(Padding.indent(item, 2)) matrix.write() diff --git a/zmk/hardware.py b/zmk/hardware.py index 170c053..bbae1fa 100644 --- a/zmk/hardware.py +++ b/zmk/hardware.py @@ -1,20 +1,16 @@ -""" -Hardware metadata discovery and processing. -""" - +import itertools import re -from collections.abc import Generator, Iterable -from dataclasses import dataclass, field -from functools import reduce +from collections.abc import Iterable +from dataclasses import dataclass, field, replace from pathlib import Path -from typing import Any, Literal, Self, TypeAlias, TypedDict, TypeGuard, TypeVar +from typing import Literal, Self, TypeAlias, TypedDict, cast import dacite +from rich.console import RenderableType +from rich.measure import Measurement -from .menu import show_menu -from .repo import Repo -from .util import flatten -from .yaml import read_yaml +from .revision import Revision +from .util import horizontal_group, union Feature: TypeAlias = ( Literal["keys", "display", "encoder", "underglow", "backlight", "pointer", "studio"] @@ -32,7 +28,97 @@ class VariantDict(TypedDict): Variant: TypeAlias = str | VariantDict -_HW = TypeVar("_HW", bound="Hardware") + +@dataclass +class BoardTarget: + """ + Zephyr board target. Identifies a unique combination of board name, + board revision, and board qualifiers (SoC, CPU cluster, and variant). + """ + + name: str = "" + """Zephyr Board ID (not the ZMK display name)""" + revision: Revision = field(default_factory=Revision) + """Optional board revision""" + qualifiers: str = "" + """Optional board qualifiers (including forward slashes)""" + + @classmethod + def parse(cls, target: str): + """ + Parse a board target into its parts. + + Examples: + ``` + BoardTarget.parse("nice_nano//zmk") + # returns + BoardTarget(name="nice_nano", revision=Revision(), qualifiers="//zmk") + + BoardTarget.parse("bl5340_dvk@1.2.0/nrf5340/cpuapp/ns") + # returns + BoardTarget( + name="bl5340_dvk", + revision=Revision("1.2.0"), + qualifiers="/nrf5340/cpuapp/ns" + ) + ``` + """ + name, qualifiers = split_board_qualifiers(target) + name, _, revision = name.partition("@") + + return BoardTarget( + name=name, revision=Revision(revision), qualifiers=qualifiers + ) + + def __str__(self): + return self.name + self.revision.at_str + self.qualifiers + + # Can't use just __rich__ here, or this won't display properly in tables. + # https://github.com/Textualize/rich/issues/3188 + + def __rich_console__(self, console, options): + yield str(self) + + def __rich_measure__(self, console, options): + length = len(str(self)) + return Measurement(length, length) + + def with_revision(self, revision: Revision): + """Get a copy of this BoardTarget() with a different revision.""" + return replace(self, revision=revision) + + +@dataclass +class BuildItem: + """Entry in the build.yaml file""" + + board: BoardTarget + shield: str | None = None + snippet: str | None = None + cmake_args: str | None = None + artifact_name: str | None = None + + def __rich__(self): + return horizontal_group( + *(item for item in self.__menu_row__() if item), padding=(0, 2) + ) + + def __menu_row__(self) -> Iterable[RenderableType]: + yield self.board + yield self.shield or "" + + extras = [] + if self.snippet: + extras.append(f"snippet: {self.snippet}") + + if self.artifact_name: + extras.append(f"artifact-name: {self.artifact_name}") + + if self.cmake_args: + extras.append(f"cmake-args: {self.cmake_args}") + + if extras: + yield f"[dim]{', '.join(extras)}" @dataclass @@ -41,382 +127,376 @@ class Hardware: directory: Path """Path to the directory containing this hardware""" + file_format: str | None type: str id: str + """Zephyr identifier for the hardware. Board IDs include board qualifiers.""" name: str + """Display name for the hardware""" + url: str | None + description: str | None + manufacturer: str | None - file_format: str | None = None - url: str | None = None - description: str | None = None - manufacturer: str | None = None - version: str | None = None + @classmethod + def from_dict(cls, data) -> Self: + config = dacite.Config(cast=[set, Revision]) + return dacite.from_dict(data_class=cls, data=data, config=config) def __str__(self) -> str: return self.id - def __rich__(self) -> Any: + def __rich__(self) -> RenderableType: return f"{self.id} [dim]{self.name}" - @classmethod - def from_dict(cls, data) -> Self: - """Read a hardware description from a dict""" - return dacite.from_dict(cls, data) - - def has_id(self, hardware_id: str) -> bool: - """Get whether this hardware has the given ID (case insensitive)""" - return hardware_id.casefold() == self.id.casefold() + def __menu_row__(self) -> Iterable[RenderableType]: + return [self.id, self.name] - def has_revision(self, revision: str) -> bool: + def get_normalized_ids(self) -> list[str]: """ - Get whether this hardware supports the given revision. - - Any empty string is treated as "default revision" and always returns True. + Returns a list of names by which this hardware can be matched, for example + in command line parameters. All results are casefolded. """ - # By default, the only supported revision is no revision at all. - return not revision - - def get_revisions(self) -> list[str]: - """Get a list of supported revisions""" - return [] - - def get_default_revision(self) -> str | None: - """Get the default item from get_revisions or None if no default is set""" - return None + return [self.id.casefold()] @dataclass class Interconnect(Hardware): - """Description of the connection between two pieces of hardware""" + """ + Description of the connection between two pieces of hardware. + + Matches #/$defs/interconnect from + https://github.com/zmkfirmware/zmk/blob/main/schema/hardware-metadata.schema.json + """ node_labels: dict[str, str] = field(default_factory=dict) design_guideline: str | None = None @dataclass -class Keyboard(Hardware): +class KeyboardComponent(Hardware): """Base class for hardware that forms a keyboard""" - siblings: list[str] | None = field(default_factory=list) - """List of board/shield IDs for a split keyboard""" - exposes: list[str] | None = field(default_factory=list) + siblings: list[str] = field(default_factory=list) + """List of board/shield IDs for a split keyboard. Board IDs include board qualifiers""" + exposes: set[str] = field(default_factory=set) """List of interconnect IDs this board/shield provides""" - features: list[Feature] | None = field(default_factory=list) + features: set[Feature] = field(default_factory=set) """List of features this board/shield supports""" - variants: list[Variant] | None = field(default_factory=list) - - def __post_init__(self): - self.siblings = self.siblings or [] - self.exposes = self.exposes or [] - self.features = self.features or [] - self.variants = self.variants or [] + variants: list[Variant] = field(default_factory=list) - def get_config_path(self, revision: str | None = None) -> Path: - """Path to the .conf file for this keyboard""" - return self._get_keyboard_file(".conf", revision) - - def get_keymap_path(self, revision: str | None = None) -> Path: - """Path to the .keymap file for this keyboard""" - return self._get_keyboard_file(".keymap", revision) - - def _get_revision_suffixes(self, revision: str | None = None) -> Generator[str]: - if revision: - for rev in get_revision_forms(revision): - yield "_" + rev.replace(".", "_") - - def _get_keyboard_file(self, extension: str, revision: str | None = None) -> Path: - if revision: - for rev in get_revision_forms(revision): - path = self.directory / f"{self.id}_{rev.replace('.', '_')}{extension}" - if path.exists(): - return path - - return self.directory / f"{self.id}{extension}" + @property + def has_keys(self) -> bool: + """Get whether this hardware has the "keys" feature.""" + return "keys" in self.features @dataclass -class Board(Keyboard): - """Hardware with a processor. May be a keyboard or a controller.""" - - arch: str | None = None - outputs: list[Output] = field(default_factory=list) - """List of methods by which this board supports sending HID data""" - - revisions: list[str] = field(default_factory=list) - default_revision: str | None = None - - def __post_init__(self): - super().__post_init__() - self.outputs = self.outputs or [] - self.revisions = self.revisions or [] - - def has_revision(self, revision: str): - # Empty string means "use default revision" - if not revision: - return True - - revision = normalize_revision(revision) - - return any(normalize_revision(rev) == revision for rev in self.revisions) - - def get_revisions(self): - return self.revisions - - def get_default_revision(self): - return self.default_revision - - -def split_revision(identifier: str) -> tuple[str, str]: +class Board(KeyboardComponent): """ - Splits a string containing a hardware ID and optionally a revision into the - ID and revision. + Description of a Zephyr board. May be a controller or a standalone keyboard. - Examples: - "foo" -> "foo", "" - "foo@2" -> "foo", "2" + Matches #/$defs/board from + https://github.com/zmkfirmware/zmk/blob/main/schema/hardware-metadata.schema.json """ - hardware_id, _, revision = identifier.partition("@") - return hardware_id, revision + arch: str | None = None + outputs: set[Output] = field(default_factory=set) + """List of methods by which this board supports sending HID data""" -def append_revision(identifier: str, revision: str | None): - """ - Joins a hardware ID with a revision string. - - Examples: - "foo" + None -> "foo" - "foo" + "2" -> "foo@2" - """ - return f"{identifier}@{revision}" if revision else identifier - + revisions: list[Revision] = field(default_factory=list) + default_revision: Revision = field(default_factory=Revision) -def normalize_revision(revision: str | None) -> str: - """ - Normalizes letter revisions to uppercase and shortens numeric versions to - the smallest form with the same meaning. + def get_normalized_ids(self) -> list[str]: + """ + Returns a list of names by which this hardware can be matched, for example + in command line parameters. All results are casefolded. - Examples: - "a" -> "A" - "1.2.0" -> "1.2" - "2.0.0" -> "2" - """ - if not revision: - return "" + To make specifying board IDs as command line parameters easier, the "zmk" + board qualifier is optional, e.g. "nice_nano" matches a board with + `id="nice_nano//zmk"`, and "nrfmicro/nrf52840" matches a board with + `id="nrfmicro/nrf52840/zmk"`. + """ + norm_id = self.id.casefold() - return re.sub(r"(?:\.0){1,2}$", "", revision).upper() + result = [norm_id] + if norm_id.endswith("//zmk"): + result.append(norm_id.removesuffix("//zmk")) + elif norm_id.endswith("/zmk"): + result.append(norm_id.removesuffix("/zmk")) + return result -def get_revision_forms(revision: str) -> list[str]: - """ - Returns a list of all equivalent spellings of a revision. - Examples: - "a" -> ["A", "a"] - "1.2.3" -> ["1.2.3"] - "1.2.0" -> ["1.2", "1.2.0"] - "2.0.0" -> ["2", "2.0", "2.0.0"] +@dataclass +class Shield(KeyboardComponent): """ - revision = normalize_revision(revision) - - if revision.isalpha(): - return [revision.upper(), revision.lower()] - - result = [] - - dot_count = revision.count(".") - if dot_count == 0: - result.append(revision + ".0.0") - if dot_count <= 1: - result.append(revision + ".0") - - result.append(revision) + Description of a Zephyr shield. May be a keyboard or a peripheral. - return result + Matches #/$defs/shield from + https://github.com/zmkfirmware/zmk/blob/main/schema/hardware-metadata.schema.json + """ -@dataclass -class Shield(Keyboard): - """Hardware that attaches to a board. May be a keyboard or a peripheral.""" + requires: set[str] = field(default_factory=set) + """List of interconnects this shield requires to be attached to""" - requires: list[str] | None = field(default_factory=list) - """List of interconnects to which this shield attaches""" - def __post_init__(self): - super().__post_init__() - self.requires = self.requires or [] +class IncompleteKeyboardError(Exception): + pass @dataclass -class GroupedHardware: - """Hardware grouped by type.""" - - keyboards: list[Keyboard] = field(default_factory=list) - """List of boards/shields that are keyboard PCBs""" - controllers: list[Board] = field(default_factory=list) - """List of boards that are controllers for keyboards""" - interconnects: list[Interconnect] = field(default_factory=list) - """List of interconnect descriptions""" - - # TODO: add displays and other peripherals? - - def find_keyboard(self, item_id: str) -> Keyboard | None: - """Find a keyboard by ID""" - item_id = item_id.casefold() - return next((i for i in self.keyboards if i.id.casefold() == item_id), None) - - def find_controller(self, item_id: str) -> Board | None: - """Find a controller by ID""" - item_id = item_id.casefold() - return next((i for i in self.controllers if i.id.casefold() == item_id), None) - - def find_interconnect(self, item_id: str) -> Interconnect | None: - """Find an interconnect by ID""" - item_id = item_id.casefold() - return next( - (i for i in self.interconnects if i.id.casefold() == item_id), - None, - ) +class Keyboard: + """ + Collection of information needed to determine how to build keyboard firmware. + This consists of: + - A board + - Optionally, a specific board revision to use + - Optionally, some number of shields. -def is_keyboard(hardware: Hardware) -> TypeGuard[Keyboard]: - """Test whether an item is a keyboard (board or shield supporting keys)""" - match hardware: - case Keyboard(features=feat) if feat and "keys" in feat: - return True + At least one item between the board and shields is required to have the "keys" + before the keyboard is considered "complete" and can be used to build firmware. + """ - case _: - return False + board: Board | None = None + """The controller board""" + board_revision: Revision = field(default_factory=Revision) + """The board revision selected to build""" + shields: list[Shield] = field(default_factory=list) + """List of shields to attach to the board""" + @property + def board_targets(self) -> list[BoardTarget]: + """ + The Zephyr board target(s) from board and board_revision. -def is_controller(hardware: Hardware) -> TypeGuard[Board]: - """Test whether an item is a keyboard controller (board which isn't a keyboard)""" - return isinstance(hardware, Board) and not is_keyboard(hardware) + If board.siblings is not empty, this returns one item per sibling. + Otherwise, it returns a single item for board.id. + :raises IncompleteKeyboardError: if board is not set. + """ + if not self.board: + raise IncompleteKeyboardError("Cannot get board_target when board is None") -def is_interconnect(hardware: Hardware) -> TypeGuard[Interconnect]: - """Test whether an item is an interconnect description""" - return isinstance(hardware, Interconnect) + board_ids = self.board.siblings or [self.board.id] + revision = self.board_revision or self.board.default_revision + return [ + BoardTarget.parse(board_id).with_revision(revision) + for board_id in board_ids + ] -def is_compatible( - base: Board | Shield | Iterable[Board | Shield], shield: Shield -) -> bool: - """ - Get whether a shield can be attached to the given hardware. + @property + def board_revisions(self) -> list[Revision]: + """The board's supported revisions.""" + return self.board.revisions if self.board else [] - This simply checks whether all the interconnects required by "shield" are - provided by the hardware in "base". If "base" is a list of hardware, it does - not account for the fact that one of the items in "base" may already be using - an interconnect provided by another item. - """ + @property + def keys_component(self) -> Board | Shield | None: + """The first item in [board, *shields] which has the "keys" feature.""" + if self.board and self.board.has_keys: + return self.board - if not shield.requires: - return True + return next((s for s in self.shields if s.has_keys), None) - base = [base] if isinstance(base, Keyboard) else base - exposed = flatten(b.exposes for b in base) + @property + def exposes(self) -> set[str]: + """Set of interconnects exposed by the board and shields.""" + board_exposes = self.board.exposes if self.board else set() - return all(ic in exposed for ic in shield.requires) + return board_exposes | union(s.exposes for s in self.shields) + @property + def requires(self) -> set[str]: + """Set of interconnects required by shields.""" + return union(s.requires for s in self.shields) -def get_board_roots(repo: Repo) -> Iterable[Path]: - """Get the paths that contain hardware definitions for a repo""" - roots = set() + @property + def missing_requirements(self) -> set[str]: + """ + Gets any interconnects in self.requires that are not satisfied by + self.exposes. - if root := repo.board_root: - roots.add(root) + This does not attempt to account for multiple instances of the same + interconnect. + """ + return self.requires - self.exposes - for module in repo.get_modules(): - if root := module.board_root: - roots.add(root) + @property + def revisions(self) -> list[Revision]: + """List of available board revisions""" + return self.board.revisions if self.board else [] - return roots + @property + def default_revision(self) -> Revision: + """Board revision that will be used if one isn't explicitly selected""" + return self.board.default_revision if self.board else Revision() + def add_component(self, component: KeyboardComponent): + """Add a board or shield to the keyboard.""" + match component: + case Board(): + self.board = component -def get_hardware(repo: Repo) -> GroupedHardware: - """Get lists of hardware descriptions, grouped by type, for a repo""" - hardware = flatten(_find_hardware(root) for root in get_board_roots(repo)) + case Shield(): + self.shields.append(component) - def func(groups: GroupedHardware, item: Hardware): - if is_keyboard(item): - groups.keyboards.append(item) - elif is_controller(item): - groups.controllers.append(item) - elif is_interconnect(item): - groups.interconnects.append(item) + case _: + raise TypeError("Unknown component type") - return groups + def is_compatible(self, component: Board | Shield): + """ + Get whether a board or shield can be attached to the keyboard and all + interconnect requirements would be satisfied. - groups = reduce(func, hardware, GroupedHardware()) + If given a shield, this checks whether all the interconnects required by + the shield are provided by the board and/or shields already in the keyboard. - groups.controllers = sorted(groups.controllers, key=lambda x: x.id) - groups.keyboards = sorted(groups.keyboards, key=lambda x: x.id) - groups.interconnects = sorted(groups.interconnects, key=lambda x: x.id) + If given a board, this checks whether adding the board would satisfy all + the requirements for the shields already in the keyboard. + """ + match component: + case Board(): + # This assumes we're replacing any existing board, so determine + # the new set of exposed interconnects with the new board and + # without any existing board. + new_exposes = component.exposes | union(s.exposes for s in self.shields) + return not (self.requires - new_exposes) + + case Shield(): + return not (component.requires - self.exposes) + + def get_build_items(self) -> list[BuildItem]: + """ + Get the individual builds needed to make the firmware. This currently + accounts for splits but not more complex configurations where each build + may need different options. + + Each build item will contain the board and every shield. If any board or + shield has siblings, then this will return one item per possible + combination of siblings. For example: + ``` + self.board = Board(id="nice_nano//zmk", default_revision="2") + self.shields = [Shield(id="two_percent_milk")] + # returns + [BuildItem(board_target="nice_nano@2//zmk", shield="two_percent_milk")] + + self.board = Board(id="nice_nano//zmk", default_revision="2") + self.shields = [Shield(siblings=["a_left", "a_right"]), Shield(id="display")] + # returns + [ + BuildItem(board_target="nice_nano@2//zmk", shield="a_left display"), + BuildItem(board_target="nice_nano@2//zmk", shield="a_right display"), + ] + ``` + + :raises IncompleteKeyboardError: if board is not set or there is no + component with the "keys" feature. + """ + if not self.board or not self.keys_component: + raise IncompleteKeyboardError( + "Cannot get build items for an incomplete keyboard" + ) - return groups + build_items: list[BuildItem] = [] + shield_lists = (shield.siblings or [shield.id] for shield in self.shields) + for combination in itertools.product(self.board_targets, *shield_lists): + # Python's type system can't represent that the first item is always a + # BoardTarget and the rest are always str, so explicit casts are needed. + target = cast("BoardTarget", combination[0]) + shields = cast("tuple[str, ...]", combination[1:]) + shield_str = " ".join(shields) if shields else None -def _find_hardware(path: Path) -> Generator[Hardware, None, None]: - for meta_path in path.rglob("*.zmk.yml"): - meta = read_yaml(meta_path) - meta["directory"] = meta_path.parent + build_items.append(BuildItem(board=target, shield=shield_str)) - match meta.get("type"): - case "board": - yield Board.from_dict(meta) + return build_items - case "shield": - yield Shield.from_dict(meta) + def get_config_path(self) -> Path: + """ + Get the path to the keyboard's .conf file. - case "interconnect": - yield Interconnect.from_dict(meta) + :raises IncompleteKeyboardError: if there is no component with the "keys" feature. + """ + return _get_keyboard_file(self.keys_component, self.board_revision, ".conf") + def get_keymap_path(self) -> Path: + """ + Get the path to the keyboard's .keymap file. -def _filter_hardware(item: Hardware, text: str): - text = text.casefold().strip() - return text in item.id.casefold() or text in item.name.casefold() + :raises IncompleteKeyboardError: if there is no component with the "keys" feature. + """ + return _get_keyboard_file(self.keys_component, self.board_revision, ".keymap") -def show_hardware_menu( - title: str, - items: Iterable[_HW], - **kwargs, -) -> _HW: +def split_board_qualifiers(identifier: str) -> tuple[str, str]: """ - Show a menu to select from a list of Hardware objects. + Splits a string into a board ID and board qualifiers. If the string contains + a revision, it is ignored and returned as part of the first value. - kwargs are passed through to zmk.menu.show_menu(), except for filter_func, - which is set to a function appropriate for filtering Hardware objects. + Examples: + "foo" -> "foo", "" + "foo/bar/baz" -> "foo", "/bar/baz" + "foo@2/bar/baz" -> "foo@2", "/bar/baz" """ - return show_menu(title=title, items=items, **kwargs, filter_func=_filter_hardware) + try: + index = identifier.index("/") + return identifier[0:index], identifier[index:] + except ValueError: + return identifier, "" -def show_revision_menu( - board: Hardware, title: str | None = None, **kwargs -) -> str | None: +def _get_filename(name: str): """ - Show a menu to select from a list of revisions for a board. - - If the board has no revisions, returns None without showing a menu. - If the board has only one revision, returns it without showing a menu. - - kwargs are passed through to zmk.menu.show_menu(), except for default_index, - which is set based on the board's default revision. + Replaces all special characters used in revisions and board qualifiers that + are not valid in a filename with underscores. """ + return re.sub(r"[./]+", "_", name) + - revisions = board.get_revisions() - if not revisions: - return None +def _get_keyboard_file( + keys_component: Board | Shield | None, revision: Revision, suffix: str +) -> Path: + if not keys_component: + raise IncompleteKeyboardError("Cannot get file path for an incomplete keyboard") - if len(revisions) == 1: - return revisions[0] + search_names: list[str] = [keys_component.id] + search_revisions: list[str] = [] + directory = keys_component.directory - default_revision = board.get_default_revision() - default_index = revisions.index(default_revision) if default_revision else 0 + if isinstance(keys_component, Board): + # If the keyboard has board qualifiers, also search without them. + target = BoardTarget.parse(keys_component.id) + if target.qualifiers: + search_names.append(target.name) - return show_menu( - title=title or f"Select a {board.name} revision:", - items=revisions, - default_index=default_index, - **kwargs, - ) + # If the keyboard has revisions, search for files with each possible + # spelling of the revision. + revision = revision or keys_component.default_revision + if revision: + search_revisions = revision.get_spellings() + + # Combine everything into a list of paths to search from most to least + # specific. + search_paths: list[Path] = [] + + for name in search_names: + for rev in search_revisions: + # Despite Zephyr board targets having the revision before the + # qualifiers, paths have the revision at the end for some reason. + path = (directory / _get_filename(name + "_" + rev)).with_suffix(suffix) + search_paths.append(path) + + path = (directory / _get_filename(name)).with_suffix(suffix) + search_paths.append(path) + + for path in search_paths: + if path.exists(): + return path + + # If none of these files exists, then just return the most general of + # the possible paths. For a shield, this will just be the shield ID itself. + # For a board, it will be the ID without revisions or qualifiers. + return search_paths[-1] diff --git a/zmk/hardware_list.py b/zmk/hardware_list.py new file mode 100644 index 0000000..d590054 --- /dev/null +++ b/zmk/hardware_list.py @@ -0,0 +1,205 @@ +from collections.abc import Generator, Iterable +from dataclasses import dataclass, field +from functools import reduce +from pathlib import Path +from typing import TypeVar + +from rich.console import Console + +from .hardware import Board, BoardTarget, Hardware, Interconnect, Keyboard, Shield +from .menu import show_menu +from .repo import Repo +from .revision import Revision +from .styles import MENU_THEME, BoardIdHighlighter +from .util import flatten +from .yaml import read_yaml + +_HW = TypeVar("_HW", bound=Hardware) + + +@dataclass +class HardwareGroups: + """Hardware grouped by type.""" + + keyboards: list[Board | Shield] = field(default_factory=list) + """List of boards/shields that are keyboard PCBs""" + controllers: list[Board] = field(default_factory=list) + """List of boards that are controllers for keyboards""" + interconnects: list[Interconnect] = field(default_factory=list) + """List of interconnect descriptions""" + + # TODO: add displays and other peripherals? + + def find_keyboard(self, item_id: str) -> Board | Shield | None: + """Find a keyboard by ID""" + # Ignore any board revision provided + target = BoardTarget.parse(item_id).with_revision(Revision()) + return _find_by_id(self.keyboards, str(target)) + + def find_controller(self, item_id: str) -> Board | None: + """Find a controller by ID""" + # Ignore any board revision provided + target = BoardTarget.parse(item_id).with_revision(Revision()) + return _find_by_id(self.controllers, str(target)) + + def find_interconnect(self, item_id: str) -> Interconnect | None: + """Find an interconnect by ID""" + return _find_by_id(self.interconnects, item_id) + + def filter_compatible_keyboards(self, keyboard: Keyboard): + """ + Modifies the "keyboards" list so it contains only shields compatible with + the boards and shields in a given keyboard object. + """ + self.keyboards = [ + kb + for kb in self.keyboards + if isinstance(kb, Shield) and keyboard.is_compatible(kb) + ] + + def filter_compatible_controllers(self, keyboard: Keyboard): + """ + Modifies the "controllers" list so it contains only boards compatible + with the shields on a given keyboard object. + """ + self.controllers = [c for c in self.controllers if keyboard.is_compatible(c)] + + def filter_to_interconnect(self, interconnect: Interconnect): + """ + Modifies the "controllers" list so it contains only boards that expose + the given interconnect. + + Modifies the "keyboards" list so it contains only shields that require + the given interconnect and boards/shields that provide it. + """ + self.controllers = [c for c in self.controllers if interconnect.id in c.exposes] + self.keyboards = [ + kb + for kb in self.keyboards + if interconnect.id in kb.exposes + or (isinstance(kb, Shield) and interconnect.id in kb.requires) + ] + + +def _find_by_id(hardware: Iterable[_HW], item_id: str) -> _HW | None: + norm_id = item_id.casefold() + return next((hw for hw in hardware if norm_id in hw.get_normalized_ids()), None) + + +def _find_hardware(path: Path) -> Generator[Hardware, None, None]: + for meta_path in path.rglob("*.zmk.yml"): + meta = read_yaml(meta_path) + meta["directory"] = meta_path.parent + + match meta.get("type"): + case "board": + yield Board.from_dict(meta) + + case "shield": + yield Shield.from_dict(meta) + + case "interconnect": + yield Interconnect.from_dict(meta) + + +def get_board_roots(repo: Repo) -> Iterable[Path]: + """Get the paths that contain hardware definitions for a repo""" + roots = set() + + if root := repo.board_root: + roots.add(root) + + for module in repo.get_modules(): + if root := module.board_root: + roots.add(root) + + return roots + + +def get_hardware(repo: Repo) -> HardwareGroups: + """Get lists of hardware descriptions, grouped by type, for a repo""" + hardware = flatten(_find_hardware(root) for root in get_board_roots(repo)) + + def func(groups: HardwareGroups, item: Hardware): + if isinstance(item, (Shield, Board)) and item.has_keys: + groups.keyboards.append(item) + elif isinstance(item, Board): + groups.controllers.append(item) + elif isinstance(item, Interconnect): + groups.interconnects.append(item) + + return groups + + groups = reduce(func, hardware, HardwareGroups()) + + groups.controllers = sorted(groups.controllers, key=lambda x: x.id) + groups.keyboards = sorted(groups.keyboards, key=lambda x: x.id) + groups.interconnects = sorted(groups.interconnects, key=lambda x: x.id) + + return groups + + +def show_hardware_menu( + title: str, + items: Iterable[_HW], + console: Console | None = None, + **kwargs, +) -> _HW: + """ + Show a menu to select from a list of Hardware objects. + + kwargs are passed through to zmk.menu.show_menu(), except for filter_func, + which is set to a function appropriate for filtering Hardware objects. + """ + if console is None: + console = Console(theme=MENU_THEME, highlighter=BoardIdHighlighter()) + + def filter_hardware(item: Hardware, text: str): + text = text.casefold().strip() + return text in item.id.casefold() or text in item.name.casefold() + + return show_menu( + title=title, + items=items, + console=console, + filter_func=filter_hardware, + **kwargs, + ) + + +def show_revision_menu( + board: Board, title: str | None = None, console: Console | None = None, **kwargs +) -> Revision: + """ + Show a menu to select from a list of revisions for a board. + + If the board has no revisions, returns Revision() without showing a menu. + If the board has only one revision, returns it without showing a menu. + + kwargs are passed through to zmk.menu.show_menu(), except for default_index, + which is set based on the board's default revision. + """ + + # Revisions could be listed in the .zmk.yml file in any order. Sort them + # descending so the latest revisions appear at the top. + revisions = sorted(board.revisions, reverse=True) + if not revisions: + return Revision() + + if len(revisions) == 1: + return revisions[0] + + default_revision = board.default_revision + default_index = revisions.index(default_revision) if default_revision else 0 + + if console is None: + # Disable the default highlighter so it doesn't colorize numbers + console = Console(highlighter=None) + + return show_menu( + title=title or f"Select a {board.name} revision:", + items=revisions, + default_index=default_index, + console=console, + **kwargs, + ) diff --git a/zmk/revision.py b/zmk/revision.py new file mode 100644 index 0000000..f197505 --- /dev/null +++ b/zmk/revision.py @@ -0,0 +1,156 @@ +import re + + +class Revision: + """ + Represents a Zephyr board revision. + + Revisions are automatically normalized, so comparing two objects for equality + works even if the same revision is spelled differently, e.g. + ``` + Revision("1") == Revision("1.1") # false + Revision("1") == Revision("1.0") # true + Revision("1.0.0") in {Revision("1"), Revision("2")} # true + ``` + + `Revision()` represents the lack of a revision. Truth testing a revision with + an empty string value returns false, so `Revision | None` is not needed to + represent an optional revision. + """ + + def __init__(self, value: str = ""): + self._value = _normalize(value) + + def __bool__(self): + return bool(self._value) + + def __hash__(self): + return hash(self._value) + + def __eq__(self, other): + return isinstance(other, Revision) and other._value == self._value + + def __lt__(self, other): + if not isinstance(other, Revision): + raise TypeError() + return _revision_lt(self, other) + + def __str__(self): + return self._value + + def __rich__(self): + return str(self) + + def __repr__(self): + return f'Revision("{self._value})"' if self else "Revision()" + + def __menu_row__(self): + # In menus, use the most specific spelling so adjacent items don't have + # varying widths. + return [self.specific_str] + + @property + def at_str(self) -> str: + """ + The revision prefixed with "@" if this object has a value, else "". + """ + return f"@{self}" if self else "" + + @property + def specific_str(self) -> str: + """ + The most specific spelling of this revision, e.g. Revision("1") -> "1.0.0" + """ + if not self: + return "" + + if self._value.isalpha(): + return self._value + + result = self._value + for _ in range(2 - self._value.count(".")): + result += ".0" + + return result + + def get_spellings(self) -> list[str]: + """ + Returns a list of all equivalent spellings of this revision. This may be + necessary for example when you have a normalized revision and need to find + board files that apply to it (e.g. board_1_0_0.overlay). + + Return values are ordered from most to least specific. + + Examples: + ``` + "A" -> ["A", "a"] + "1" -> ["1.0.0", "1.0", "1"] + "1.2" -> ["1.2.0", "1.2"] + "1.2.3" -> ["1.2.3"] + ``` + """ + if not self: + return [""] + + if self._value.isalpha(): + return [self._value, self._value.lower()] + + value = self.specific_str + result = [value] + + while value.endswith(".0"): + value = value[:-2] + result.append(value) + + return result + + +def _normalize(revision: str) -> str: + """ + Normalizes letter revisions to uppercase and shortens numeric versions to + the smallest form with the same meaning. + + Examples: + ``` + "" -> "" + "a" -> "A" + "1.2.0" -> "1.2" + "2.0.0" -> "2" + ``` + """ + return re.sub(r"(?:\.0){1,2}$", "", revision.strip()).upper() + + +def _to_number_list(revision: str) -> list[int]: + """ + Splits a revision string into a list of numbers. Must not be an alphabetical + revision. + """ + return [int(part) for part in revision.split(".")] + + +def _revision_lt(lhs: Revision, rhs: Revision) -> bool: + """ """ + lhs_str = lhs.specific_str + rhs_str = rhs.specific_str + + # Place Revision() before Revision("...") + if not lhs_str: + return bool(rhs_str) + + if not rhs_str: + return False + + # Place Revision("A") before Revision("1"). + if lhs_str.isalpha(): + if rhs_str.isalpha(): + # Both revisions are alphabetical. Can compare directly. + return lhs_str < rhs_str + + return True + + if rhs_str.isalpha(): + return False + + # Both revisions are numerical. Can compare as a list of numbers. + return _to_number_list(lhs_str) < _to_number_list(rhs_str) diff --git a/zmk/styles.py b/zmk/styles.py index 29dcc6c..dc39d1c 100644 --- a/zmk/styles.py +++ b/zmk/styles.py @@ -11,6 +11,8 @@ class KeyValueHighlighter(RegexHighlighter): """Highlight "key=value" items.""" + base_style = "kv." + highlights = [ # noqa: RUF012 https://github.com/astral-sh/ruff/issues/5429 r"(?P[\w.]+)(?P=)(?P.*)" ] @@ -22,7 +24,7 @@ class BoardIdHighlighter(RegexHighlighter): base_style = "board." highlights = [ # noqa: RUF012 https://github.com/astral-sh/ruff/issues/5429 - r"(?P[/])", + r"(?P/[a-zA-Z_]?\w*/[a-zA-Z_]\w*\b)", r"(?P@(?:[A-Z]|([0-9]+(\.[0-9]+){0,2})))", ] @@ -33,30 +35,39 @@ class CommandLineHighlighter(RegexHighlighter): base_style = "cmd." highlights = [ # noqa: RUF012 https://github.com/astral-sh/ruff/issues/5429 - r"(?P-[a-zA-Z]|--[a-zA-Z-_]+)" + r"(?-[a-zA-Z]|--[a-zA-Z-_]+)" ] THEME = Theme( { - "key": "bright_blue", - "equals": "dim blue", + "title": "bright_magenta", + "kv.key": "bright_blue", + "kv.equals": "dim blue", "value": "default", - "board.separator": "dim", - "board.revision": "sky_blue2", + "board.revision": "rgb(135,95,175)", + "board.qualifier": "rgb(135,95,175)", "cmd.flag": "dim", } ) +# Don't use colors in menus, as that will override the focus style. +MENU_THEME = Theme({"board.revision": "dim", "board.qualifier": "dim"}) + + +def chain_highlighters(*highlighters: HighlighterType | None) -> HighlighterType: + """ + Return a new highlighter which runs each of the given highlighters in order. -def chain_highlighters(highlighters: list[HighlighterType]) -> HighlighterType: - """Return a new highlighter which runs each of the given highlighters in order""" + Arguments that are None will be skipped. + """ def run_all_highlighters(text: str | Text): text = text if isinstance(text, Text) else Text(text) for item in highlighters: - text = item(text) + if item: + text = item(text) return text diff --git a/zmk/util.py b/zmk/util.py index ff38f50..6904b83 100644 --- a/zmk/util.py +++ b/zmk/util.py @@ -4,14 +4,14 @@ import functools import operator -import os -from collections.abc import Generator, Iterable +from collections.abc import Iterable from contextlib import contextmanager -from pathlib import Path from typing import TypeVar -from rich.console import Console +from rich.console import Console, RenderableType +from rich.padding import PaddingDimensions from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.table import Table T = TypeVar("T") @@ -21,7 +21,12 @@ def flatten(items: Iterable[T | Iterable[T]]) -> Iterable[T]: return functools.reduce(operator.iconcat, items, []) -def splice(text: str, index: int, count: int = 0, insert_text: str = ""): +def union(items: Iterable[set[T]]) -> set[T]: + """Compute the union of any number of sets""" + return functools.reduce(operator.or_, items, set()) + + +def splice(text: str, index: int, count: int = 0, insert_text: str = "") -> str: """ Remove `count` characters starting from `index` in `text` and replace them with `insert_text`. @@ -29,16 +34,27 @@ def splice(text: str, index: int, count: int = 0, insert_text: str = ""): return text[0:index] + insert_text + text[index + count :] -@contextmanager -def set_directory(path: Path) -> Generator[None, None, None]: - """Context manager to temporarily change the working directory""" - original = Path().absolute() - - try: - os.chdir(path) - yield - finally: - os.chdir(original) +def horizontal_group( + *renderables: RenderableType, + padding: PaddingDimensions = 0, + collapse_padding=True, + pad_edge=False, + expand=False, + highlight=True, +) -> Table: + """ + Similar to rich.group.Group, but uses a table to place renderables in a row + instead of a column. + """ + grid = Table.grid( + padding=padding, + collapse_padding=collapse_padding, + pad_edge=pad_edge, + expand=expand, + ) + grid.highlight = highlight + grid.add_row(*renderables) + return grid @contextmanager From 1c2e43c32d0c20f62a6cc488a06e731836552397 Mon Sep 17 00:00:00 2001 From: Joel Spadin Date: Sat, 7 Mar 2026 19:40:56 -0600 Subject: [PATCH 6/8] fix: Don't show error if fsmonitor--daemon not running Fixed an issue where "zmk module remove" would display an error when attempting to stop Git's fsmonitor--daemon so it can delete the module, but fsmonitor--daemon was not running. We now silently ignore the error. --- zmk/commands/module/remove.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/zmk/commands/module/remove.py b/zmk/commands/module/remove.py index 16ae8f0..5fb6334 100644 --- a/zmk/commands/module/remove.py +++ b/zmk/commands/module/remove.py @@ -80,7 +80,12 @@ def remove_readonly(func, path, _): with spinner("Deleting module files."): try: # Make sure Git isn't locking the folder first. - subprocess.call(["git", "fsmonitor--daemon", "stop"], cwd=module_path) + subprocess.call( + ["git", "fsmonitor--daemon", "stop"], + cwd=module_path, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) shutil.rmtree(module_path, onexc=remove_readonly) except FileNotFoundError: From aa32bf993f1fe47f059e0d00168e878a4ceef20a Mon Sep 17 00:00:00 2001 From: Joel Spadin Date: Sun, 8 Mar 2026 12:14:07 -0500 Subject: [PATCH 7/8] fix: Disable menu filtering on legacy Windows console Calling terminal.get_cursor_pos() when using the legacy Windows Console Host resulted in the program hanging waiting for a response on stdin to an escape sequence that the terminal didn't respond to. Enabling VT mode temporarily made it respond and provide a cursor position, but either the return values differed from how a modern terminal would report them or this interacted strangely with menu rendering in some other way, as this resulted in parts of menus being rendered to the wrong locations on any menu with a filter text field. To work around this, we now simply detect whether VT mode is enabled, and if it isn't, we disable the filter text field. This allows menus to mostly work, just without the ability to search by typing. --- zmk/menu.py | 21 +++++++++++++----- zmk/terminal.py | 58 ++++++++++++++++++++----------------------------- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/zmk/menu.py b/zmk/menu.py index 25e5561..f4a9f6a 100644 --- a/zmk/menu.py +++ b/zmk/menu.py @@ -133,11 +133,17 @@ def __init__( self._num_title_lines = 0 self._last_title_line_len = 0 - if self._get_display_count() == self._max_items_per_page: - self._top_row = 1 + if terminal.cursor_control_supported(): + if self._get_display_count() == self._max_items_per_page: + self._top_row = 1 + else: + _, y = terminal.get_cursor_pos() + self._top_row = min(y, self.console.height - self._get_menu_height()) else: - _, y = terminal.get_cursor_pos() - self._top_row = min(y, self.console.height - self._get_menu_height()) + # If get_cursor_pos() is unsupported, then we can't move the cursor + # accurately between the end of the menu and the filter text. Disable + # the filter feature so the menu still mostly works. + self._filter_func = None self._apply_filter() @@ -381,7 +387,7 @@ def _handle_input(self): return False def _handle_backspace(self): - if self._cursor_index == 0: + if not self.has_filter or self._cursor_index == 0: return self._filter_text = splice(self._filter_text, self._cursor_index - 1, count=1) @@ -389,13 +395,16 @@ def _handle_backspace(self): self._apply_filter() def _handle_delete(self): - if self._cursor_index == len(self._filter_text): + if not self.has_filter or self._cursor_index == len(self._filter_text): return self._filter_text = splice(self._filter_text, self._cursor_index, count=1) self._apply_filter() def _handle_text(self, key: bytes): + if not self.has_filter: + return + text = key.decode() self._filter_text = splice( self._filter_text, self._cursor_index, insert_text=text diff --git a/zmk/terminal.py b/zmk/terminal.py index fb8220b..b56c7cf 100644 --- a/zmk/terminal.py +++ b/zmk/terminal.py @@ -35,14 +35,7 @@ _STD_INPUT_HANDLE = -10 _STD_OUTPUT_HANDLE = -11 - _ENABLE_PROCESSED_OUTPUT = 1 - _ENABLE_WRAP_AT_EOL_OUTPUT = 2 _ENABLE_VIRTUAL_TERMINAL_PROCESSING = 4 - _VT_FLAGS = ( - _ENABLE_PROCESSED_OUTPUT - | _ENABLE_WRAP_AT_EOL_OUTPUT - | _ENABLE_VIRTUAL_TERMINAL_PROCESSING - ) _WINDOWS_SPECIAL_KEYS = { 71: HOME, @@ -76,25 +69,6 @@ def read_key() -> bytes: return key - @contextmanager - def enable_vt_mode() -> Generator[None, None, None]: - """ - Context manager which enables virtual terminal processing. - """ - kernel32 = windll.kernel32 - stdout_handle = kernel32.GetStdHandle(_STD_OUTPUT_HANDLE) - - old_stdout_mode = wintypes.DWORD() - kernel32.GetConsoleMode(stdout_handle, byref(old_stdout_mode)) - - new_stdout_mode = old_stdout_mode.value | _VT_FLAGS - - try: - kernel32.SetConsoleMode(stdout_handle, new_stdout_mode) - yield - finally: - kernel32.SetConsoleMode(stdout_handle, old_stdout_mode) - @contextmanager def disable_echo() -> Generator[None, None, None]: """ @@ -112,16 +86,22 @@ def disable_echo() -> Generator[None, None, None]: finally: kernel32.SetConsoleMode(stdin_handle, old_stdin_mode) -except ImportError: - import termios - - @contextmanager - def enable_vt_mode() -> Generator[None, None, None]: + def cursor_control_supported() -> bool: """ - Context manager which enables virtual terminal processing. + Gets whether this terminal supports the virtual terminal escape sequence + for getting the cursor position. """ - # Assume that Unix terminals support VT escape sequences by default. - yield + kernel32 = windll.kernel32 + stdout_handle = kernel32.GetStdHandle(_STD_OUTPUT_HANDLE) + + stdout_mode = wintypes.DWORD() + kernel32.GetConsoleMode(stdout_handle, byref(stdout_mode)) + + return bool(stdout_mode.value & _ENABLE_VIRTUAL_TERMINAL_PROCESSING) + + +except ImportError: + import termios @contextmanager def disable_echo() -> Generator[None, None, None]: @@ -152,10 +132,20 @@ def read_key() -> bytes: return key + def cursor_control_supported() -> bool: + """ + Gets whether this terminal supports the virtual terminal escape sequence + for getting the cursor position. + """ + # Assume that Unix terminals support VT escape sequences by default. + return True + def get_cursor_pos() -> tuple[int, int]: """ Returns the cursor position as a tuple (x, y). Positions are 0-based. + + This function may not work properly if cursor_control_supported() returns False. """ with disable_echo(): sys.stdout.write("\x1b[6n") From d375a9858f9bef13a14b226e8eb9a6e031ff863f Mon Sep 17 00:00:00 2001 From: Joel Spadin Date: Mon, 9 Mar 2026 10:42:55 -0500 Subject: [PATCH 8/8] fix: Exclude uv.lock from TOML checks --- .pre-commit-config.yaml | 4 ++-- pyproject.toml | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a88768..0b8605d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ fail_fast: false repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.2 + rev: v0.15.5 hooks: - id: ruff-check - id: ruff-format @@ -18,7 +18,7 @@ repos: - id: check-added-large-files - id: check-shebang-scripts-are-executable - repo: https://github.com/tombi-toml/tombi-pre-commit - rev: v0.7.32 + rev: v0.9.4 hooks: - id: tombi-format - id: tombi-lint diff --git a/pyproject.toml b/pyproject.toml index 2a0c3ec..a28647d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,3 +114,6 @@ packages = ["zmk"] [tool.setuptools_scm] version_file = "zmk/_version.py" + +[tool.tombi.files] +exclude = ["uv.lock"]