From 4b7991e6dece1c78c78119dd67621d12193f7f7d Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 19:51:55 +0200 Subject: [PATCH 01/17] feat: interactive example using pygame --- examples/interactive.py | 480 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 480 insertions(+) create mode 100644 examples/interactive.py diff --git a/examples/interactive.py b/examples/interactive.py new file mode 100644 index 0000000..f54449a --- /dev/null +++ b/examples/interactive.py @@ -0,0 +1,480 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "world_engine", +# "numpy", +# "pygame-ce", # community fork; ships a newer SDL2 with better Wayland support +# "pillow", +# ] +# +# [tool.uv.sources] +# world_engine = { path = "..", editable = true } +# /// +# +# Minimal interactive client for the Overworld World Engine. +# +# uv run examples/interactive.py Overworld/Waypoint-1.5-1B +# +# Controls: +# WASD / mouse / buttons : forwarded as CtrlInput to the model +# ESC : pause (freeze last frame, release mouse) +# U : reset (re-seed, continues playing) +# Left-click (on pause) : resume +# Close window / Ctrl+C : quit +# +# The engine is pinned to a single dedicated thread for its entire lifetime. +# torch.compile + triton.cudagraphs capture stream state tied to the capturing +# thread, so calling gen_frame() from a *different* thread segfaults. The main +# (pygame/UI) thread communicates with the engine thread via two queues. +# +# Supports both Waypoint-1 / 1.1 (single-frame output) and Waypoint-1.5 +# (4-frame temporally-compressed output). The only model-dependent branches +# live in `prime_seed` and `render`, keyed off `engine.model_cfg.model_type`. + +import argparse +import io +import json +import logging +import queue +import random +import threading +import time +import urllib.request + +import numpy as np +import pygame +import torch +from PIL import Image + +from world_engine import CtrlInput, WorldEngine + + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(threadName)s] %(message)s", + datefmt="%H:%M:%S", +) +log = logging.getLogger("interactive") + + +# GitHub contents API for the Biome `seeds/` directory, pinned to a known ref. +# Same source as examples/gen_sample.py. +BIOME_SEEDS_API = ( + "https://api.github.com/repos/Overworldai/Biome/contents/seeds?ref=14343a6" +) + +WINDOW_SIZE = (1280, 720) +# Aspect ratio for the center-crop applied to seed images. +CROP_ASPECT_W, CROP_ASPECT_H = 16, 9 + +# Map pygame keys / mouse buttons to the Windows VK integers that CtrlInput.button +# expects (see https://github.com/Overworldai/owl-control keycode table). Covers +# the main ANSI rows, space, shift, and the three mouse buttons — enough for +# WASD / spacebar / look-around gameplay without being exhaustive. +# Uses `pygame.K_*` int constants directly so this dict can be built at import +# time (before `pygame.init()`). +PYGAME_TO_VK: dict[int, int] = ( + {getattr(pygame, f"K_{ch}"): ord(ch) for ch in "1234567890"} + | {pygame.K_MINUS: 0xBD, pygame.K_EQUALS: 0xBB} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "qwertyuiop"} + | {pygame.K_LEFTBRACKET: 0xDB, pygame.K_RIGHTBRACKET: 0xDD, pygame.K_BACKSLASH: 0xDC} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "asdfghjkl"} + | {pygame.K_SEMICOLON: 0xBA, pygame.K_QUOTE: 0xDE} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "zxcvbnm"} + | {pygame.K_COMMA: 0xBC, pygame.K_PERIOD: 0xBE, pygame.K_SLASH: 0xBF} + | {pygame.K_SPACE: 0x20, pygame.K_LSHIFT: 0x10, pygame.K_RSHIFT: 0x10} +) +# pygame mouse button ids: 1=left, 2=middle, 3=right. VK: 0x01 LBUTTON, 0x04 MBUTTON, 0x02 RBUTTON. +MOUSE_TO_VK: dict[int, int] = {1: 0x01, 2: 0x04, 3: 0x02} + + +# --- seed loading ----------------------------------------------------------- + +def center_crop(img: Image.Image) -> np.ndarray: + """Center-crop to CROP_ASPECT_W:CROP_ASPECT_H. Returns uint8 (H, W, 3).""" + w, h = img.size + # Pick whichever dimension is the limiting factor and derive the other. + if w * CROP_ASPECT_H > h * CROP_ASPECT_W: + new_w, new_h = h * CROP_ASPECT_W // CROP_ASPECT_H, h + else: + new_w, new_h = w, w * CROP_ASPECT_H // CROP_ASPECT_W + left = (w - new_w) // 2 + top = (h - new_h) // 2 + # `.copy()` — PIL's buffer is read-only and torch.from_numpy requires writable. + return np.asarray(img.crop((left, top, left + new_w, top + new_h)).convert("RGB")).copy() + + +def load_seed_from_path(path: str) -> np.ndarray: + """Load a local image as uint8 (H, W, 3), center-cropped.""" + log.info("loading seed from local file: %s", path) + return center_crop(Image.open(path)) + + +def load_seed_from_github() -> np.ndarray: + """Download a random seed from the pinned Biome `seeds/` directory.""" + log.info("fetching Biome seeds index") + with urllib.request.urlopen(BIOME_SEEDS_API) as res: + entries = [e for e in json.load(res) if e["type"] == "file"] + url = random.choice(entries)["download_url"] + log.info("downloading random Biome seed: %s", url) + with urllib.request.urlopen(url) as res: + img_bytes = res.read() + return center_crop(Image.open(io.BytesIO(img_bytes))) + + +def prime_seed(engine: WorldEngine, seed: np.ndarray) -> None: + """Encode the seed frame into the engine's KV cache.""" + t = torch.from_numpy(seed).to(engine.device) # uint8 (H, W, 3) + tc = engine.model_cfg.temporal_compression + if tc > 1: + # Multi-frame models (e.g. Waypoint-1.5) consume/produce a stack of + # `temporal_compression` frames per step. + t = t.unsqueeze(0).expand(tc, -1, -1, -1).contiguous() + log.info("priming engine with seed shape=%s", tuple(t.shape)) + engine.append_frame(t) + + +# --- rendering -------------------------------------------------------------- + +def _blit_frame(screen: pygame.Surface, frame: np.ndarray) -> pygame.Surface: + """Blit a single (H, W, 3) uint8 numpy frame, scaled to the window. Returns the scaled surface.""" + # pygame.surfarray expects (W, H, 3), so swap the first two axes. + surf = pygame.surfarray.make_surface(frame.swapaxes(0, 1)) + surf = pygame.transform.scale(surf, screen.get_size()) + screen.blit(surf, (0, 0)) + return surf + + +def _draw_hud( + screen: pygame.Surface, + font: pygame.font.Font | None, + model_uri: str, + batch_dt: float, +) -> None: + """Draw FPS / frametime and model name at the top-right corner. No-op if font is None.""" + if font is None: + return + lines: list[tuple[str, tuple[int, int, int]]] = [] + if batch_dt > 0: + lines.append((f"{1.0 / batch_dt:.1f} fps / {batch_dt * 1000:.1f} ms", (255, 255, 255))) + lines.append((model_uri, (160, 160, 160))) + for i, (text, color) in enumerate(lines): + label = font.render(text, True, color) + x = screen.get_width() - label.get_width() - 12 + y = 12 + i * (label.get_height() + 4) + screen.blit(label, (x, y)) + + +def render( + screen: pygame.Surface, + frame_cpu: torch.Tensor, + batch_dt: float, + hud_font: pygame.font.Font | None = None, + model_uri: str = "", +) -> pygame.Surface: + """Display an already-on-CPU frame; return the last surface for pause caching. + + For multi-frame models the tensor is (T, H, W, 3) — we spread the T + sub-frames evenly across `batch_dt` (per README "Waypoint-1.5 Behavior"). + The sleeps are what let the pipeline overlap: while we pace here, the GPU + is already computing the next batch. + """ + arr = frame_cpu.numpy() + if arr.ndim == 3: # single-frame model: (H, W, 3) + last = _blit_frame(screen, arr) + _draw_hud(screen, hud_font, model_uri, batch_dt) + pygame.display.flip() + return last + + # Multi-frame model: (T, H, W, 3) + step_ms = max(0, int(batch_dt * 1000 / arr.shape[0])) + last: pygame.Surface | None = None + for i, sub in enumerate(arr): + if i > 0 and step_ms: + pygame.time.wait(step_ms) + last = _blit_frame(screen, sub) + _draw_hud(screen, hud_font, model_uri, batch_dt) + pygame.display.flip() + assert last is not None + return last + + +def draw_pause_overlay(screen: pygame.Surface, last: pygame.Surface, font: pygame.font.Font) -> None: + """Redraw the cached last frame with a dimmed overlay and centered pause text.""" + screen.blit(last, (0, 0)) + dim = pygame.Surface(screen.get_size(), pygame.SRCALPHA) + dim.fill((0, 0, 0, 128)) # 50% black + screen.blit(dim, (0, 0)) + label = font.render("Paused — click to resume", True, (255, 255, 255)) + rect = label.get_rect(center=screen.get_rect().center) + screen.blit(label, rect) + pygame.display.flip() + + +def draw_status(screen: pygame.Surface, font: pygame.font.Font, text: str) -> None: + """Clear to black and draw a status line in the bottom-left corner.""" + screen.fill((0, 0, 0)) + label = font.render(text, True, (220, 220, 220)) + screen.blit(label, (16, screen.get_height() - label.get_height() - 16)) + pygame.display.flip() + + +# --- engine thread ----------------------------------------------------------- + + +class Engine: + """Owns a WorldEngine on a dedicated thread. Main thread communicates via queues. + + All CUDA work (construction, torch.compile warmup, gen_frame) happens on the + engine thread. Cross-thread invocation of compiled+cudagraphs'd code segfaults. + """ + + def __init__(self) -> None: + self._stop = object() + self._reset = object() + self.ctrl_q: queue.Queue = queue.Queue(maxsize=1) + self.frame_q: queue.Queue = queue.Queue(maxsize=1) + self.status: str = "Starting…" + self.ready = threading.Event() + self.error: BaseException | None = None + + def start(self, model_uri: str, quant: str | None, device: str, seed_path: str | None) -> None: + threading.Thread( + target=self._run, args=(model_uri, quant, device, seed_path), + daemon=True, name="engine", + ).start() + + def stop(self) -> None: + try: + self.ctrl_q.put_nowait(self._stop) + except queue.Full: + pass + + def reset(self) -> None: + """Request engine reset (re-prime seed, produce fresh first frame).""" + self.ctrl_q.put(self._reset) + + def _run_gen(self, eng: WorldEngine, ctrl: CtrlInput) -> torch.Tensor: + """gen_frame -> synchronize -> .cpu(). The sync + immediate CPU copy + mirrors Biome's server: the returned GPU tensor may share storage that + gen_frame reuses on the next call, so it must be materialized before + the next invocation. + """ + frame = eng.gen_frame(ctrl=ctrl) + if torch.cuda.is_available(): + torch.cuda.synchronize() + return frame.cpu() + + def _run(self, model_uri: str, quant: str | None, device: str, seed_path: str | None) -> None: + try: + t0 = time.perf_counter() + self.status = "Loading model…" + log.info("loading model %s (quant=%s, device=%s)", model_uri, quant, device) + eng = WorldEngine(model_uri, quant=quant, device=device) + log.info( + "model loaded: type=%s, temporal_compression=%d", + eng.model_cfg.model_type, eng.model_cfg.temporal_compression, + ) + + self.status = "Loading seed…" + seed = load_seed_from_path(seed_path) if seed_path else load_seed_from_github() + + self.status = "Priming engine…" + eng.reset() + prime_seed(eng, seed) + + self.status = "Warming up (torch.compile)…" + log.info("warming up torch.compile") + w0 = time.perf_counter() + first = self._run_gen(eng, CtrlInput()) + log.info("warmup complete in %.1fs", time.perf_counter() - w0) + self.frame_q.put(first) + + self.status = "Ready — click to start." + log.info("init finished in %.1fs", time.perf_counter() - t0) + self.ready.set() + + # Command loop: pull a CtrlInput (or sentinel), dispatch. + while True: + cmd = self.ctrl_q.get() + if cmd is self._stop: + return + if cmd is self._reset: + # reset() clears the KV cache and all state, so the model + # must be re-seeded with append_frame before it can produce + # coherent output again. + log.info("resetting engine") + eng.reset() + prime_seed(eng, seed) + self.frame_q.put(self._run_gen(eng, CtrlInput())) + continue + self.frame_q.put(self._run_gen(eng, cmd)) + except BaseException as exc: + log.exception("engine thread failed") + self.error = exc + self.ready.set() + + +# --- main loop phases ------------------------------------------------------- + +def grab_mouse(screen: pygame.Surface) -> None: + """Confine + hide the cursor for FPS-style gameplay.""" + pygame.event.set_grab(True) + pygame.mouse.set_visible(False) + pygame.mouse.get_rel() # discard any pre-grab accumulated delta + + +def release_mouse() -> None: + pygame.event.set_grab(False) + pygame.mouse.set_visible(True) + + +def loading_screen(screen: pygame.Surface, font: pygame.font.Font, clock: pygame.time.Clock, engine: Engine) -> bool: + """Phase 1: show status while engine initializes. Returns True when ready, False on user quit.""" + while not engine.ready.is_set(): + for e in pygame.event.get(): + if e.type == pygame.QUIT or (e.type == pygame.KEYDOWN and e.key == pygame.K_ESCAPE): + return False + draw_status(screen, font, engine.status) + clock.tick(30) + return True + + +def gameplay( + screen: pygame.Surface, + font: pygame.font.Font, + hud_font: pygame.font.Font, + clock: pygame.time.Clock, + engine: Engine, + model_uri: str, + mouse_sensitivity: float, +) -> None: + """Phase 2: interactive generation loop. Starts auto-paused on the first frame.""" + first = engine.frame_q.get() + last_surface: pygame.Surface = render(screen, first, 0.0) + paused = True + log.info("ready") + + held_vks: set[int] = set() + pending: torch.Tensor | None = None + batch_dt = 0.0 + + while True: + scroll = 0 + + for e in pygame.event.get(): + if e.type == pygame.QUIT: + return + + # Auto-pause when the cursor leaves the window. Safety net for + # WMs where `set_grab` is advisory and the cursor can escape. + elif e.type == pygame.WINDOWLEAVE and not paused: + if pending is not None: + last_surface = render(screen, pending, batch_dt) + pending = None + paused = True + release_mouse() + continue + + elif e.type == pygame.KEYDOWN: + if e.key == pygame.K_ESCAPE and not paused: + if pending is not None: + last_surface = render(screen, pending, batch_dt) + pending = None + paused = True + release_mouse() + continue + if e.key == pygame.K_u and not paused: + # Reset: re-prime the seed. The engine thread handles + # reset + re-seed + gen_frame and puts the fresh frame on + # frame_q like any normal frame. Gameplay continues. + pending = None + engine.reset() + continue + vk = PYGAME_TO_VK.get(e.key) + if vk is not None: + held_vks.add(vk) + + elif e.type == pygame.KEYUP: + vk = PYGAME_TO_VK.get(e.key) + if vk is not None: + held_vks.discard(vk) + + elif e.type == pygame.MOUSEBUTTONDOWN: + if paused and e.button == 1: + paused = False + grab_mouse(screen) + continue + vk = MOUSE_TO_VK.get(e.button) + if vk is not None: + held_vks.add(vk) + if e.button == 4: + scroll += 1 + elif e.button == 5: + scroll -= 1 + + elif e.type == pygame.MOUSEBUTTONUP: + vk = MOUSE_TO_VK.get(e.button) + if vk is not None: + held_vks.discard(vk) + + if paused: + draw_pause_overlay(screen, last_surface, font) + clock.tick(60) + continue + + dx, dy = pygame.mouse.get_rel() + ctrl = CtrlInput( + button=set(held_vks), + mouse=(dx * mouse_sensitivity, dy * mouse_sensitivity), + scroll_wheel=scroll, + ) + + # Hand the ctrl off to the engine thread; it starts computing + # immediately. We then render the *previous* batch (with pacing) + # while it works. Finally we block on the queue for the new batch. + t0 = time.perf_counter() + engine.ctrl_q.put(ctrl) + if pending is not None: + last_surface = render(screen, pending, batch_dt, hud_font, model_uri) + pending = engine.frame_q.get() + batch_dt = time.perf_counter() - t0 + + +# --- entry point ------------------------------------------------------------- + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("model_uri", help="HF model id, e.g. Overworld/Waypoint-1.5-1B") + ap.add_argument("-s", "--seed", help="Path to a local seed image (defaults to a random Biome seed)") + ap.add_argument("-q", "--quant", choices=["intw8a8", "fp8w8a8", "nvfp4"], default=None) + ap.add_argument("-d", "--device", default="cuda") + ap.add_argument("-m", "--mouse-sensitivity", type=float, default=1.5) + args = ap.parse_args() + + pygame.init() + screen = pygame.display.set_mode(WINDOW_SIZE, pygame.RESIZABLE) + pygame.display.set_caption(args.model_uri) + font = pygame.font.SysFont(None, 36) + hud_font = pygame.font.SysFont(None, 22) + status_font = pygame.font.SysFont(None, 24) + clock = pygame.time.Clock() + + engine = Engine() + engine.start(args.model_uri, args.quant, args.device, args.seed) + + try: + if not loading_screen(screen, status_font, clock, engine): + return + if engine.error: + raise engine.error + gameplay(screen, font, hud_font, clock, engine, args.model_uri, args.mouse_sensitivity) + except KeyboardInterrupt: + pass + finally: + engine.stop() + pygame.quit() + + +if __name__ == "__main__": + main() From ac0d179f655602e2a5499ee7fad0fafec82ba62e Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 19:57:28 +0200 Subject: [PATCH 02/17] refactor(interactive): make single-threaded again --- examples/interactive.py | 213 ++++++++++++++-------------------------- 1 file changed, 76 insertions(+), 137 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index f54449a..76c6dcc 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -22,11 +22,6 @@ # Left-click (on pause) : resume # Close window / Ctrl+C : quit # -# The engine is pinned to a single dedicated thread for its entire lifetime. -# torch.compile + triton.cudagraphs capture stream state tied to the capturing -# thread, so calling gen_frame() from a *different* thread segfaults. The main -# (pygame/UI) thread communicates with the engine thread via two queues. -# # Supports both Waypoint-1 / 1.1 (single-frame output) and Waypoint-1.5 # (4-frame temporally-compressed output). The only model-dependent branches # live in `prime_seed` and `render`, keyed off `engine.model_cfg.model_type`. @@ -35,9 +30,7 @@ import io import json import logging -import queue import random -import threading import time import urllib.request @@ -51,7 +44,7 @@ logging.basicConfig( level=logging.INFO, - format="%(asctime)s [%(threadName)s] %(message)s", + format="%(asctime)s %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger("interactive") @@ -219,103 +212,7 @@ def draw_status(screen: pygame.Surface, font: pygame.font.Font, text: str) -> No pygame.display.flip() -# --- engine thread ----------------------------------------------------------- - - -class Engine: - """Owns a WorldEngine on a dedicated thread. Main thread communicates via queues. - - All CUDA work (construction, torch.compile warmup, gen_frame) happens on the - engine thread. Cross-thread invocation of compiled+cudagraphs'd code segfaults. - """ - - def __init__(self) -> None: - self._stop = object() - self._reset = object() - self.ctrl_q: queue.Queue = queue.Queue(maxsize=1) - self.frame_q: queue.Queue = queue.Queue(maxsize=1) - self.status: str = "Starting…" - self.ready = threading.Event() - self.error: BaseException | None = None - - def start(self, model_uri: str, quant: str | None, device: str, seed_path: str | None) -> None: - threading.Thread( - target=self._run, args=(model_uri, quant, device, seed_path), - daemon=True, name="engine", - ).start() - - def stop(self) -> None: - try: - self.ctrl_q.put_nowait(self._stop) - except queue.Full: - pass - - def reset(self) -> None: - """Request engine reset (re-prime seed, produce fresh first frame).""" - self.ctrl_q.put(self._reset) - - def _run_gen(self, eng: WorldEngine, ctrl: CtrlInput) -> torch.Tensor: - """gen_frame -> synchronize -> .cpu(). The sync + immediate CPU copy - mirrors Biome's server: the returned GPU tensor may share storage that - gen_frame reuses on the next call, so it must be materialized before - the next invocation. - """ - frame = eng.gen_frame(ctrl=ctrl) - if torch.cuda.is_available(): - torch.cuda.synchronize() - return frame.cpu() - - def _run(self, model_uri: str, quant: str | None, device: str, seed_path: str | None) -> None: - try: - t0 = time.perf_counter() - self.status = "Loading model…" - log.info("loading model %s (quant=%s, device=%s)", model_uri, quant, device) - eng = WorldEngine(model_uri, quant=quant, device=device) - log.info( - "model loaded: type=%s, temporal_compression=%d", - eng.model_cfg.model_type, eng.model_cfg.temporal_compression, - ) - - self.status = "Loading seed…" - seed = load_seed_from_path(seed_path) if seed_path else load_seed_from_github() - - self.status = "Priming engine…" - eng.reset() - prime_seed(eng, seed) - - self.status = "Warming up (torch.compile)…" - log.info("warming up torch.compile") - w0 = time.perf_counter() - first = self._run_gen(eng, CtrlInput()) - log.info("warmup complete in %.1fs", time.perf_counter() - w0) - self.frame_q.put(first) - - self.status = "Ready — click to start." - log.info("init finished in %.1fs", time.perf_counter() - t0) - self.ready.set() - - # Command loop: pull a CtrlInput (or sentinel), dispatch. - while True: - cmd = self.ctrl_q.get() - if cmd is self._stop: - return - if cmd is self._reset: - # reset() clears the KV cache and all state, so the model - # must be re-seeded with append_frame before it can produce - # coherent output again. - log.info("resetting engine") - eng.reset() - prime_seed(eng, seed) - self.frame_q.put(self._run_gen(eng, CtrlInput())) - continue - self.frame_q.put(self._run_gen(eng, cmd)) - except BaseException as exc: - log.exception("engine thread failed") - self.error = exc - self.ready.set() - - -# --- main loop phases ------------------------------------------------------- +# --- mouse helpers ----------------------------------------------------------- def grab_mouse(screen: pygame.Surface) -> None: """Confine + hide the cursor for FPS-style gameplay.""" @@ -329,33 +226,75 @@ def release_mouse() -> None: pygame.mouse.set_visible(True) -def loading_screen(screen: pygame.Surface, font: pygame.font.Font, clock: pygame.time.Clock, engine: Engine) -> bool: - """Phase 1: show status while engine initializes. Returns True when ready, False on user quit.""" - while not engine.ready.is_set(): - for e in pygame.event.get(): - if e.type == pygame.QUIT or (e.type == pygame.KEYDOWN and e.key == pygame.K_ESCAPE): - return False - draw_status(screen, font, engine.status) - clock.tick(30) - return True +# --- engine init ------------------------------------------------------------- + +def init_engine( + screen: pygame.Surface, + font: pygame.font.Font, + model_uri: str, + quant: str | None, + device: str, + seed_path: str | None, +) -> tuple[WorldEngine, np.ndarray, torch.Tensor]: + """Load the model, seed, and produce the first frame (torch.compile warmup). + Draws status to the screen before each blocking step. The window won't pump + events during the heavy steps (model load, compile), but the status text + gives the user a progress indication. + + Returns (engine, seed, first_frame_cpu). + """ + draw_status(screen, font, "Loading model…") + log.info("loading model %s (quant=%s, device=%s)", model_uri, quant, device) + engine = WorldEngine(model_uri, quant=quant, device=device) + log.info( + "model loaded: type=%s, temporal_compression=%d", + engine.model_cfg.model_type, engine.model_cfg.temporal_compression, + ) + + draw_status(screen, font, "Loading seed…") + seed = load_seed_from_path(seed_path) if seed_path else load_seed_from_github() + + draw_status(screen, font, "Priming engine…") + engine.reset() + prime_seed(engine, seed) + + # The first gen_frame triggers torch.compile — the most expensive step. + draw_status(screen, font, "Warming up (torch.compile)…") + log.info("warming up torch.compile") + w0 = time.perf_counter() + first = engine.gen_frame(ctrl=CtrlInput()).cpu() + log.info("warmup complete in %.1fs", time.perf_counter() - w0) + + return engine, seed, first + + +# --- gameplay ---------------------------------------------------------------- def gameplay( screen: pygame.Surface, font: pygame.font.Font, hud_font: pygame.font.Font, clock: pygame.time.Clock, - engine: Engine, + engine: WorldEngine, + seed: np.ndarray, + first_frame: torch.Tensor, model_uri: str, mouse_sensitivity: float, ) -> None: - """Phase 2: interactive generation loop. Starts auto-paused on the first frame.""" - first = engine.frame_q.get() - last_surface: pygame.Surface = render(screen, first, 0.0) + """Interactive generation loop. Starts auto-paused on the first frame. + + Uses the pipelining pattern from the README: gen_frame() queues GPU kernels + and returns immediately; we render the *previous* batch (with pacing sleeps) + while the GPU works; then .cpu() syncs and transfers the result. + """ + last_surface: pygame.Surface = render(screen, first_frame, 0.0) paused = True log.info("ready") held_vks: set[int] = set() + # Pipeline state: `pending` holds the not-yet-rendered CPU frame from the + # previous gen_frame call. We render it while the GPU computes the next one. pending: torch.Tensor | None = None batch_dt = 0.0 @@ -370,7 +309,7 @@ def gameplay( # WMs where `set_grab` is advisory and the cursor can escape. elif e.type == pygame.WINDOWLEAVE and not paused: if pending is not None: - last_surface = render(screen, pending, batch_dt) + last_surface = render(screen, pending, batch_dt, hud_font, model_uri) pending = None paused = True release_mouse() @@ -379,17 +318,18 @@ def gameplay( elif e.type == pygame.KEYDOWN: if e.key == pygame.K_ESCAPE and not paused: if pending is not None: - last_surface = render(screen, pending, batch_dt) + last_surface = render(screen, pending, batch_dt, hud_font, model_uri) pending = None paused = True release_mouse() continue if e.key == pygame.K_u and not paused: - # Reset: re-prime the seed. The engine thread handles - # reset + re-seed + gen_frame and puts the fresh frame on - # frame_q like any normal frame. Gameplay continues. + # reset() clears the KV cache and all state, so the model + # must be re-seeded with append_frame before it can produce + # coherent output again. pending = None engine.reset() + prime_seed(engine, seed) continue vk = PYGAME_TO_VK.get(e.key) if vk is not None: @@ -430,14 +370,14 @@ def gameplay( scroll_wheel=scroll, ) - # Hand the ctrl off to the engine thread; it starts computing - # immediately. We then render the *previous* batch (with pacing) - # while it works. Finally we block on the queue for the new batch. + # Pipeline: kick off generation (GPU kernels queued, returns fast), + # then render the *previous* batch while the GPU works. Finally .cpu() + # syncs and transfers the just-computed batch to CPU. t0 = time.perf_counter() - engine.ctrl_q.put(ctrl) + next_frames = engine.gen_frame(ctrl=ctrl) if pending is not None: last_surface = render(screen, pending, batch_dt, hud_font, model_uri) - pending = engine.frame_q.get() + pending = next_frames.cpu() batch_dt = time.perf_counter() - t0 @@ -460,19 +400,18 @@ def main() -> None: status_font = pygame.font.SysFont(None, 24) clock = pygame.time.Clock() - engine = Engine() - engine.start(args.model_uri, args.quant, args.device, args.seed) - try: - if not loading_screen(screen, status_font, clock, engine): - return - if engine.error: - raise engine.error - gameplay(screen, font, hud_font, clock, engine, args.model_uri, args.mouse_sensitivity) + engine, seed, first = init_engine( + screen, status_font, args.model_uri, args.quant, args.device, args.seed, + ) + gameplay( + screen, font, hud_font, clock, + engine, seed, first, + args.model_uri, args.mouse_sensitivity, + ) except KeyboardInterrupt: pass finally: - engine.stop() pygame.quit() From 8898e58c904fa7f460e07de2c7776cc7951dab3d Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 20:06:37 +0200 Subject: [PATCH 03/17] refactor(interactive): GameState class --- examples/interactive.py | 181 ++++++++++++++++++++++------------------ 1 file changed, 99 insertions(+), 82 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 76c6dcc..20c782b 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -33,6 +33,7 @@ import random import time import urllib.request +from dataclasses import dataclass, field import numpy as np import pygame @@ -212,20 +213,6 @@ def draw_status(screen: pygame.Surface, font: pygame.font.Font, text: str) -> No pygame.display.flip() -# --- mouse helpers ----------------------------------------------------------- - -def grab_mouse(screen: pygame.Surface) -> None: - """Confine + hide the cursor for FPS-style gameplay.""" - pygame.event.set_grab(True) - pygame.mouse.set_visible(False) - pygame.mouse.get_rel() # discard any pre-grab accumulated delta - - -def release_mouse() -> None: - pygame.event.set_grab(False) - pygame.mouse.set_visible(True) - - # --- engine init ------------------------------------------------------------- def init_engine( @@ -271,103 +258,133 @@ def init_engine( # --- gameplay ---------------------------------------------------------------- -def gameplay( - screen: pygame.Surface, - font: pygame.font.Font, - hud_font: pygame.font.Font, - clock: pygame.time.Clock, - engine: WorldEngine, - seed: np.ndarray, - first_frame: torch.Tensor, - model_uri: str, - mouse_sensitivity: float, -) -> None: - """Interactive generation loop. Starts auto-paused on the first frame. - Uses the pipelining pattern from the README: gen_frame() queues GPU kernels - and returns immediately; we render the *previous* batch (with pacing sleeps) - while the GPU works; then .cpu() syncs and transfers the result. - """ - last_surface: pygame.Surface = render(screen, first_frame, 0.0) - paused = True - log.info("ready") +@dataclass +class GameState: + """Mutable state shared between event handling and the generation loop.""" - held_vks: set[int] = set() + screen: pygame.Surface + engine: WorldEngine + seed: np.ndarray + hud_font: pygame.font.Font + model_uri: str + paused: bool = True + held_vks: set[int] = field(default_factory=set) + scroll: int = 0 # Pipeline state: `pending` holds the not-yet-rendered CPU frame from the # previous gen_frame call. We render it while the GPU computes the next one. pending: torch.Tensor | None = None - batch_dt = 0.0 - - while True: - scroll = 0 + batch_dt: float = 0.0 + last_surface: pygame.Surface | None = None + + def enter_pause(self) -> None: + """Flush any in-flight batch and enter paused state.""" + if self.pending is not None: + self.last_surface = render(self.screen, self.pending, self.batch_dt, self.hud_font, self.model_uri) + self.pending = None + self.paused = True + pygame.event.set_grab(False) + pygame.mouse.set_visible(True) + + def exit_pause(self) -> None: + """Re-grab the cursor and resume gameplay.""" + self.paused = False + pygame.event.set_grab(True) + pygame.mouse.set_visible(False) + pygame.mouse.get_rel() # discard accumulated delta during pause + + def process_events(self) -> bool: + """Drain pygame events and update state. Returns False to quit.""" + self.scroll = 0 for e in pygame.event.get(): if e.type == pygame.QUIT: - return + return False # Auto-pause when the cursor leaves the window. Safety net for # WMs where `set_grab` is advisory and the cursor can escape. - elif e.type == pygame.WINDOWLEAVE and not paused: - if pending is not None: - last_surface = render(screen, pending, batch_dt, hud_font, model_uri) - pending = None - paused = True - release_mouse() - continue + elif e.type == pygame.WINDOWLEAVE and not self.paused: + self.enter_pause() elif e.type == pygame.KEYDOWN: - if e.key == pygame.K_ESCAPE and not paused: - if pending is not None: - last_surface = render(screen, pending, batch_dt, hud_font, model_uri) - pending = None - paused = True - release_mouse() - continue - if e.key == pygame.K_u and not paused: + if e.key == pygame.K_ESCAPE and not self.paused: + self.enter_pause() + elif e.key == pygame.K_u and not self.paused: # reset() clears the KV cache and all state, so the model # must be re-seeded with append_frame before it can produce # coherent output again. - pending = None - engine.reset() - prime_seed(engine, seed) - continue - vk = PYGAME_TO_VK.get(e.key) - if vk is not None: - held_vks.add(vk) + self.pending = None + self.engine.reset() + prime_seed(self.engine, self.seed) + else: + vk = PYGAME_TO_VK.get(e.key) + if vk is not None: + self.held_vks.add(vk) elif e.type == pygame.KEYUP: vk = PYGAME_TO_VK.get(e.key) if vk is not None: - held_vks.discard(vk) + self.held_vks.discard(vk) elif e.type == pygame.MOUSEBUTTONDOWN: - if paused and e.button == 1: - paused = False - grab_mouse(screen) - continue - vk = MOUSE_TO_VK.get(e.button) - if vk is not None: - held_vks.add(vk) - if e.button == 4: - scroll += 1 - elif e.button == 5: - scroll -= 1 + if self.paused and e.button == 1: + self.exit_pause() + else: + vk = MOUSE_TO_VK.get(e.button) + if vk is not None: + self.held_vks.add(vk) + if e.button == 4: + self.scroll += 1 + elif e.button == 5: + self.scroll -= 1 elif e.type == pygame.MOUSEBUTTONUP: vk = MOUSE_TO_VK.get(e.button) if vk is not None: - held_vks.discard(vk) + self.held_vks.discard(vk) + + return True + + +def gameplay( + screen: pygame.Surface, + font: pygame.font.Font, + hud_font: pygame.font.Font, + clock: pygame.time.Clock, + engine: WorldEngine, + seed: np.ndarray, + first_frame: torch.Tensor, + model_uri: str, + mouse_sensitivity: float, +) -> None: + """Interactive generation loop. Starts auto-paused on the first frame. + + Uses the pipelining pattern from the README: gen_frame() queues GPU kernels + and returns immediately; we render the *previous* batch (with pacing sleeps) + while the GPU works; then .cpu() syncs and transfers the result. + """ + state = GameState( + screen=screen, engine=engine, seed=seed, + hud_font=hud_font, model_uri=model_uri, + last_surface=render(screen, first_frame, 0.0), + ) + log.info("ready") + + while True: + if not state.process_events(): + return - if paused: - draw_pause_overlay(screen, last_surface, font) + if state.paused: + assert state.last_surface is not None + draw_pause_overlay(screen, state.last_surface, font) clock.tick(60) continue dx, dy = pygame.mouse.get_rel() ctrl = CtrlInput( - button=set(held_vks), + button=set(state.held_vks), mouse=(dx * mouse_sensitivity, dy * mouse_sensitivity), - scroll_wheel=scroll, + scroll_wheel=state.scroll, ) # Pipeline: kick off generation (GPU kernels queued, returns fast), @@ -375,10 +392,10 @@ def gameplay( # syncs and transfers the just-computed batch to CPU. t0 = time.perf_counter() next_frames = engine.gen_frame(ctrl=ctrl) - if pending is not None: - last_surface = render(screen, pending, batch_dt, hud_font, model_uri) - pending = next_frames.cpu() - batch_dt = time.perf_counter() - t0 + if state.pending is not None: + state.last_surface = render(screen, state.pending, state.batch_dt, hud_font, model_uri) + state.pending = next_frames.cpu() + state.batch_dt = time.perf_counter() - t0 # --- entry point ------------------------------------------------------------- From c93c63cdae9fd42b48dc8f1862ead43be9eed277 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 20:15:58 +0200 Subject: [PATCH 04/17] refactor(interactive): Engine class --- examples/interactive.py | 161 +++++++++++++++++++++------------------- 1 file changed, 83 insertions(+), 78 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 20c782b..a2561cc 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -116,16 +116,6 @@ def load_seed_from_github() -> np.ndarray: return center_crop(Image.open(io.BytesIO(img_bytes))) -def prime_seed(engine: WorldEngine, seed: np.ndarray) -> None: - """Encode the seed frame into the engine's KV cache.""" - t = torch.from_numpy(seed).to(engine.device) # uint8 (H, W, 3) - tc = engine.model_cfg.temporal_compression - if tc > 1: - # Multi-frame models (e.g. Waypoint-1.5) consume/produce a stack of - # `temporal_compression` frames per step. - t = t.unsqueeze(0).expand(tc, -1, -1, -1).contiguous() - log.info("priming engine with seed shape=%s", tuple(t.shape)) - engine.append_frame(t) # --- rendering -------------------------------------------------------------- @@ -213,47 +203,79 @@ def draw_status(screen: pygame.Surface, font: pygame.font.Font, text: str) -> No pygame.display.flip() -# --- engine init ------------------------------------------------------------- +# --- engine ------------------------------------------------------------------ -def init_engine( - screen: pygame.Surface, - font: pygame.font.Font, - model_uri: str, - quant: str | None, - device: str, - seed_path: str | None, -) -> tuple[WorldEngine, np.ndarray, torch.Tensor]: - """Load the model, seed, and produce the first frame (torch.compile warmup). - Draws status to the screen before each blocking step. The window won't pump - events during the heavy steps (model load, compile), but the status text - gives the user a progress indication. +class Engine: + """Wraps WorldEngine with seed management and the generation pipeline. - Returns (engine, seed, first_frame_cpu). + After construction, the first generated frame is available as `self.pending`. + Subsequent frames are produced by `next_frame()` and should be `.cpu()`'d + into `self.pending` by the caller before the next `next_frame()` call. """ - draw_status(screen, font, "Loading model…") - log.info("loading model %s (quant=%s, device=%s)", model_uri, quant, device) - engine = WorldEngine(model_uri, quant=quant, device=device) - log.info( - "model loaded: type=%s, temporal_compression=%d", - engine.model_cfg.model_type, engine.model_cfg.temporal_compression, - ) - draw_status(screen, font, "Loading seed…") - seed = load_seed_from_path(seed_path) if seed_path else load_seed_from_github() - - draw_status(screen, font, "Priming engine…") - engine.reset() - prime_seed(engine, seed) - - # The first gen_frame triggers torch.compile — the most expensive step. - draw_status(screen, font, "Warming up (torch.compile)…") - log.info("warming up torch.compile") - w0 = time.perf_counter() - first = engine.gen_frame(ctrl=CtrlInput()).cpu() - log.info("warmup complete in %.1fs", time.perf_counter() - w0) + inner: WorldEngine + seed: np.ndarray + model_uri: str + pending: torch.Tensor | None + + def __init__( + self, + screen: pygame.Surface, + font: pygame.font.Font, + model_uri: str, + quant: str | None, + device: str, + seed_path: str | None, + ) -> None: + """Load model, seed, prime, compile-warmup. Shows status on *screen*.""" + draw_status(screen, font, "Loading model…") + log.info("loading model %s (quant=%s, device=%s)", model_uri, quant, device) + self.inner = WorldEngine(model_uri, quant=quant, device=device) + log.info( + "model loaded: type=%s, temporal_compression=%d", + self.inner.model_cfg.model_type, self.inner.model_cfg.temporal_compression, + ) - return engine, seed, first + draw_status(screen, font, "Loading seed…") + self.seed = load_seed_from_path(seed_path) if seed_path else load_seed_from_github() + self.model_uri = model_uri + + draw_status(screen, font, "Priming engine…") + self.inner.reset() + self._prime_seed() + + # The first gen_frame triggers torch.compile — the most expensive step. + draw_status(screen, font, "Warming up (torch.compile)…") + log.info("warming up torch.compile") + w0 = time.perf_counter() + self.pending = self.next_frame(ctrl=CtrlInput()).cpu() + log.info("warmup complete in %.1fs", time.perf_counter() - w0) + + def _prime_seed(self) -> None: + """Encode the seed frame into the KV cache.""" + t = torch.from_numpy(self.seed).to(self.inner.device) # uint8 (H, W, 3) + tc = self.inner.model_cfg.temporal_compression + if tc > 1: + # Multi-frame models (e.g. Waypoint-1.5) consume/produce a stack of + # `temporal_compression` frames per step. + t = t.unsqueeze(0).expand(tc, -1, -1, -1).contiguous() + log.info("priming engine with seed shape=%s", tuple(t.shape)) + self.inner.append_frame(t) + + def next_frame(self, ctrl: CtrlInput) -> torch.Tensor: + """Generate the next frame. Returns a GPU tensor; caller must .cpu() before the next call.""" + return self.inner.gen_frame(ctrl=ctrl) + + def reset(self) -> None: + """Reset all state and re-prime the seed. + + reset() clears the KV cache and all state, so the model must be + re-seeded with append_frame before it can produce coherent output. + """ + self.pending = None + self.inner.reset() + self._prime_seed() # --- gameplay ---------------------------------------------------------------- @@ -264,24 +286,21 @@ class GameState: """Mutable state shared between event handling and the generation loop.""" screen: pygame.Surface - engine: WorldEngine - seed: np.ndarray + engine: Engine hud_font: pygame.font.Font - model_uri: str paused: bool = True held_vks: set[int] = field(default_factory=set) scroll: int = 0 - # Pipeline state: `pending` holds the not-yet-rendered CPU frame from the - # previous gen_frame call. We render it while the GPU computes the next one. - pending: torch.Tensor | None = None batch_dt: float = 0.0 last_surface: pygame.Surface | None = None def enter_pause(self) -> None: """Flush any in-flight batch and enter paused state.""" - if self.pending is not None: - self.last_surface = render(self.screen, self.pending, self.batch_dt, self.hud_font, self.model_uri) - self.pending = None + if self.engine.pending is not None: + self.last_surface = render( + self.screen, self.engine.pending, self.batch_dt, self.hud_font, self.engine.model_uri, + ) + self.engine.pending = None self.paused = True pygame.event.set_grab(False) pygame.mouse.set_visible(True) @@ -310,12 +329,7 @@ def process_events(self) -> bool: if e.key == pygame.K_ESCAPE and not self.paused: self.enter_pause() elif e.key == pygame.K_u and not self.paused: - # reset() clears the KV cache and all state, so the model - # must be re-seeded with append_frame before it can produce - # coherent output again. - self.pending = None self.engine.reset() - prime_seed(self.engine, self.seed) else: vk = PYGAME_TO_VK.get(e.key) if vk is not None: @@ -351,10 +365,7 @@ def gameplay( font: pygame.font.Font, hud_font: pygame.font.Font, clock: pygame.time.Clock, - engine: WorldEngine, - seed: np.ndarray, - first_frame: torch.Tensor, - model_uri: str, + engine: Engine, mouse_sensitivity: float, ) -> None: """Interactive generation loop. Starts auto-paused on the first frame. @@ -364,10 +375,10 @@ def gameplay( while the GPU works; then .cpu() syncs and transfers the result. """ state = GameState( - screen=screen, engine=engine, seed=seed, - hud_font=hud_font, model_uri=model_uri, - last_surface=render(screen, first_frame, 0.0), + screen=screen, engine=engine, hud_font=hud_font, + last_surface=render(screen, engine.pending, 0.0), ) + engine.pending = None log.info("ready") while True: @@ -391,10 +402,10 @@ def gameplay( # then render the *previous* batch while the GPU works. Finally .cpu() # syncs and transfers the just-computed batch to CPU. t0 = time.perf_counter() - next_frames = engine.gen_frame(ctrl=ctrl) - if state.pending is not None: - state.last_surface = render(screen, state.pending, state.batch_dt, hud_font, model_uri) - state.pending = next_frames.cpu() + next_frames = engine.next_frame(ctrl=ctrl) + if engine.pending is not None: + state.last_surface = render(screen, engine.pending, state.batch_dt, hud_font, engine.model_uri) + engine.pending = next_frames.cpu() state.batch_dt = time.perf_counter() - t0 @@ -418,14 +429,8 @@ def main() -> None: clock = pygame.time.Clock() try: - engine, seed, first = init_engine( - screen, status_font, args.model_uri, args.quant, args.device, args.seed, - ) - gameplay( - screen, font, hud_font, clock, - engine, seed, first, - args.model_uri, args.mouse_sensitivity, - ) + engine = Engine(screen, status_font, args.model_uri, args.quant, args.device, args.seed) + gameplay(screen, font, hud_font, clock, engine, args.mouse_sensitivity) except KeyboardInterrupt: pass finally: From 42f0bd15c6da4a9e3c69f64687e2089978e600ad Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 20:24:50 +0200 Subject: [PATCH 05/17] refactor(interactive): Renderer class --- examples/interactive.py | 188 ++++++++++++++++++---------------------- 1 file changed, 82 insertions(+), 106 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index a2561cc..2ee7604 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -120,87 +120,77 @@ def load_seed_from_github() -> np.ndarray: # --- rendering -------------------------------------------------------------- -def _blit_frame(screen: pygame.Surface, frame: np.ndarray) -> pygame.Surface: - """Blit a single (H, W, 3) uint8 numpy frame, scaled to the window. Returns the scaled surface.""" - # pygame.surfarray expects (W, H, 3), so swap the first two axes. - surf = pygame.surfarray.make_surface(frame.swapaxes(0, 1)) - surf = pygame.transform.scale(surf, screen.get_size()) - screen.blit(surf, (0, 0)) - return surf - - -def _draw_hud( - screen: pygame.Surface, - font: pygame.font.Font | None, - model_uri: str, - batch_dt: float, -) -> None: - """Draw FPS / frametime and model name at the top-right corner. No-op if font is None.""" - if font is None: - return - lines: list[tuple[str, tuple[int, int, int]]] = [] - if batch_dt > 0: - lines.append((f"{1.0 / batch_dt:.1f} fps / {batch_dt * 1000:.1f} ms", (255, 255, 255))) - lines.append((model_uri, (160, 160, 160))) - for i, (text, color) in enumerate(lines): - label = font.render(text, True, color) - x = screen.get_width() - label.get_width() - 12 - y = 12 + i * (label.get_height() + 4) - screen.blit(label, (x, y)) - - -def render( - screen: pygame.Surface, - frame_cpu: torch.Tensor, - batch_dt: float, - hud_font: pygame.font.Font | None = None, - model_uri: str = "", -) -> pygame.Surface: - """Display an already-on-CPU frame; return the last surface for pause caching. - - For multi-frame models the tensor is (T, H, W, 3) — we spread the T - sub-frames evenly across `batch_dt` (per README "Waypoint-1.5 Behavior"). - The sleeps are what let the pipeline overlap: while we pace here, the GPU - is already computing the next batch. - """ - arr = frame_cpu.numpy() - if arr.ndim == 3: # single-frame model: (H, W, 3) - last = _blit_frame(screen, arr) - _draw_hud(screen, hud_font, model_uri, batch_dt) - pygame.display.flip() - return last - - # Multi-frame model: (T, H, W, 3) - step_ms = max(0, int(batch_dt * 1000 / arr.shape[0])) - last: pygame.Surface | None = None - for i, sub in enumerate(arr): - if i > 0 and step_ms: - pygame.time.wait(step_ms) - last = _blit_frame(screen, sub) - _draw_hud(screen, hud_font, model_uri, batch_dt) + +class Renderer: + """Owns the screen surface and fonts; handles all drawing.""" + + def __init__(self, model_uri: str) -> None: + self.screen = pygame.display.set_mode(WINDOW_SIZE, pygame.RESIZABLE) + pygame.display.set_caption(model_uri) + self.model_uri = model_uri + self.font = pygame.font.SysFont(None, 36) + self.hud_font = pygame.font.SysFont(None, 22) + self.status_font = pygame.font.SysFont(None, 24) + self._last_surface: pygame.Surface | None = None + + def _present(self, frame: np.ndarray, batch_dt: float) -> None: + """Blit a single (H, W, 3) frame, draw HUD, flip, and cache for pause.""" + # pygame.surfarray expects (W, H, 3), so swap the first two axes. + surf = pygame.surfarray.make_surface(frame.swapaxes(0, 1)) + surf = pygame.transform.scale(surf, self.screen.get_size()) + self.screen.blit(surf, (0, 0)) + self._last_surface = surf + + lines: list[tuple[str, tuple[int, int, int]]] = [] + if batch_dt > 0: + lines.append((f"{1.0 / batch_dt:.1f} fps / {batch_dt * 1000:.1f} ms", (255, 255, 255))) + lines.append((self.model_uri, (160, 160, 160))) + for i, (text, color) in enumerate(lines): + label = self.hud_font.render(text, True, color) + x = self.screen.get_width() - label.get_width() - 12 + y = 12 + i * (label.get_height() + 4) + self.screen.blit(label, (x, y)) + pygame.display.flip() - assert last is not None - return last + def render_frame(self, frame_cpu: torch.Tensor, batch_dt: float) -> None: + """Display an already-on-CPU frame and cache it for the pause overlay. -def draw_pause_overlay(screen: pygame.Surface, last: pygame.Surface, font: pygame.font.Font) -> None: - """Redraw the cached last frame with a dimmed overlay and centered pause text.""" - screen.blit(last, (0, 0)) - dim = pygame.Surface(screen.get_size(), pygame.SRCALPHA) - dim.fill((0, 0, 0, 128)) # 50% black - screen.blit(dim, (0, 0)) - label = font.render("Paused — click to resume", True, (255, 255, 255)) - rect = label.get_rect(center=screen.get_rect().center) - screen.blit(label, rect) - pygame.display.flip() + For multi-frame models the tensor is (T, H, W, 3) — we spread the T + sub-frames evenly across `batch_dt` (per README "Waypoint-1.5 Behavior"). + The sleeps are what let the pipeline overlap: while we pace here, the + GPU is already computing the next batch. + """ + arr = frame_cpu.numpy() + if arr.ndim == 3: # single-frame model: (H, W, 3) + self._present(arr, batch_dt) + return + # Multi-frame model: (T, H, W, 3) + step_ms = max(0, int(batch_dt * 1000 / arr.shape[0])) + for i, sub in enumerate(arr): + if i > 0 and step_ms: + pygame.time.wait(step_ms) + self._present(sub, batch_dt) + + def draw_pause(self) -> None: + """Redraw the cached last frame with a dimmed overlay and centered pause text.""" + assert self._last_surface is not None + self.screen.blit(self._last_surface, (0, 0)) + dim = pygame.Surface(self.screen.get_size(), pygame.SRCALPHA) + dim.fill((0, 0, 0, 128)) # 50% black + self.screen.blit(dim, (0, 0)) + label = self.font.render("Paused — click to resume", True, (255, 255, 255)) + rect = label.get_rect(center=self.screen.get_rect().center) + self.screen.blit(label, rect) + pygame.display.flip() -def draw_status(screen: pygame.Surface, font: pygame.font.Font, text: str) -> None: - """Clear to black and draw a status line in the bottom-left corner.""" - screen.fill((0, 0, 0)) - label = font.render(text, True, (220, 220, 220)) - screen.blit(label, (16, screen.get_height() - label.get_height() - 16)) - pygame.display.flip() + def draw_status(self, text: str) -> None: + """Clear to black and draw a status line in the bottom-left corner.""" + self.screen.fill((0, 0, 0)) + label = self.status_font.render(text, True, (220, 220, 220)) + self.screen.blit(label, (16, self.screen.get_height() - label.get_height() - 16)) + pygame.display.flip() # --- engine ------------------------------------------------------------------ @@ -221,15 +211,14 @@ class Engine: def __init__( self, - screen: pygame.Surface, - font: pygame.font.Font, + renderer: "Renderer", model_uri: str, quant: str | None, device: str, seed_path: str | None, ) -> None: - """Load model, seed, prime, compile-warmup. Shows status on *screen*.""" - draw_status(screen, font, "Loading model…") + """Load model, seed, prime, compile-warmup. Shows status via *renderer*.""" + renderer.draw_status("Loading model…") log.info("loading model %s (quant=%s, device=%s)", model_uri, quant, device) self.inner = WorldEngine(model_uri, quant=quant, device=device) log.info( @@ -237,16 +226,16 @@ def __init__( self.inner.model_cfg.model_type, self.inner.model_cfg.temporal_compression, ) - draw_status(screen, font, "Loading seed…") + renderer.draw_status("Loading seed…") self.seed = load_seed_from_path(seed_path) if seed_path else load_seed_from_github() self.model_uri = model_uri - draw_status(screen, font, "Priming engine…") + renderer.draw_status("Priming engine…") self.inner.reset() self._prime_seed() # The first gen_frame triggers torch.compile — the most expensive step. - draw_status(screen, font, "Warming up (torch.compile)…") + renderer.draw_status("Warming up (torch.compile)…") log.info("warming up torch.compile") w0 = time.perf_counter() self.pending = self.next_frame(ctrl=CtrlInput()).cpu() @@ -285,21 +274,17 @@ def reset(self) -> None: class GameState: """Mutable state shared between event handling and the generation loop.""" - screen: pygame.Surface + renderer: Renderer engine: Engine - hud_font: pygame.font.Font paused: bool = True held_vks: set[int] = field(default_factory=set) scroll: int = 0 batch_dt: float = 0.0 - last_surface: pygame.Surface | None = None def enter_pause(self) -> None: """Flush any in-flight batch and enter paused state.""" if self.engine.pending is not None: - self.last_surface = render( - self.screen, self.engine.pending, self.batch_dt, self.hud_font, self.engine.model_uri, - ) + self.renderer.render_frame(self.engine.pending, self.batch_dt) self.engine.pending = None self.paused = True pygame.event.set_grab(False) @@ -361,9 +346,7 @@ def process_events(self) -> bool: def gameplay( - screen: pygame.Surface, - font: pygame.font.Font, - hud_font: pygame.font.Font, + renderer: Renderer, clock: pygame.time.Clock, engine: Engine, mouse_sensitivity: float, @@ -374,11 +357,9 @@ def gameplay( and returns immediately; we render the *previous* batch (with pacing sleeps) while the GPU works; then .cpu() syncs and transfers the result. """ - state = GameState( - screen=screen, engine=engine, hud_font=hud_font, - last_surface=render(screen, engine.pending, 0.0), - ) + renderer.render_frame(engine.pending, 0.0) engine.pending = None + state = GameState(renderer=renderer, engine=engine) log.info("ready") while True: @@ -386,8 +367,7 @@ def gameplay( return if state.paused: - assert state.last_surface is not None - draw_pause_overlay(screen, state.last_surface, font) + renderer.draw_pause() clock.tick(60) continue @@ -404,7 +384,7 @@ def gameplay( t0 = time.perf_counter() next_frames = engine.next_frame(ctrl=ctrl) if engine.pending is not None: - state.last_surface = render(screen, engine.pending, state.batch_dt, hud_font, engine.model_uri) + renderer.render_frame(engine.pending, state.batch_dt) engine.pending = next_frames.cpu() state.batch_dt = time.perf_counter() - t0 @@ -421,16 +401,12 @@ def main() -> None: args = ap.parse_args() pygame.init() - screen = pygame.display.set_mode(WINDOW_SIZE, pygame.RESIZABLE) - pygame.display.set_caption(args.model_uri) - font = pygame.font.SysFont(None, 36) - hud_font = pygame.font.SysFont(None, 22) - status_font = pygame.font.SysFont(None, 24) + renderer = Renderer(args.model_uri) clock = pygame.time.Clock() try: - engine = Engine(screen, status_font, args.model_uri, args.quant, args.device, args.seed) - gameplay(screen, font, hud_font, clock, engine, args.mouse_sensitivity) + engine = Engine(renderer, args.model_uri, args.quant, args.device, args.seed) + gameplay(renderer, clock, engine, args.mouse_sensitivity) except KeyboardInterrupt: pass finally: From fd4dad96efa34464c3bb914da096d00c84097b69 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 20:57:51 +0200 Subject: [PATCH 06/17] refactor(interactive): split up Engine so that it doesn't need to take renderer --- examples/interactive.py | 136 +++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 78 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 2ee7604..1bbac71 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -51,15 +51,7 @@ log = logging.getLogger("interactive") -# GitHub contents API for the Biome `seeds/` directory, pinned to a known ref. -# Same source as examples/gen_sample.py. -BIOME_SEEDS_API = ( - "https://api.github.com/repos/Overworldai/Biome/contents/seeds?ref=14343a6" -) - WINDOW_SIZE = (1280, 720) -# Aspect ratio for the center-crop applied to seed images. -CROP_ASPECT_W, CROP_ASPECT_H = 16, 9 # Map pygame keys / mouse buttons to the Windows VK integers that CtrlInput.button # expects (see https://github.com/Overworldai/owl-control keycode table). Covers @@ -82,38 +74,6 @@ MOUSE_TO_VK: dict[int, int] = {1: 0x01, 2: 0x04, 3: 0x02} -# --- seed loading ----------------------------------------------------------- - -def center_crop(img: Image.Image) -> np.ndarray: - """Center-crop to CROP_ASPECT_W:CROP_ASPECT_H. Returns uint8 (H, W, 3).""" - w, h = img.size - # Pick whichever dimension is the limiting factor and derive the other. - if w * CROP_ASPECT_H > h * CROP_ASPECT_W: - new_w, new_h = h * CROP_ASPECT_W // CROP_ASPECT_H, h - else: - new_w, new_h = w, w * CROP_ASPECT_H // CROP_ASPECT_W - left = (w - new_w) // 2 - top = (h - new_h) // 2 - # `.copy()` — PIL's buffer is read-only and torch.from_numpy requires writable. - return np.asarray(img.crop((left, top, left + new_w, top + new_h)).convert("RGB")).copy() - - -def load_seed_from_path(path: str) -> np.ndarray: - """Load a local image as uint8 (H, W, 3), center-cropped.""" - log.info("loading seed from local file: %s", path) - return center_crop(Image.open(path)) - - -def load_seed_from_github() -> np.ndarray: - """Download a random seed from the pinned Biome `seeds/` directory.""" - log.info("fetching Biome seeds index") - with urllib.request.urlopen(BIOME_SEEDS_API) as res: - entries = [e for e in json.load(res) if e["type"] == "file"] - url = random.choice(entries)["download_url"] - log.info("downloading random Biome seed: %s", url) - with urllib.request.urlopen(url) as res: - img_bytes = res.read() - return center_crop(Image.open(io.BytesIO(img_bytes))) @@ -199,50 +159,42 @@ def draw_status(self, text: str) -> None: class Engine: """Wraps WorldEngine with seed management and the generation pipeline. - After construction, the first generated frame is available as `self.pending`. - Subsequent frames are produced by `next_frame()` and should be `.cpu()`'d - into `self.pending` by the caller before the next `next_frame()` call. + Frames are produced by `next_frame()` and should be `.cpu()`'d into + `self.pending` by the caller before the next `next_frame()` call. """ - inner: WorldEngine - seed: np.ndarray - model_uri: str - pending: torch.Tensor | None - - def __init__( - self, - renderer: "Renderer", - model_uri: str, - quant: str | None, - device: str, - seed_path: str | None, - ) -> None: - """Load model, seed, prime, compile-warmup. Shows status via *renderer*.""" - renderer.draw_status("Loading model…") + def __init__(self, model_uri: str, quant: str | None, device: str) -> None: log.info("loading model %s (quant=%s, device=%s)", model_uri, quant, device) self.inner = WorldEngine(model_uri, quant=quant, device=device) + self.model_uri = model_uri + self.seed: np.ndarray | None = None + self.pending: torch.Tensor | None = None log.info( "model loaded: type=%s, temporal_compression=%d", self.inner.model_cfg.model_type, self.inner.model_cfg.temporal_compression, ) - renderer.draw_status("Loading seed…") - self.seed = load_seed_from_path(seed_path) if seed_path else load_seed_from_github() - self.model_uri = model_uri - - renderer.draw_status("Priming engine…") + def set_seed(self, img: Image.Image) -> None: + """Center-crop the image to the expected aspect ratio and store as seed.""" + # TODO: Calculate aspect ratio automatically once + # https://github.com/Overworldai/world_engine/issues/43 lands + crop_w, crop_h = 16, 9 + w, h = img.size + if w * crop_h > h * crop_w: + new_w, new_h = h * crop_w // crop_h, h + else: + new_w, new_h = w, w * crop_h // crop_w + left = (w - new_w) // 2 + top = (h - new_h) // 2 + cropped = img.crop((left, top, left + new_w, top + new_h)).convert("RGB") + # .copy() — PIL's buffer is read-only and torch.from_numpy requires writable. + self.seed = np.asarray(cropped).copy() + + def prime(self) -> None: + """Reset state and encode the seed frame into the KV cache.""" + assert self.seed is not None, "call set_seed() first" + self.pending = None self.inner.reset() - self._prime_seed() - - # The first gen_frame triggers torch.compile — the most expensive step. - renderer.draw_status("Warming up (torch.compile)…") - log.info("warming up torch.compile") - w0 = time.perf_counter() - self.pending = self.next_frame(ctrl=CtrlInput()).cpu() - log.info("warmup complete in %.1fs", time.perf_counter() - w0) - - def _prime_seed(self) -> None: - """Encode the seed frame into the KV cache.""" t = torch.from_numpy(self.seed).to(self.inner.device) # uint8 (H, W, 3) tc = self.inner.model_cfg.temporal_compression if tc > 1: @@ -252,6 +204,13 @@ def _prime_seed(self) -> None: log.info("priming engine with seed shape=%s", tuple(t.shape)) self.inner.append_frame(t) + def warmup(self) -> None: + """Run one gen_frame to trigger torch.compile. Result stored in `self.pending`.""" + log.info("warming up torch.compile") + w0 = time.perf_counter() + self.pending = self.next_frame(ctrl=CtrlInput()).cpu() + log.info("warmup complete in %.1fs", time.perf_counter() - w0) + def next_frame(self, ctrl: CtrlInput) -> torch.Tensor: """Generate the next frame. Returns a GPU tensor; caller must .cpu() before the next call.""" return self.inner.gen_frame(ctrl=ctrl) @@ -262,9 +221,7 @@ def reset(self) -> None: reset() clears the KV cache and all state, so the model must be re-seeded with append_frame before it can produce coherent output. """ - self.pending = None - self.inner.reset() - self._prime_seed() + self.prime() # --- gameplay ---------------------------------------------------------------- @@ -391,6 +348,22 @@ def gameplay( # --- entry point ------------------------------------------------------------- +def get_seed(path: str | None) -> Image.Image: + """Load a seed image from a local path, or download a random one from Biome.""" + if path is not None: + log.info("loading seed from local file: %s", path) + return Image.open(path) + # GitHub contents API for the Biome `seeds/` directory, pinned to a known ref. + biome_api = "https://api.github.com/repos/Overworldai/Biome/contents/seeds?ref=14343a6" + log.info("fetching Biome seeds index") + with urllib.request.urlopen(biome_api) as res: + entries = [e for e in json.load(res) if e["type"] == "file"] + url = random.choice(entries)["download_url"] + log.info("downloading random Biome seed: %s", url) + with urllib.request.urlopen(url) as res: + return Image.open(io.BytesIO(res.read())) + + def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("model_uri", help="HF model id, e.g. Overworld/Waypoint-1.5-1B") @@ -405,7 +378,14 @@ def main() -> None: clock = pygame.time.Clock() try: - engine = Engine(renderer, args.model_uri, args.quant, args.device, args.seed) + renderer.draw_status("Loading model…") + engine = Engine(args.model_uri, args.quant, args.device) + renderer.draw_status("Loading seed…") + engine.set_seed(get_seed(args.seed)) + renderer.draw_status("Priming engine…") + engine.prime() + renderer.draw_status("Warming up (torch.compile)…") + engine.warmup() gameplay(renderer, clock, engine, args.mouse_sensitivity) except KeyboardInterrupt: pass From f1c2b3c5459f75c816925afcf53989412a3a09eb Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 21:01:21 +0200 Subject: [PATCH 07/17] refactor(interactive): move vk constants into process_events --- examples/interactive.py | 48 ++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 1bbac71..f3f3134 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -53,26 +53,6 @@ WINDOW_SIZE = (1280, 720) -# Map pygame keys / mouse buttons to the Windows VK integers that CtrlInput.button -# expects (see https://github.com/Overworldai/owl-control keycode table). Covers -# the main ANSI rows, space, shift, and the three mouse buttons — enough for -# WASD / spacebar / look-around gameplay without being exhaustive. -# Uses `pygame.K_*` int constants directly so this dict can be built at import -# time (before `pygame.init()`). -PYGAME_TO_VK: dict[int, int] = ( - {getattr(pygame, f"K_{ch}"): ord(ch) for ch in "1234567890"} - | {pygame.K_MINUS: 0xBD, pygame.K_EQUALS: 0xBB} - | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "qwertyuiop"} - | {pygame.K_LEFTBRACKET: 0xDB, pygame.K_RIGHTBRACKET: 0xDD, pygame.K_BACKSLASH: 0xDC} - | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "asdfghjkl"} - | {pygame.K_SEMICOLON: 0xBA, pygame.K_QUOTE: 0xDE} - | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "zxcvbnm"} - | {pygame.K_COMMA: 0xBC, pygame.K_PERIOD: 0xBE, pygame.K_SLASH: 0xBF} - | {pygame.K_SPACE: 0x20, pygame.K_LSHIFT: 0x10, pygame.K_RSHIFT: 0x10} -) -# pygame mouse button ids: 1=left, 2=middle, 3=right. VK: 0x01 LBUTTON, 0x04 MBUTTON, 0x02 RBUTTON. -MOUSE_TO_VK: dict[int, int] = {1: 0x01, 2: 0x04, 3: 0x02} - @@ -256,8 +236,26 @@ def exit_pause(self) -> None: def process_events(self) -> bool: """Drain pygame events and update state. Returns False to quit.""" - self.scroll = 0 + # Map pygame keys / mouse buttons to the Windows VK integers that + # CtrlInput.button expects (see https://github.com/Overworldai/owl-control + # keycode table). Covers the main ANSI rows, space, shift, and three + # mouse buttons — enough for WASD / spacebar / look-around gameplay. + key_to_vk: dict[int, int] = ( + {getattr(pygame, f"K_{ch}"): ord(ch) for ch in "1234567890"} + | {pygame.K_MINUS: 0xBD, pygame.K_EQUALS: 0xBB} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "qwertyuiop"} + | {pygame.K_LEFTBRACKET: 0xDB, pygame.K_RIGHTBRACKET: 0xDD, pygame.K_BACKSLASH: 0xDC} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "asdfghjkl"} + | {pygame.K_SEMICOLON: 0xBA, pygame.K_QUOTE: 0xDE} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "zxcvbnm"} + | {pygame.K_COMMA: 0xBC, pygame.K_PERIOD: 0xBE, pygame.K_SLASH: 0xBF} + | {pygame.K_SPACE: 0x20, pygame.K_LSHIFT: 0x10, pygame.K_RSHIFT: 0x10} + ) + # pygame mouse button ids: 1=left, 2=middle, 3=right. + # VK: 0x01 LBUTTON, 0x04 MBUTTON, 0x02 RBUTTON. + mouse_to_vk: dict[int, int] = {1: 0x01, 2: 0x04, 3: 0x02} + self.scroll = 0 for e in pygame.event.get(): if e.type == pygame.QUIT: return False @@ -273,12 +271,12 @@ def process_events(self) -> bool: elif e.key == pygame.K_u and not self.paused: self.engine.reset() else: - vk = PYGAME_TO_VK.get(e.key) + vk = key_to_vk.get(e.key) if vk is not None: self.held_vks.add(vk) elif e.type == pygame.KEYUP: - vk = PYGAME_TO_VK.get(e.key) + vk = key_to_vk.get(e.key) if vk is not None: self.held_vks.discard(vk) @@ -286,7 +284,7 @@ def process_events(self) -> bool: if self.paused and e.button == 1: self.exit_pause() else: - vk = MOUSE_TO_VK.get(e.button) + vk = mouse_to_vk.get(e.button) if vk is not None: self.held_vks.add(vk) if e.button == 4: @@ -295,7 +293,7 @@ def process_events(self) -> bool: self.scroll -= 1 elif e.type == pygame.MOUSEBUTTONUP: - vk = MOUSE_TO_VK.get(e.button) + vk = mouse_to_vk.get(e.button) if vk is not None: self.held_vks.discard(vk) From a320407bc83df865af4b3d67272052ca8525d0fa Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 21:06:29 +0200 Subject: [PATCH 08/17] refactor(interactive): top-level GameState creation + gameplay -> run method --- examples/interactive.py | 96 ++++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index f3f3134..6aa2fbe 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -209,16 +209,18 @@ def reset(self) -> None: @dataclass class GameState: - """Mutable state shared between event handling and the generation loop.""" + """Interactive generation loop state. Call `run()` to enter the main loop.""" renderer: Renderer engine: Engine + clock: pygame.time.Clock + mouse_sensitivity: float paused: bool = True held_vks: set[int] = field(default_factory=set) scroll: int = 0 batch_dt: float = 0.0 - def enter_pause(self) -> None: + def _enter_pause(self) -> None: """Flush any in-flight batch and enter paused state.""" if self.engine.pending is not None: self.renderer.render_frame(self.engine.pending, self.batch_dt) @@ -227,14 +229,14 @@ def enter_pause(self) -> None: pygame.event.set_grab(False) pygame.mouse.set_visible(True) - def exit_pause(self) -> None: + def _exit_pause(self) -> None: """Re-grab the cursor and resume gameplay.""" self.paused = False pygame.event.set_grab(True) pygame.mouse.set_visible(False) pygame.mouse.get_rel() # discard accumulated delta during pause - def process_events(self) -> bool: + def _process_events(self) -> bool: """Drain pygame events and update state. Returns False to quit.""" # Map pygame keys / mouse buttons to the Windows VK integers that # CtrlInput.button expects (see https://github.com/Overworldai/owl-control @@ -263,11 +265,11 @@ def process_events(self) -> bool: # Auto-pause when the cursor leaves the window. Safety net for # WMs where `set_grab` is advisory and the cursor can escape. elif e.type == pygame.WINDOWLEAVE and not self.paused: - self.enter_pause() + self._enter_pause() elif e.type == pygame.KEYDOWN: if e.key == pygame.K_ESCAPE and not self.paused: - self.enter_pause() + self._enter_pause() elif e.key == pygame.K_u and not self.paused: self.engine.reset() else: @@ -282,7 +284,7 @@ def process_events(self) -> bool: elif e.type == pygame.MOUSEBUTTONDOWN: if self.paused and e.button == 1: - self.exit_pause() + self._exit_pause() else: vk = mouse_to_vk.get(e.button) if vk is not None: @@ -299,49 +301,43 @@ def process_events(self) -> bool: return True + def run(self) -> None: + """Interactive generation loop. Starts auto-paused on the first frame. -def gameplay( - renderer: Renderer, - clock: pygame.time.Clock, - engine: Engine, - mouse_sensitivity: float, -) -> None: - """Interactive generation loop. Starts auto-paused on the first frame. - - Uses the pipelining pattern from the README: gen_frame() queues GPU kernels - and returns immediately; we render the *previous* batch (with pacing sleeps) - while the GPU works; then .cpu() syncs and transfers the result. - """ - renderer.render_frame(engine.pending, 0.0) - engine.pending = None - state = GameState(renderer=renderer, engine=engine) - log.info("ready") - - while True: - if not state.process_events(): - return - - if state.paused: - renderer.draw_pause() - clock.tick(60) - continue - - dx, dy = pygame.mouse.get_rel() - ctrl = CtrlInput( - button=set(state.held_vks), - mouse=(dx * mouse_sensitivity, dy * mouse_sensitivity), - scroll_wheel=state.scroll, - ) - - # Pipeline: kick off generation (GPU kernels queued, returns fast), - # then render the *previous* batch while the GPU works. Finally .cpu() - # syncs and transfers the just-computed batch to CPU. - t0 = time.perf_counter() - next_frames = engine.next_frame(ctrl=ctrl) - if engine.pending is not None: - renderer.render_frame(engine.pending, state.batch_dt) - engine.pending = next_frames.cpu() - state.batch_dt = time.perf_counter() - t0 + Uses the pipelining pattern from the README: gen_frame() queues GPU + kernels and returns immediately; we render the *previous* batch (with + pacing sleeps) while the GPU works; then .cpu() syncs and transfers + the result. + """ + self.renderer.render_frame(self.engine.pending, 0.0) + self.engine.pending = None + log.info("ready") + + while True: + if not self._process_events(): + return + + if self.paused: + self.renderer.draw_pause() + self.clock.tick(60) + continue + + dx, dy = pygame.mouse.get_rel() + ctrl = CtrlInput( + button=set(self.held_vks), + mouse=(dx * self.mouse_sensitivity, dy * self.mouse_sensitivity), + scroll_wheel=self.scroll, + ) + + # Pipeline: kick off generation (GPU kernels queued, returns fast), + # then render the *previous* batch while the GPU works. Finally + # .cpu() syncs and transfers the just-computed batch to CPU. + t0 = time.perf_counter() + next_frames = self.engine.next_frame(ctrl=ctrl) + if self.engine.pending is not None: + self.renderer.render_frame(self.engine.pending, self.batch_dt) + self.engine.pending = next_frames.cpu() + self.batch_dt = time.perf_counter() - t0 # --- entry point ------------------------------------------------------------- @@ -384,7 +380,7 @@ def main() -> None: engine.prime() renderer.draw_status("Warming up (torch.compile)…") engine.warmup() - gameplay(renderer, clock, engine, args.mouse_sensitivity) + GameState(renderer, engine, clock, args.mouse_sensitivity).run() except KeyboardInterrupt: pass finally: From b9d0543c428329144d0146b5afef14a3d5f177f9 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 21:15:52 +0200 Subject: [PATCH 09/17] refactor(interactive): move pending into GameState --- examples/interactive.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 6aa2fbe..d1cb51d 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -139,8 +139,8 @@ def draw_status(self, text: str) -> None: class Engine: """Wraps WorldEngine with seed management and the generation pipeline. - Frames are produced by `next_frame()` and should be `.cpu()`'d into - `self.pending` by the caller before the next `next_frame()` call. + Frames are produced by `next_frame()` and should be `.cpu()`'d by the + caller before the next `next_frame()` call (GPU buffers may be reused). """ def __init__(self, model_uri: str, quant: str | None, device: str) -> None: @@ -148,7 +148,6 @@ def __init__(self, model_uri: str, quant: str | None, device: str) -> None: self.inner = WorldEngine(model_uri, quant=quant, device=device) self.model_uri = model_uri self.seed: np.ndarray | None = None - self.pending: torch.Tensor | None = None log.info( "model loaded: type=%s, temporal_compression=%d", self.inner.model_cfg.model_type, self.inner.model_cfg.temporal_compression, @@ -173,7 +172,6 @@ def set_seed(self, img: Image.Image) -> None: def prime(self) -> None: """Reset state and encode the seed frame into the KV cache.""" assert self.seed is not None, "call set_seed() first" - self.pending = None self.inner.reset() t = torch.from_numpy(self.seed).to(self.inner.device) # uint8 (H, W, 3) tc = self.inner.model_cfg.temporal_compression @@ -184,12 +182,13 @@ def prime(self) -> None: log.info("priming engine with seed shape=%s", tuple(t.shape)) self.inner.append_frame(t) - def warmup(self) -> None: - """Run one gen_frame to trigger torch.compile. Result stored in `self.pending`.""" + def warmup(self) -> torch.Tensor: + """Run one gen_frame to trigger torch.compile. Returns the first frame (CPU).""" log.info("warming up torch.compile") w0 = time.perf_counter() - self.pending = self.next_frame(ctrl=CtrlInput()).cpu() + first = self.next_frame(ctrl=CtrlInput()).cpu() log.info("warmup complete in %.1fs", time.perf_counter() - w0) + return first def next_frame(self, ctrl: CtrlInput) -> torch.Tensor: """Generate the next frame. Returns a GPU tensor; caller must .cpu() before the next call.""" @@ -218,13 +217,16 @@ class GameState: paused: bool = True held_vks: set[int] = field(default_factory=set) scroll: int = 0 + # Pipeline state: the not-yet-rendered CPU frame from the previous + # gen_frame call. Rendered while the GPU computes the next one. + pending: torch.Tensor | None = None batch_dt: float = 0.0 def _enter_pause(self) -> None: """Flush any in-flight batch and enter paused state.""" - if self.engine.pending is not None: - self.renderer.render_frame(self.engine.pending, self.batch_dt) - self.engine.pending = None + if self.pending is not None: + self.renderer.render_frame(self.pending, self.batch_dt) + self.pending = None self.paused = True pygame.event.set_grab(False) pygame.mouse.set_visible(True) @@ -271,6 +273,7 @@ def _process_events(self) -> bool: if e.key == pygame.K_ESCAPE and not self.paused: self._enter_pause() elif e.key == pygame.K_u and not self.paused: + self.pending = None self.engine.reset() else: vk = key_to_vk.get(e.key) @@ -309,8 +312,8 @@ def run(self) -> None: pacing sleeps) while the GPU works; then .cpu() syncs and transfers the result. """ - self.renderer.render_frame(self.engine.pending, 0.0) - self.engine.pending = None + self.renderer.render_frame(self.pending, 0.0) + self.pending = None log.info("ready") while True: @@ -334,9 +337,9 @@ def run(self) -> None: # .cpu() syncs and transfers the just-computed batch to CPU. t0 = time.perf_counter() next_frames = self.engine.next_frame(ctrl=ctrl) - if self.engine.pending is not None: - self.renderer.render_frame(self.engine.pending, self.batch_dt) - self.engine.pending = next_frames.cpu() + if self.pending is not None: + self.renderer.render_frame(self.pending, self.batch_dt) + self.pending = next_frames.cpu() self.batch_dt = time.perf_counter() - t0 @@ -379,8 +382,8 @@ def main() -> None: renderer.draw_status("Priming engine…") engine.prime() renderer.draw_status("Warming up (torch.compile)…") - engine.warmup() - GameState(renderer, engine, clock, args.mouse_sensitivity).run() + first = engine.warmup() + GameState(renderer, engine, clock, args.mouse_sensitivity, pending=first).run() except KeyboardInterrupt: pass finally: From bb5105b23d9ecc864985d3f9276a5912f85e597c Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 21:30:12 +0200 Subject: [PATCH 10/17] fix(interactive): show FPS+LFPS --- examples/interactive.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index d1cb51d..751d474 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -73,7 +73,7 @@ def __init__(self, model_uri: str) -> None: self.status_font = pygame.font.SysFont(None, 24) self._last_surface: pygame.Surface | None = None - def _present(self, frame: np.ndarray, batch_dt: float) -> None: + def _present(self, frame: np.ndarray, batch_dt: float, temporal_compression: int) -> None: """Blit a single (H, W, 3) frame, draw HUD, flip, and cache for pause.""" # pygame.surfarray expects (W, H, 3), so swap the first two axes. surf = pygame.surfarray.make_surface(frame.swapaxes(0, 1)) @@ -83,7 +83,12 @@ def _present(self, frame: np.ndarray, batch_dt: float) -> None: lines: list[tuple[str, tuple[int, int, int]]] = [] if batch_dt > 0: - lines.append((f"{1.0 / batch_dt:.1f} fps / {batch_dt * 1000:.1f} ms", (255, 255, 255))) + lfps = 1.0 / batch_dt + if temporal_compression > 1: + lines.append((f"{lfps * temporal_compression:.1f} FPS", (255, 255, 255))) + lines.append((f"{lfps:.1f} LFPS / {batch_dt * 1000:.1f} ms", (255, 255, 255))) + else: + lines.append((f"{lfps:.1f} FPS / {batch_dt * 1000:.1f} ms", (255, 255, 255))) lines.append((self.model_uri, (160, 160, 160))) for i, (text, color) in enumerate(lines): label = self.hud_font.render(text, True, color) @@ -93,7 +98,7 @@ def _present(self, frame: np.ndarray, batch_dt: float) -> None: pygame.display.flip() - def render_frame(self, frame_cpu: torch.Tensor, batch_dt: float) -> None: + def render_frame(self, frame_cpu: torch.Tensor, batch_dt: float, temporal_compression: int) -> None: """Display an already-on-CPU frame and cache it for the pause overlay. For multi-frame models the tensor is (T, H, W, 3) — we spread the T @@ -103,7 +108,7 @@ def render_frame(self, frame_cpu: torch.Tensor, batch_dt: float) -> None: """ arr = frame_cpu.numpy() if arr.ndim == 3: # single-frame model: (H, W, 3) - self._present(arr, batch_dt) + self._present(arr, batch_dt, temporal_compression) return # Multi-frame model: (T, H, W, 3) @@ -111,7 +116,7 @@ def render_frame(self, frame_cpu: torch.Tensor, batch_dt: float) -> None: for i, sub in enumerate(arr): if i > 0 and step_ms: pygame.time.wait(step_ms) - self._present(sub, batch_dt) + self._present(sub, batch_dt, temporal_compression) def draw_pause(self) -> None: """Redraw the cached last frame with a dimmed overlay and centered pause text.""" @@ -153,6 +158,10 @@ def __init__(self, model_uri: str, quant: str | None, device: str) -> None: self.inner.model_cfg.model_type, self.inner.model_cfg.temporal_compression, ) + @property + def temporal_compression(self) -> int: + return getattr(self.inner.model_cfg, "temporal_compression", 1) + def set_seed(self, img: Image.Image) -> None: """Center-crop the image to the expected aspect ratio and store as seed.""" # TODO: Calculate aspect ratio automatically once @@ -222,10 +231,14 @@ class GameState: pending: torch.Tensor | None = None batch_dt: float = 0.0 + @property + def temporal_compression(self) -> int: + return self.engine.temporal_compression + def _enter_pause(self) -> None: """Flush any in-flight batch and enter paused state.""" if self.pending is not None: - self.renderer.render_frame(self.pending, self.batch_dt) + self.renderer.render_frame(self.pending, self.batch_dt, self.temporal_compression) self.pending = None self.paused = True pygame.event.set_grab(False) @@ -312,7 +325,7 @@ def run(self) -> None: pacing sleeps) while the GPU works; then .cpu() syncs and transfers the result. """ - self.renderer.render_frame(self.pending, 0.0) + self.renderer.render_frame(self.pending, 0.0, self.temporal_compression) self.pending = None log.info("ready") @@ -338,7 +351,7 @@ def run(self) -> None: t0 = time.perf_counter() next_frames = self.engine.next_frame(ctrl=ctrl) if self.pending is not None: - self.renderer.render_frame(self.pending, self.batch_dt) + self.renderer.render_frame(self.pending, self.batch_dt, self.temporal_compression) self.pending = next_frames.cpu() self.batch_dt = time.perf_counter() - t0 From a3d1917812bb08475bb1a37a833e405bd793412c Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 21:38:25 +0200 Subject: [PATCH 11/17] refactor(interactive): dataclass fields + docs --- examples/interactive.py | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 751d474..3547601 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -33,7 +33,7 @@ import random import time import urllib.request -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass, field import numpy as np import pygame @@ -61,17 +61,24 @@ # --- rendering -------------------------------------------------------------- +@dataclass class Renderer: """Owns the screen surface and fonts; handles all drawing.""" - def __init__(self, model_uri: str) -> None: + model_uri: str + screen: pygame.Surface = field(init=False) + font: pygame.font.Font = field(init=False) + hud_font: pygame.font.Font = field(init=False) + status_font: pygame.font.Font = field(init=False) + _last_surface: pygame.Surface | None = field(init=False, default=None, repr=False) + """Cached surface of the most recently rendered frame, used for the pause overlay.""" + + def __post_init__(self) -> None: self.screen = pygame.display.set_mode(WINDOW_SIZE, pygame.RESIZABLE) - pygame.display.set_caption(model_uri) - self.model_uri = model_uri + pygame.display.set_caption(self.model_uri) self.font = pygame.font.SysFont(None, 36) self.hud_font = pygame.font.SysFont(None, 22) self.status_font = pygame.font.SysFont(None, 24) - self._last_surface: pygame.Surface | None = None def _present(self, frame: np.ndarray, batch_dt: float, temporal_compression: int) -> None: """Blit a single (H, W, 3) frame, draw HUD, flip, and cache for pause.""" @@ -141,6 +148,7 @@ def draw_status(self, text: str) -> None: # --- engine ------------------------------------------------------------------ +@dataclass class Engine: """Wraps WorldEngine with seed management and the generation pipeline. @@ -148,11 +156,16 @@ class Engine: caller before the next `next_frame()` call (GPU buffers may be reused). """ - def __init__(self, model_uri: str, quant: str | None, device: str) -> None: - log.info("loading model %s (quant=%s, device=%s)", model_uri, quant, device) - self.inner = WorldEngine(model_uri, quant=quant, device=device) - self.model_uri = model_uri - self.seed: np.ndarray | None = None + model_uri: str + quant: InitVar[str | None] + device: InitVar[str] + inner: WorldEngine = field(init=False, repr=False) + seed: np.ndarray | None = field(init=False, default=None, repr=False) + """Center-cropped uint8 (H, W, 3) numpy array, set via `set_seed()`.""" + + def __post_init__(self, quant: str | None, device: str) -> None: + log.info("loading model %s (quant=%s, device=%s)", self.model_uri, quant, device) + self.inner = WorldEngine(self.model_uri, quant=quant, device=device) log.info( "model loaded: type=%s, temporal_compression=%d", self.inner.model_cfg.model_type, self.inner.model_cfg.temporal_compression, @@ -225,11 +238,14 @@ class GameState: mouse_sensitivity: float paused: bool = True held_vks: set[int] = field(default_factory=set) + """Currently pressed Windows VK codes, forwarded as `CtrlInput.button`.""" scroll: int = 0 - # Pipeline state: the not-yet-rendered CPU frame from the previous - # gen_frame call. Rendered while the GPU computes the next one. pending: torch.Tensor | None = None + """Not-yet-rendered CPU frame from the previous `gen_frame` call. + Rendered while the GPU computes the next batch (pipeline overlap).""" batch_dt: float = 0.0 + """Wall-clock seconds the last `gen_frame` + `.cpu()` cycle took. Used to + pace multi-frame sub-frames evenly across the generation interval.""" @property def temporal_compression(self) -> int: From 9d5e9d4eda21e591a08003944a6da53c58930952 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 23:36:54 +0200 Subject: [PATCH 12/17] fix(interactive): minor input tweaks --- examples/interactive.py | 53 ++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 3547601..f9d6f36 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -34,6 +34,7 @@ import time import urllib.request from dataclasses import InitVar, dataclass, field +from typing import ClassVar import numpy as np import pygame @@ -267,27 +268,26 @@ def _exit_pause(self) -> None: pygame.mouse.set_visible(False) pygame.mouse.get_rel() # discard accumulated delta during pause + # Map pygame keys / mouse buttons to the Windows VK integers that + # CtrlInput.button expects (see https://github.com/Overworldai/owl-control + # keycode table). Covers the main ANSI rows, space, shift, and three + # mouse buttons — enough for WASD / spacebar / look-around gameplay. + _KEY_TO_VK: ClassVar[dict[int, int]] = ( + {getattr(pygame, f"K_{ch}"): ord(ch) for ch in "1234567890"} + | {pygame.K_MINUS: 0xBD, pygame.K_EQUALS: 0xBB} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "qwertyuiop"} + | {pygame.K_LEFTBRACKET: 0xDB, pygame.K_RIGHTBRACKET: 0xDD, pygame.K_BACKSLASH: 0xDC} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "asdfghjkl"} + | {pygame.K_SEMICOLON: 0xBA, pygame.K_QUOTE: 0xDE} + | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "zxcvbnm"} + | {pygame.K_COMMA: 0xBC, pygame.K_PERIOD: 0xBE, pygame.K_SLASH: 0xBF} + | {pygame.K_SPACE: 0x20, pygame.K_LSHIFT: 0x10, pygame.K_RSHIFT: 0x10} + ) + _MOUSE_TO_VK: ClassVar[dict[int, int]] = {1: 0x01, 2: 0x04, 3: 0x02} + """pygame button ids 1/2/3 → VK 0x01 LBUTTON / 0x04 MBUTTON / 0x02 RBUTTON.""" + def _process_events(self) -> bool: """Drain pygame events and update state. Returns False to quit.""" - # Map pygame keys / mouse buttons to the Windows VK integers that - # CtrlInput.button expects (see https://github.com/Overworldai/owl-control - # keycode table). Covers the main ANSI rows, space, shift, and three - # mouse buttons — enough for WASD / spacebar / look-around gameplay. - key_to_vk: dict[int, int] = ( - {getattr(pygame, f"K_{ch}"): ord(ch) for ch in "1234567890"} - | {pygame.K_MINUS: 0xBD, pygame.K_EQUALS: 0xBB} - | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "qwertyuiop"} - | {pygame.K_LEFTBRACKET: 0xDB, pygame.K_RIGHTBRACKET: 0xDD, pygame.K_BACKSLASH: 0xDC} - | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "asdfghjkl"} - | {pygame.K_SEMICOLON: 0xBA, pygame.K_QUOTE: 0xDE} - | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "zxcvbnm"} - | {pygame.K_COMMA: 0xBC, pygame.K_PERIOD: 0xBE, pygame.K_SLASH: 0xBF} - | {pygame.K_SPACE: 0x20, pygame.K_LSHIFT: 0x10, pygame.K_RSHIFT: 0x10} - ) - # pygame mouse button ids: 1=left, 2=middle, 3=right. - # VK: 0x01 LBUTTON, 0x04 MBUTTON, 0x02 RBUTTON. - mouse_to_vk: dict[int, int] = {1: 0x01, 2: 0x04, 3: 0x02} - self.scroll = 0 for e in pygame.event.get(): if e.type == pygame.QUIT: @@ -305,12 +305,12 @@ def _process_events(self) -> bool: self.pending = None self.engine.reset() else: - vk = key_to_vk.get(e.key) + vk = self._KEY_TO_VK.get(e.key) if vk is not None: self.held_vks.add(vk) elif e.type == pygame.KEYUP: - vk = key_to_vk.get(e.key) + vk = self._KEY_TO_VK.get(e.key) if vk is not None: self.held_vks.discard(vk) @@ -318,16 +318,15 @@ def _process_events(self) -> bool: if self.paused and e.button == 1: self._exit_pause() else: - vk = mouse_to_vk.get(e.button) + vk = self._MOUSE_TO_VK.get(e.button) if vk is not None: self.held_vks.add(vk) - if e.button == 4: - self.scroll += 1 - elif e.button == 5: - self.scroll -= 1 + + elif e.type == pygame.MOUSEWHEEL: + self.scroll += e.y elif e.type == pygame.MOUSEBUTTONUP: - vk = mouse_to_vk.get(e.button) + vk = self._MOUSE_TO_VK.get(e.button) if vk is not None: self.held_vks.discard(vk) From c019ed50f8f6fefdd627ddd92a7c7c3d54803319 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 14 Apr 2026 23:57:39 +0200 Subject: [PATCH 13/17] fix(interactive): frame pacing - ensure full utilisation + matching of inference framerate + consistent dispatch --- examples/interactive.py | 124 ++++++++++++++++++++++++++++++---------- 1 file changed, 95 insertions(+), 29 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index f9d6f36..29efdc5 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -23,8 +23,31 @@ # Close window / Ctrl+C : quit # # Supports both Waypoint-1 / 1.1 (single-frame output) and Waypoint-1.5 -# (4-frame temporally-compressed output). The only model-dependent branches -# live in `prime_seed` and `render`, keyed off `engine.model_cfg.model_type`. +# (4-frame temporally-compressed output). +# +# Frame pacing strategy +# --------------------- +# gen_frame() dispatches GPU kernels and returns a not-yet-ready tensor. We +# render the *previous* batch's sub-frames (with pacing sleeps) while the GPU +# computes, then call .cpu() to sync + transfer the result. +# +# For multi-frame models (temporal_compression > 1), each gen_frame produces T +# sub-frames that must be spread over a pacing interval. The interval is: +# +# pace_s = max(batch_dt * SLEEP_RATIO, target_s - overhead) +# +# where target_s = T / inference_fps (the model's intended visual frame time), +# batch_dt is the previous cycle's wall-clock time, and overhead is the +# measured non-render portion of the cycle (dispatch + .cpu() + events). +# +# - The `target_s - overhead` term ensures the *total* cycle (render + overhead) +# hits the model's target framerate. +# - The `batch_dt * SLEEP_RATIO` floor prevents the render from filling the +# entire cycle, which would create a diverging feedback loop (batch_dt +# includes render time, so pacing to 100% of it grows without bound). +# - Each sub-frame is presented immediately, then we sleep for SLEEP_RATIO of +# the remaining time until the next deadline (yielding CPU), then busy-wait +# the rest for precise timing. See the SLEEP_RATIO constant for details. import argparse import io @@ -53,6 +76,12 @@ WINDOW_SIZE = (1280, 720) +# Fraction of a sleep interval that is yielded to the OS via pygame.time.wait; +# the remainder is covered by a busy-wait spin for precise timing. OS schedulers +# are free to extend sleeps beyond their stated duration, so we deliberately +# undershoot and spin the rest. 0.8 was selected ad-hoc and could be raised +# toward 1.0 on platforms with more accurate sleep granularity. +SLEEP_RATIO = 0.8 @@ -106,25 +135,33 @@ def _present(self, frame: np.ndarray, batch_dt: float, temporal_compression: int pygame.display.flip() - def render_frame(self, frame_cpu: torch.Tensor, batch_dt: float, temporal_compression: int) -> None: + def render_frame( + self, + frame_cpu: torch.Tensor, + batch_dt: float, + temporal_compression: int, + pace_s: float, + ) -> None: """Display an already-on-CPU frame and cache it for the pause overlay. - For multi-frame models the tensor is (T, H, W, 3) — we spread the T - sub-frames evenly across `batch_dt` (per README "Waypoint-1.5 Behavior"). - The sleeps are what let the pipeline overlap: while we pace here, the - GPU is already computing the next batch. + Sub-frames are spread evenly across `pace_s`. The caller computes + `pace_s` to balance GPU overlap headroom with the FPS cap. """ arr = frame_cpu.numpy() - if arr.ndim == 3: # single-frame model: (H, W, 3) - self._present(arr, batch_dt, temporal_compression) - return - - # Multi-frame model: (T, H, W, 3) - step_ms = max(0, int(batch_dt * 1000 / arr.shape[0])) - for i, sub in enumerate(arr): - if i > 0 and step_ms: - pygame.time.wait(step_ms) + # Treat single-frame (H, W, 3) as a batch of one. + frames = [arr] if arr.ndim == 3 else list(arr) + step_s = pace_s / len(frames) + start = time.perf_counter() + for i, sub in enumerate(frames): self._present(sub, batch_dt, temporal_compression) + # Hybrid sleep+spin: yield CPU for SLEEP_RATIO of the remaining + # time, then busy-wait the rest for precise timing. + deadline = start + step_s * (i + 1) + remaining_ms = int((deadline - time.perf_counter()) * SLEEP_RATIO * 1000) + if remaining_ms > 0: + pygame.time.wait(remaining_ms) + while time.perf_counter() < deadline: + pass def draw_pause(self) -> None: """Redraw the cached last frame with a dimmed overlay and centered pause text.""" @@ -176,6 +213,11 @@ def __post_init__(self, quant: str | None, device: str) -> None: def temporal_compression(self) -> int: return getattr(self.inner.model_cfg, "temporal_compression", 1) + @property + def inference_fps(self) -> int: + """Visual framerate the model targets.""" + return getattr(self.inner.model_cfg, "inference_fps", 60) + def set_seed(self, img: Image.Image) -> None: """Center-crop the image to the expected aspect ratio and store as seed.""" # TODO: Calculate aspect ratio automatically once @@ -245,17 +287,38 @@ class GameState: """Not-yet-rendered CPU frame from the previous `gen_frame` call. Rendered while the GPU computes the next batch (pipeline overlap).""" batch_dt: float = 0.0 - """Wall-clock seconds the last `gen_frame` + `.cpu()` cycle took. Used to - pace multi-frame sub-frames evenly across the generation interval.""" + """Wall-clock seconds the last full cycle (dispatch + render + sync) took.""" + _overhead: float = 0.0 + """Non-render portion of the previous cycle (dispatch + .cpu() + events). + Subtracted from the FPS-cap target so pacing compensates for it.""" + _pace_s: float = 0.0 + """Pacing interval used by the last render_frame call.""" @property def temporal_compression(self) -> int: return self.engine.temporal_compression + @property + def inference_fps(self) -> int: + return self.engine.inference_fps + + def _compute_pace(self) -> float: + """Compute pacing interval for render_frame, accounting for overhead.""" + # Target interval from the model's intended visual framerate. + target_s = ( + self.temporal_compression / self.inference_fps + if self.inference_fps > 0 + else 0.0 + ) + # Subtract measured non-render overhead so the *total* cycle hits target_s. + # Use SLEEP_RATIO of batch_dt as a floor so the render never fills the entire + # overlap window (batch_dt includes render time — pacing to 100% diverges). + return max(self.batch_dt * SLEEP_RATIO, target_s - self._overhead) + def _enter_pause(self) -> None: """Flush any in-flight batch and enter paused state.""" if self.pending is not None: - self.renderer.render_frame(self.pending, self.batch_dt, self.temporal_compression) + self.renderer.render_frame(self.pending, self.batch_dt, self.temporal_compression, self._compute_pace()) self.pending = None self.paused = True pygame.event.set_grab(False) @@ -335,12 +398,11 @@ def _process_events(self) -> bool: def run(self) -> None: """Interactive generation loop. Starts auto-paused on the first frame. - Uses the pipelining pattern from the README: gen_frame() queues GPU - kernels and returns immediately; we render the *previous* batch (with - pacing sleeps) while the GPU works; then .cpu() syncs and transfers - the result. + See "Frame pacing strategy" at the top of the file for the full + pipeline design. The first iteration after (un)pause has no previous + batch to render, so there is no GPU overlap on that frame. """ - self.renderer.render_frame(self.pending, 0.0, self.temporal_compression) + self.renderer.render_frame(self.pending, 0.0, self.temporal_compression, 0.0) self.pending = None log.info("ready") @@ -350,7 +412,7 @@ def run(self) -> None: if self.paused: self.renderer.draw_pause() - self.clock.tick(60) + self.clock.tick(self.inference_fps) continue dx, dy = pygame.mouse.get_rel() @@ -360,15 +422,19 @@ def run(self) -> None: scroll_wheel=self.scroll, ) - # Pipeline: kick off generation (GPU kernels queued, returns fast), - # then render the *previous* batch while the GPU works. Finally - # .cpu() syncs and transfers the just-computed batch to CPU. + # Pipeline (see "Frame pacing strategy" at top of file): + # 1. Dispatch gen_frame — GPU kernels are queued, returns fast. + # 2. Render the *previous* batch with pacing sleeps while GPU works. + # 3. .cpu() syncs the GPU and transfers the just-computed batch. + # 4. Measure overhead (non-render time) to feed back into pacing. t0 = time.perf_counter() next_frames = self.engine.next_frame(ctrl=ctrl) if self.pending is not None: - self.renderer.render_frame(self.pending, self.batch_dt, self.temporal_compression) + self._pace_s = self._compute_pace() + self.renderer.render_frame(self.pending, self.batch_dt, self.temporal_compression, self._pace_s) self.pending = next_frames.cpu() self.batch_dt = time.perf_counter() - t0 + self._overhead = self.batch_dt - self._pace_s # --- entry point ------------------------------------------------------------- From ab06b7a91eb2993b52ca2438f1a09ee56d236a97 Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 15 Apr 2026 00:08:36 +0200 Subject: [PATCH 14/17] fix(interactive): fmt+lint --- examples/interactive.py | 86 +++++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 21 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 29efdc5..7c0a92e 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -66,7 +66,6 @@ from world_engine import CtrlInput, WorldEngine - logging.basicConfig( level=logging.INFO, format="%(asctime)s %(message)s", @@ -84,10 +83,6 @@ SLEEP_RATIO = 0.8 - - - - # --- rendering -------------------------------------------------------------- @@ -101,7 +96,7 @@ class Renderer: hud_font: pygame.font.Font = field(init=False) status_font: pygame.font.Font = field(init=False) _last_surface: pygame.Surface | None = field(init=False, default=None, repr=False) - """Cached surface of the most recently rendered frame, used for the pause overlay.""" + """Cached surface of the most recently rendered frame, used for pause.""" def __post_init__(self) -> None: self.screen = pygame.display.set_mode(WINDOW_SIZE, pygame.RESIZABLE) @@ -110,7 +105,12 @@ def __post_init__(self) -> None: self.hud_font = pygame.font.SysFont(None, 22) self.status_font = pygame.font.SysFont(None, 24) - def _present(self, frame: np.ndarray, batch_dt: float, temporal_compression: int) -> None: + def _present( + self, + frame: np.ndarray, + batch_dt: float, + temporal_compression: int, + ) -> None: """Blit a single (H, W, 3) frame, draw HUD, flip, and cache for pause.""" # pygame.surfarray expects (W, H, 3), so swap the first two axes. surf = pygame.surfarray.make_surface(frame.swapaxes(0, 1)) @@ -122,10 +122,16 @@ def _present(self, frame: np.ndarray, batch_dt: float, temporal_compression: int if batch_dt > 0: lfps = 1.0 / batch_dt if temporal_compression > 1: - lines.append((f"{lfps * temporal_compression:.1f} FPS", (255, 255, 255))) - lines.append((f"{lfps:.1f} LFPS / {batch_dt * 1000:.1f} ms", (255, 255, 255))) + lines.append( + (f"{lfps * temporal_compression:.1f} FPS", (255, 255, 255)), + ) + lines.append( + (f"{lfps:.1f} LFPS / {batch_dt * 1000:.1f} ms", (255, 255, 255)), + ) else: - lines.append((f"{lfps:.1f} FPS / {batch_dt * 1000:.1f} ms", (255, 255, 255))) + lines.append( + (f"{lfps:.1f} FPS / {batch_dt * 1000:.1f} ms", (255, 255, 255)), + ) lines.append((self.model_uri, (160, 160, 160))) for i, (text, color) in enumerate(lines): label = self.hud_font.render(text, True, color) @@ -164,7 +170,7 @@ def render_frame( pass def draw_pause(self) -> None: - """Redraw the cached last frame with a dimmed overlay and centered pause text.""" + """Redraw cached last frame with a dimmed overlay and pause text.""" assert self._last_surface is not None self.screen.blit(self._last_surface, (0, 0)) dim = pygame.Surface(self.screen.get_size(), pygame.SRCALPHA) @@ -179,7 +185,10 @@ def draw_status(self, text: str) -> None: """Clear to black and draw a status line in the bottom-left corner.""" self.screen.fill((0, 0, 0)) label = self.status_font.render(text, True, (220, 220, 220)) - self.screen.blit(label, (16, self.screen.get_height() - label.get_height() - 16)) + self.screen.blit( + label, + (16, self.screen.get_height() - label.get_height() - 16), + ) pygame.display.flip() @@ -202,11 +211,17 @@ class Engine: """Center-cropped uint8 (H, W, 3) numpy array, set via `set_seed()`.""" def __post_init__(self, quant: str | None, device: str) -> None: - log.info("loading model %s (quant=%s, device=%s)", self.model_uri, quant, device) + log.info( + "loading model %s (quant=%s, device=%s)", + self.model_uri, + quant, + device, + ) self.inner = WorldEngine(self.model_uri, quant=quant, device=device) log.info( "model loaded: type=%s, temporal_compression=%d", - self.inner.model_cfg.model_type, self.inner.model_cfg.temporal_compression, + self.inner.model_cfg.model_type, + self.inner.model_cfg.temporal_compression, ) @property @@ -256,7 +271,10 @@ def warmup(self) -> torch.Tensor: return first def next_frame(self, ctrl: CtrlInput) -> torch.Tensor: - """Generate the next frame. Returns a GPU tensor; caller must .cpu() before the next call.""" + """Generate the next frame. Returns a GPU tensor. + + Caller must .cpu() before the next call (GPU buffers may be reused). + """ return self.inner.gen_frame(ctrl=ctrl) def reset(self) -> None: @@ -318,7 +336,12 @@ def _compute_pace(self) -> float: def _enter_pause(self) -> None: """Flush any in-flight batch and enter paused state.""" if self.pending is not None: - self.renderer.render_frame(self.pending, self.batch_dt, self.temporal_compression, self._compute_pace()) + self.renderer.render_frame( + self.pending, + self.batch_dt, + self.temporal_compression, + self._compute_pace(), + ) self.pending = None self.paused = True pygame.event.set_grab(False) @@ -339,7 +362,11 @@ def _exit_pause(self) -> None: {getattr(pygame, f"K_{ch}"): ord(ch) for ch in "1234567890"} | {pygame.K_MINUS: 0xBD, pygame.K_EQUALS: 0xBB} | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "qwertyuiop"} - | {pygame.K_LEFTBRACKET: 0xDB, pygame.K_RIGHTBRACKET: 0xDD, pygame.K_BACKSLASH: 0xDC} + | { + pygame.K_LEFTBRACKET: 0xDB, + pygame.K_RIGHTBRACKET: 0xDD, + pygame.K_BACKSLASH: 0xDC, + } | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "asdfghjkl"} | {pygame.K_SEMICOLON: 0xBA, pygame.K_QUOTE: 0xDE} | {getattr(pygame, f"K_{ch}"): ord(ch.upper()) for ch in "zxcvbnm"} @@ -431,7 +458,12 @@ def run(self) -> None: next_frames = self.engine.next_frame(ctrl=ctrl) if self.pending is not None: self._pace_s = self._compute_pace() - self.renderer.render_frame(self.pending, self.batch_dt, self.temporal_compression, self._pace_s) + self.renderer.render_frame( + self.pending, + self.batch_dt, + self.temporal_compression, + self._pace_s, + ) self.pending = next_frames.cpu() self.batch_dt = time.perf_counter() - t0 self._overhead = self.batch_dt - self._pace_s @@ -439,13 +471,16 @@ def run(self) -> None: # --- entry point ------------------------------------------------------------- + def get_seed(path: str | None) -> Image.Image: """Load a seed image from a local path, or download a random one from Biome.""" if path is not None: log.info("loading seed from local file: %s", path) return Image.open(path) # GitHub contents API for the Biome `seeds/` directory, pinned to a known ref. - biome_api = "https://api.github.com/repos/Overworldai/Biome/contents/seeds?ref=14343a6" + biome_api = ( + "https://api.github.com/repos/Overworldai/Biome/contents/seeds?ref=14343a6" + ) log.info("fetching Biome seeds index") with urllib.request.urlopen(biome_api) as res: entries = [e for e in json.load(res) if e["type"] == "file"] @@ -458,8 +493,17 @@ def get_seed(path: str | None) -> Image.Image: def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("model_uri", help="HF model id, e.g. Overworld/Waypoint-1.5-1B") - ap.add_argument("-s", "--seed", help="Path to a local seed image (defaults to a random Biome seed)") - ap.add_argument("-q", "--quant", choices=["intw8a8", "fp8w8a8", "nvfp4"], default=None) + ap.add_argument( + "-s", + "--seed", + help="Path to a local seed image (defaults to a random Biome seed)", + ) + ap.add_argument( + "-q", + "--quant", + choices=["intw8a8", "fp8w8a8", "nvfp4"], + default=None, + ) ap.add_argument("-d", "--device", default="cuda") ap.add_argument("-m", "--mouse-sensitivity", type=float, default=1.5) args = ap.parse_args() From 0d9c2aae90e77e2f1f1d625261f07fea4dec4523 Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 15 Apr 2026 00:20:37 +0200 Subject: [PATCH 15/17] feat(interactive): --uncap-fps --- examples/interactive.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index 7c0a92e..f171b78 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -36,12 +36,15 @@ # # pace_s = max(batch_dt * SLEEP_RATIO, target_s - overhead) # -# where target_s = T / inference_fps (the model's intended visual frame time), -# batch_dt is the previous cycle's wall-clock time, and overhead is the -# measured non-render portion of the cycle (dispatch + .cpu() + events). +# where: +# - batch_dt is the previous cycle's wall-clock time. +# - overhead is the measured non-render portion of the cycle (dispatch + +# .cpu() + events). +# - target_s = T / fps_cap. When fps_cap is 0 (--uncap-fps), target_s is 0 +# and pace_s falls back to `batch_dt * SLEEP_RATIO` — pure GPU-bound pacing. # # - The `target_s - overhead` term ensures the *total* cycle (render + overhead) -# hits the model's target framerate. +# hits the model's target framerate when a cap is active. # - The `batch_dt * SLEEP_RATIO` floor prevents the render from filling the # entire cycle, which would create a diverging feedback loop (batch_dt # includes render time, so pacing to 100% of it grows without bound). @@ -297,6 +300,7 @@ class GameState: engine: Engine clock: pygame.time.Clock mouse_sensitivity: float + uncap_fps: bool = False paused: bool = True held_vks: set[int] = field(default_factory=set) """Currently pressed Windows VK codes, forwarded as `CtrlInput.button`.""" @@ -317,17 +321,14 @@ def temporal_compression(self) -> int: return self.engine.temporal_compression @property - def inference_fps(self) -> int: - return self.engine.inference_fps + def fps_cap(self) -> int: + """0 = uncapped; otherwise the model's inference_fps.""" + return 0 if self.uncap_fps else self.engine.inference_fps def _compute_pace(self) -> float: """Compute pacing interval for render_frame, accounting for overhead.""" # Target interval from the model's intended visual framerate. - target_s = ( - self.temporal_compression / self.inference_fps - if self.inference_fps > 0 - else 0.0 - ) + target_s = self.temporal_compression / self.fps_cap if self.fps_cap > 0 else 0.0 # Subtract measured non-render overhead so the *total* cycle hits target_s. # Use SLEEP_RATIO of batch_dt as a floor so the render never fills the entire # overlap window (batch_dt includes render time — pacing to 100% diverges). @@ -439,7 +440,7 @@ def run(self) -> None: if self.paused: self.renderer.draw_pause() - self.clock.tick(self.inference_fps) + self.clock.tick(self.engine.inference_fps) continue dx, dy = pygame.mouse.get_rel() @@ -506,6 +507,11 @@ def main() -> None: ) ap.add_argument("-d", "--device", default="cuda") ap.add_argument("-m", "--mouse-sensitivity", type=float, default=1.5) + ap.add_argument( + "--uncap-fps", + action="store_true", + help="Disable the inference_fps framerate cap (run as fast as the GPU allows)", + ) args = ap.parse_args() pygame.init() @@ -521,7 +527,14 @@ def main() -> None: engine.prime() renderer.draw_status("Warming up (torch.compile)…") first = engine.warmup() - GameState(renderer, engine, clock, args.mouse_sensitivity, pending=first).run() + GameState( + renderer, + engine, + clock, + args.mouse_sensitivity, + uncap_fps=args.uncap_fps, + pending=first, + ).run() except KeyboardInterrupt: pass finally: From 2959aaca81ee8f869d817a059fb70f28e41c8590 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 12 May 2026 16:31:13 +0200 Subject: [PATCH 16/17] fix(interactive): use self.temporal_compression --- examples/interactive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/interactive.py b/examples/interactive.py index f171b78..f14e09e 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -257,7 +257,7 @@ def prime(self) -> None: assert self.seed is not None, "call set_seed() first" self.inner.reset() t = torch.from_numpy(self.seed).to(self.inner.device) # uint8 (H, W, 3) - tc = self.inner.model_cfg.temporal_compression + tc = self.temporal_compression if tc > 1: # Multi-frame models (e.g. Waypoint-1.5) consume/produce a stack of # `temporal_compression` frames per step. From 5bfc7737369839e62f833f499e812f7b372d812e Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 13 May 2026 19:04:13 +0200 Subject: [PATCH 17/17] refactor(interactive): trim some redundancies --- examples/interactive.py | 42 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/examples/interactive.py b/examples/interactive.py index f14e09e..a59b2f6 100644 --- a/examples/interactive.py +++ b/examples/interactive.py @@ -252,8 +252,8 @@ def set_seed(self, img: Image.Image) -> None: # .copy() — PIL's buffer is read-only and torch.from_numpy requires writable. self.seed = np.asarray(cropped).copy() - def prime(self) -> None: - """Reset state and encode the seed frame into the KV cache.""" + def reset(self) -> None: + """Clear KV cache and re-encode the seed frame.""" assert self.seed is not None, "call set_seed() first" self.inner.reset() t = torch.from_numpy(self.seed).to(self.inner.device) # uint8 (H, W, 3) @@ -262,7 +262,7 @@ def prime(self) -> None: # Multi-frame models (e.g. Waypoint-1.5) consume/produce a stack of # `temporal_compression` frames per step. t = t.unsqueeze(0).expand(tc, -1, -1, -1).contiguous() - log.info("priming engine with seed shape=%s", tuple(t.shape)) + log.info("resetting engine with seed shape=%s", tuple(t.shape)) self.inner.append_frame(t) def warmup(self) -> torch.Tensor: @@ -280,14 +280,6 @@ def next_frame(self, ctrl: CtrlInput) -> torch.Tensor: """ return self.inner.gen_frame(ctrl=ctrl) - def reset(self) -> None: - """Reset all state and re-prime the seed. - - reset() clears the KV cache and all state, so the model must be - re-seeded with append_frame before it can produce coherent output. - """ - self.prime() - # --- gameplay ---------------------------------------------------------------- @@ -298,10 +290,10 @@ class GameState: renderer: Renderer engine: Engine - clock: pygame.time.Clock mouse_sensitivity: float uncap_fps: bool = False paused: bool = True + clock: pygame.time.Clock = field(init=False, default_factory=pygame.time.Clock) held_vks: set[int] = field(default_factory=set) """Currently pressed Windows VK codes, forwarded as `CtrlInput.button`.""" scroll: int = 0 @@ -316,19 +308,11 @@ class GameState: _pace_s: float = 0.0 """Pacing interval used by the last render_frame call.""" - @property - def temporal_compression(self) -> int: - return self.engine.temporal_compression - - @property - def fps_cap(self) -> int: - """0 = uncapped; otherwise the model's inference_fps.""" - return 0 if self.uncap_fps else self.engine.inference_fps - def _compute_pace(self) -> float: """Compute pacing interval for render_frame, accounting for overhead.""" - # Target interval from the model's intended visual framerate. - target_s = self.temporal_compression / self.fps_cap if self.fps_cap > 0 else 0.0 + # Target interval from the model's intended visual framerate. 0 = uncapped. + fps_cap = 0 if self.uncap_fps else self.engine.inference_fps + target_s = self.engine.temporal_compression / fps_cap if fps_cap > 0 else 0.0 # Subtract measured non-render overhead so the *total* cycle hits target_s. # Use SLEEP_RATIO of batch_dt as a floor so the render never fills the entire # overlap window (batch_dt includes render time — pacing to 100% diverges). @@ -340,7 +324,7 @@ def _enter_pause(self) -> None: self.renderer.render_frame( self.pending, self.batch_dt, - self.temporal_compression, + self.engine.temporal_compression, self._compute_pace(), ) self.pending = None @@ -430,7 +414,7 @@ def run(self) -> None: pipeline design. The first iteration after (un)pause has no previous batch to render, so there is no GPU overlap on that frame. """ - self.renderer.render_frame(self.pending, 0.0, self.temporal_compression, 0.0) + self.renderer.render_frame(self.pending, 0.0, self.engine.temporal_compression, 0.0) self.pending = None log.info("ready") @@ -462,7 +446,7 @@ def run(self) -> None: self.renderer.render_frame( self.pending, self.batch_dt, - self.temporal_compression, + self.engine.temporal_compression, self._pace_s, ) self.pending = next_frames.cpu() @@ -516,21 +500,19 @@ def main() -> None: pygame.init() renderer = Renderer(args.model_uri) - clock = pygame.time.Clock() try: renderer.draw_status("Loading model…") engine = Engine(args.model_uri, args.quant, args.device) renderer.draw_status("Loading seed…") engine.set_seed(get_seed(args.seed)) - renderer.draw_status("Priming engine…") - engine.prime() + renderer.draw_status("Resetting engine…") + engine.reset() renderer.draw_status("Warming up (torch.compile)…") first = engine.warmup() GameState( renderer, engine, - clock, args.mouse_sensitivity, uncap_fps=args.uncap_fps, pending=first,