From 6198fa1d14e02b00b3a6d98814a39b887a40c27e Mon Sep 17 00:00:00 2001 From: MalarzDawid Date: Sun, 22 Feb 2026 11:54:45 +0100 Subject: [PATCH 01/48] Fix README examples and align default model URI in example scripts --- README.md | 45 +++++++++++++++++++++++++-------------- examples/gen_sample.py | 2 +- examples/prof.py | 2 +- examples/simple_client.py | 2 +- 4 files changed, 32 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index ba885b2..d6c8d07 100644 --- a/README.md +++ b/README.md @@ -57,24 +57,36 @@ export HF_TOKEN= #### Run ``` +import torch from world_engine import WorldEngine, CtrlInput -# Create inference engine -engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") +def main(): + # Create inference engine + engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") + + # Specify a prompt + engine.set_prompt("A fun game") + + # Optional: Force the next frame to be a specific image + # uint8_img = torch.randint(0, 256, (512, 512, 3), dtype=torch.uint8) + # img = engine.append_frame(uint8_img) # (H, W, 3) -# Specify a prompt -engine.set_prompt("A fun game") -# Optional: Force the next frame to be a specific image -img = pipeline.append_frame(uint8_img) # (H, W, 3) + # Generate 3 video frames conditioned on controller inputs + for controller_input in [ + CtrlInput(button={48, 42}, mouse=[0.4, 0.3]), + CtrlInput(mouse=[0.1, 0.2]), + CtrlInput(button={95, 32, 105}), + ]: + img = engine.gen_frame(ctrl=controller_input) -# Generate 3 video frames conditioned on controller inputs -for controller_input in [ - CtrlInput(button={48, 42}, mouse=[0.4, 0.3]), - CtrlInput(mouse=[0.1, 0.2]), - CtrlInput(button={95, 32, 105}), -]: - img = engine.gen_frame(ctrl=controller_input) +if __name__ == "__main__": + main() +``` + +``` +# Optional: install dependencies used by examples (OpenCV, pytest, benchmark plugin) +pip install --upgrade --ignore-installed "world_engine[examples] @ git+https://github.com/Overworldai/world_engine.git" ``` ## Usage @@ -92,7 +104,7 @@ Specify a prompt which will be used until this function is called again engine.set_prompt("A fun game") ``` -Generate a image conditioned on current controller input (explicit) and history / prompt (implicit) +Generate an image conditioned on current controller input (explicit) and history / prompt (implicit) ``` controller_input = CtrlInput(button={48, 42}, mouse=[0.4, 0.3]) img = engine.gen_frame(ctrl=controller_input) @@ -102,7 +114,7 @@ Instead of generating, **set** the next frame as a specific image. Typically don ``` # example: random noise image uint8_img = torch.randint(0, 256, (512, 512, 3), dtype=torch.uint8) -img = pipeline.append_frame(uint8_img) # returns passed image +img = engine.append_frame(uint8_img) # returns passed image ``` Note: returned `img` is always on the same device as `engine.device` @@ -118,7 +130,8 @@ Note: returned `img` is always on the same device as `engine.device` @dataclass class CtrlInput: button: Set[int] = field(default_factory=set) # pressed button IDs - mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) position + mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) velocity + scroll_wheel: int = 0 # bwd, stationary, or fwd -> (-1, 0, 1) ``` - `button` keycodes are defined by [Owl-Control](https://github.com/Overworldai/owl-control/blob/main/src/system/keycode.rs) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 62c27c5..517efab 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -3,7 +3,7 @@ def gen_vid(): - engine = WorldEngine("OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing", device="cuda") + engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") writer = None for _ in range(240): frame = engine.gen_frame().cpu().numpy()[:, :, ::-1] # RGB -> BGR for OpenCV diff --git a/examples/prof.py b/examples/prof.py index b8cd570..8055dea 100644 --- a/examples/prof.py +++ b/examples/prof.py @@ -5,7 +5,7 @@ def do_profile(n_frames=64, row_limit=20): - engine = WorldEngine("OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing", device="cuda") + engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") # warmup for _ in range(4): engine.gen_frame() diff --git a/examples/simple_client.py b/examples/simple_client.py index 2f5d410..db12d97 100644 --- a/examples/simple_client.py +++ b/examples/simple_client.py @@ -53,7 +53,7 @@ async def producer() -> None: async def main() -> None: - uri = sys.argv[1] if len(sys.argv) > 1 else "OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing" + uri = sys.argv[1] if len(sys.argv) > 1 else "Overworld/Waypoint-1-Small" engine = WorldEngine(uri, device="cuda") ctrls = ctrl_stream() frames = frame_stream(engine, ctrls) From 991a86140447c462a444625014cb05e6b13a28f6 Mon Sep 17 00:00:00 2001 From: MalarzDawid Date: Sun, 22 Feb 2026 11:56:28 +0100 Subject: [PATCH 02/48] Add examples extra dependencies for OpenCV and benchmark tooling --- pyproject.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d0fb905..c816fac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,10 @@ packages = [ [tool.setuptools.package-dir] world_engine = "src" + +[project.optional-dependencies] +examples = [ + "opencv-python>=4.9", + "pytest>=8", + "pytest-benchmark>=4.0", +] From 3ab6475cbc2e795da166143236ea7d30fb821327 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 23 Feb 2026 08:36:13 -0500 Subject: [PATCH 03/48] implement state loading / saving --- src/model/kv_cache.py | 12 ++++++++++++ src/world_engine.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/model/kv_cache.py b/src/model/kv_cache.py index 839c9a5..da244a2 100644 --- a/src/model/kv_cache.py +++ b/src/model/kv_cache.py @@ -185,6 +185,18 @@ def reset(self): layer.reset() self._is_frozen = True + @torch.inference_mode() + def get_state(self): + layers = [(layer.kv.detach().clone(), layer.written.detach().clone()) for layer in self.layers] + return {"_is_frozen": self._is_frozen, "layers": layers} + + @torch.inference_mode() + def load_state(self, state): + self._is_frozen = bool(state.get("_is_frozen", True)) + for layer, (kv, written) in zip(self.layers, state["layers"]): + layer.kv.copy_(kv) + layer.written.copy_(written) + def set_frozen(self, is_frozen: bool): self._is_frozen = is_frozen diff --git a/src/world_engine.py b/src/world_engine.py index 09a65e3..ae56a94 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -93,6 +93,17 @@ def reset(self): for v in self._ctx.values(): v.zero_() + @torch.inference_mode() + def get_state(self): + """Captures a world state to continue via load_state. Doesn't save model""" + return {"kv_cache": self.kv_cache.get_state(), "frame_ts": self.frame_ts.detach().clone()} + + @torch.inference_mode() + def load_state(self, state): + """Loads a world state object saved via save_state. Doesn't load or change model""" + self.kv_cache.load_state(state["kv_cache"]) + self.frame_ts.copy_(state["frame_ts"]) + def set_prompt(self, prompt: str): """Apply text conditioning for T2V""" if self.prompt_encoder is None: From f1c93e1633295bc47583ac97dc8a0e49daed0b23 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 23 Feb 2026 08:46:22 -0500 Subject: [PATCH 04/48] moe + fbgemm optimization --- pyproject.toml | 1 + src/model/world_model.py | 119 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 118 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0fb905..a630e90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "torch==2.10.0", "torchvision==0.25.0", "torchaudio==2.10.0", + "fbgemm-gpu-genai==1.5.0", "einops", "rotary-embedding-torch>=0.8.8", "tensordict==0.10.0", diff --git a/src/model/world_model.py b/src/model/world_model.py index aec0ce5..b92d446 100644 --- a/src/model/world_model.py +++ b/src/model/world_model.py @@ -9,6 +9,11 @@ from torch import nn import torch.nn.functional as F + +from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling +import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa + + from .attn import Attn, CrossAttention from .nn import AdaLN, ada_gate, ada_rmsnorm, NoiseConditioner from .base_model import BaseModel @@ -79,6 +84,104 @@ def forward(self, x: torch.Tensor, is_conditioned: Optional[bool] = None) -> tor return x if is_conditioned else null +class MoEWithoutFBGEMM(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.moe_top_k + moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or config.mlp_ratio / config.moe_top_k + d_intermediate = int(config.d_model * moe_mlp_ratio) + self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False) + self.expert_in_proj = nn.Parameter( + torch.empty(config.moe_n_experts, d_intermediate * (2 if config.gated_linear else 1), config.d_model) + ) + self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_intermediate)) + + def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: + if self.training or torch.is_grad_enabled(): + raise NotImplementedError("inference only") + + orig_shape = x.shape + x = x.reshape(-1, orig_shape[-1]) + logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1)) + + logits_fp32 = logits.float() + scores, expert = logits.topk(self.top_k, dim=-1, sorted=False) + weights = (scores.float() - logits_fp32.logsumexp(dim=-1, keepdim=True)).exp().to(x.dtype) + + expert = expert.flatten() + expert_sorted, sort_idx = expert.sort() + expert_ids = torch.arange(self.expert_in_proj.size(0), device=expert.device, dtype=expert_sorted.dtype) + offsets = torch.searchsorted(expert_sorted, expert_ids, right=True).to(torch.int32) + + # (1) Pad the *indices* instead of cat-copying x_grouped + src = sort_idx // self.top_k + x_grouped = x.index_select(0, torch.cat((src, src[:1]), dim=0)) + h = F.grouped_mm( + x_grouped, + self.expert_in_proj.transpose(-2, -1), + offs=offsets + ) + h[-1].zero_() # ensure last row initialized + + if self.config.gated_linear: + gate_act, up = h.chunk(2, dim=-1) + h = F.silu(gate_act) * up + else: + h = F.silu(h) + + y_grouped = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offsets)[:-1] + y = torch.empty_like(y_grouped).index_copy_(0, sort_idx, y_grouped).view(x.size(0), self.top_k, -1) + return (y * weights.unsqueeze(-1)).sum(dim=1).reshape(orig_shape) + + +class MoE(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.moe_top_k + moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or (config.mlp_ratio / config.moe_top_k) + d_int = int(config.d_model * moe_mlp_ratio) + + self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False) + self.expert_in_proj = nn.Parameter( + torch.empty(config.moe_n_experts, d_int * (2 if config.gated_linear else 1), config.d_model) + ) # (E, N, K) for grouped_mm + self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_int)) # (E, N, K) + + def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: + if self.training or torch.is_grad_enabled(): + raise NotImplementedError("inference only") + + orig = x.shape + x = x.reshape(-1, orig[-1]) + logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1)) + + logits32 = logits.float() + token_counts, expert_sorted, src = index_shuffling(logits32, top_k=self.top_k) + + E = self.expert_in_proj.size(0) + offs = token_counts[:E].cumsum(0).to(torch.int32) + + src = src.to(torch.long) + expert_sorted = expert_sorted.to(torch.long) + logZ = logits32.logsumexp(-1) + w = (logits32[src, expert_sorted] - logZ[src]).exp().to(x.dtype) # [T*K] + + xg = x.index_select(0, torch.cat((src, src[:1]), 0)) # pad by 1 for grouped_mm offs constraint + h = F.grouped_mm(xg, self.expert_in_proj.transpose(-2, -1), offs=offs) + if self.config.gated_linear: + ga, up = h.chunk(2, -1) + h = F.silu(ga) * up + else: + h = F.silu(h) + + yg = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offs)[:-1] + out = torch.zeros_like(x) + torch.ops.fbgemm.scatter_add_dense_tokens(out, (yg * w.unsqueeze(-1)).contiguous(), src) + return out.reshape(orig) + + class MLP(nn.Module): def __init__(self, dim_in, dim_middle, dim_out): super().__init__() @@ -144,12 +247,15 @@ def __init__(self, config, layer_idx): super().__init__() self.config = config self.attn = Attn(config, layer_idx) - self.mlp = MLP(config.d_model, config.d_model * config.mlp_ratio, config.d_model) + if getattr(config, "moe", False): + self.mlp = MoE(config) + else: + self.mlp = MLP(config.d_model, config.d_model * config.mlp_ratio, config.d_model) self.cond_head = CondHead(config) do_prompt_cond = config.prompt_conditioning is not None and layer_idx % config.prompt_conditioning_period == 0 self.prompt_cross_attn = CrossAttention(config, config.prompt_embedding_dim) if do_prompt_cond else None - do_ctrl_cond = layer_idx % config.ctrl_conditioning_period == 0 + do_ctrl_cond = config.ctrl_conditioning_period is not None and layer_idx % config.ctrl_conditioning_period == 0 self.ctrl_mlpfusion = MLPFusion(config) if do_ctrl_cond else None def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): @@ -304,3 +410,12 @@ def forward( ) return x + + def get_active_parameters(self) -> int: + total = sum(p.numel() for p in self.parameters()) + c = self.config + if getattr(c, "moe", False): + moe_mlp_ratio = getattr(c, "moe_mlp_ratio", None) or c.mlp_ratio / c.moe_top_k + hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts) + total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2) + return total From c6f95be04c7257e54c2f244259efacfe0bbe1f71 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 4 Mar 2026 13:36:14 -0500 Subject: [PATCH 05/48] wp-1.5 staging --- pyproject.toml | 3 +- src/ae.py | 99 ++++++++++++++++++++++++++++++++++++++++ src/model/world_model.py | 74 +++++++++++++++++++++++++++++- src/world_engine.py | 16 +++++-- 4 files changed, 186 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a630e90..c7abd72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,9 +4,10 @@ build-backend = "setuptools.build_meta" [project] name = "world_engine" -version = "1.0.0" +version = "1.5.0" requires-python = ">=3.9" dependencies = [ + "taehv @ git+https://github.com/madebyollin/taehv.git@7dc60ec6601af2e668e31bc70acc4cb3665e4c22", "torch==2.10.0", "torchvision==0.25.0", "torchaudio==2.10.0", diff --git a/src/ae.py b/src/ae.py index 577d765..4bd1912 100644 --- a/src/ae.py +++ b/src/ae.py @@ -85,3 +85,102 @@ def decode(self, latent: Tensor): decoded = (decoded / 2 + 0.5).clamp(0, 1) decoded = (decoded * 255).round().to(torch.uint8) return decoded.squeeze(0).permute(1, 2, 0)[..., :3] + + +class ChunkedStreamingTAEHV(InferenceAE): + def __init__(self, ae_model, device=None, dtype=torch.bfloat16): + self.device = device + self.dtype = dtype + self.ae_model = ae_model.eval().to(device=device, dtype=dtype) + + @classmethod + def from_pretrained(cls, model_uri: str, **kwargs): + import pathlib + + import huggingface_hub + from taehv import TAEHV + + try: + base = pathlib.Path(huggingface_hub.snapshot_download(model_uri)) + except Exception: + base = pathlib.Path(model_uri) + + ckpt = base if base.is_file() else base / "taehv1_5.pth" + return cls(TAEHV(str(ckpt)), **kwargs) + + def reset(self): + from taehv import StreamingTAEHV + + self._encoder = StreamingTAEHV(self.ae_model) + self._decoder = StreamingTAEHV(self.ae_model) + self._is_first_encode = True + + @torch.inference_mode() + def encode(self, img: Tensor): + """ + First call: + img: [H, W, C] uint8 + Later calls: + img: [T, H, W, C] uint8 with T == self.ae_model.t_downscale + + Returns: + latent: [B, C, h, w] + """ + if img.dim() == 3: + img = img.unsqueeze(0) + + assert img.dim() == 4 and img.shape[-1] == 3, ( + "Expected [H, W, C] or [T, H, W, C] uint8 image tensor" + ) + + expected_t = 1 if self._is_first_encode else self.ae_model.t_downscale + if img.shape[0] != expected_t: + raise ValueError( + f"Expected {expected_t} frame(s), got {img.shape[0]}" + ) + + rgb = img.unsqueeze(0).to(device=self.device, dtype=self.dtype) + rgb = rgb.permute(0, 1, 4, 2, 3).contiguous().div(255) + + if self._is_first_encode: + rgb = rgb.repeat(1, self.ae_model.t_downscale, 1, 1, 1) + self._is_first_encode = False + + latent = self._encoder.encode(rgb) + if latent is None: + raise RuntimeError("Expected a latent after a full chunk") + + return latent[:, 0] + + @torch.inference_mode() + def decode(self, latent: Tensor): + """ + Input: + latent: [B, C, h, w] + + Returns: + frames: [T, H, W, C] uint8 + """ + latent = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype) + + first = self._decoder.decode(latent) + if first is None: + raise RuntimeError("Expected decoded output after a latent") + + frames = [first] + while True: + frame = self._decoder.decode() + if frame is None: + break + frames.append(frame) + + decoded = torch.cat(frames, dim=1) + decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8) + return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3] + + +def get_ae(ae_uri, is_taehv_ae=False, **kwargs): + if is_taehv_ae: + return ChunkedStreamingTAEHV.from_pretrained(ae_uri, **kwargs) + else: + return InferenceAE.from_pretrained(ae_uri, **kwargs) diff --git a/src/model/world_model.py b/src/model/world_model.py index b92d446..07e8e5e 100644 --- a/src/model/world_model.py +++ b/src/model/world_model.py @@ -340,7 +340,7 @@ def __init__(self, config): self.transformer = WorldDiT(config) - self.patch = tuple(getattr(config, "patch", (1, 1))) + self.patch = tuple(config.patch) C, D = config.channels, config.d_model self.patchify = nn.Conv2d(C, D, kernel_size=self.patch, stride=self.patch, bias=False) @@ -419,3 +419,75 @@ def get_active_parameters(self) -> int: hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts) total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2) return total + + def load_state_dict(self, state_dict, strict=True, assign=False): + if getattr(self.config, "model_type", "waypoint-1") != "waypoint-1.5": + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + state_dict = dict(state_dict) + + if "unpatchify.weight" in state_dict and state_dict["unpatchify.weight"].ndim == 4: + w = state_dict["unpatchify.weight"] # [D, C, ph, pw] + state_dict["unpatchify.weight"] = w.permute(1, 2, 3, 0).reshape(-1, w.shape[0]) + if "unpatchify.bias" in state_dict and state_dict["unpatchify.bias"].numel() != self.unpatchify.bias.numel(): + ph, pw = self.patch + state_dict["unpatchify.bias"] = state_dict["unpatchify.bias"][:, None, None].expand(-1, ph, pw).reshape(-1) + + for i in range(self.config.n_layers): + p = f"transformer.blocks.{i}." + + for name in ("fc1.weight", "fc2.weight"): + old = p + "dit_mlp." + name + if old in state_dict: + state_dict.setdefault(p + "mlp." + name, state_dict.pop(old)) + + attn_bias = state_dict.pop(p + "attn_cond_head.bias_in", None) + mlp_bias = state_dict.pop(p + "mlp_cond_head.bias_in", None) + if attn_bias is not None or mlp_bias is not None: + state_dict.setdefault(p + "cond_head.bias_in", mlp_bias if mlp_bias is not None else attn_bias) + + for j in range(3): + attn = state_dict.pop(p + f"attn_cond_head.cond_proj.{j}.weight", None) + mlp = state_dict.pop(p + f"mlp_cond_head.cond_proj.{j}.weight", None) + if attn is not None: + state_dict.setdefault(p + f"cond_head.cond_proj.{j}.weight", attn) + if mlp is not None: + state_dict.setdefault(p + f"cond_head.cond_proj.{j + 3}.weight", mlp) + + x = state_dict.pop(p + "ctrl_mlpfusion.fc1_x.weight", None) + c = state_dict.pop(p + "ctrl_mlpfusion.fc1_c.weight", None) + if x is not None and c is not None: + state_dict.setdefault(p + "ctrl_mlpfusion.mlp.fc1.weight", torch.cat((x, c), dim=1)) + old = state_dict.pop(p + "ctrl_mlpfusion.fc2.weight", None) + if old is not None: + state_dict.setdefault(p + "ctrl_mlpfusion.mlp.fc2.weight", old) + + ref = "transformer.blocks.0.cond_head.cond_proj." + for i in range(1, self.config.n_layers): + p = f"transformer.blocks.{i}.cond_head.cond_proj." + for j in range(6): + k, rk = p + f"{j}.weight", ref + f"{j}.weight" + if k not in state_dict and rk in state_dict: + state_dict[k] = state_dict[rk] + state_dict = {k: v for k, v in state_dict.items() if ".cond_heads." not in k} + + return super().load_state_dict(state_dict, strict=strict, assign=assign) + +""" +TODO: use the below for quantization +import torch.nn as nn + + +class BufferLinear(nn.Linear): + capture = True + + def _load_from_state_dict(self, sd, prefix, *args): + if self.capture: + keep = {prefix + "weight", prefix + "bias"} + for k in list(sd): + if k.startswith(prefix) and k not in keep and "." not in k[len(prefix):]: + n = k[len(prefix):] + self.register_buffer(n, sd.pop(k).to(self.weight.device), + persistent=False) + super()._load_from_state_dict(sd, prefix, *args) +""" diff --git a/src/world_engine.py b/src/world_engine.py index ae56a94..363db42 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from .model import WorldModel, StaticKVCache, PromptEncoder -from .ae import InferenceAE +from .ae import get_ae from .patch_model import apply_inference_patches from .quantize import quantize_model @@ -16,6 +16,14 @@ # fix graph break: torch._dynamo.config.capture_scalar_outputs = True +COMPILE_OPTIONS = { + "max_autotune": True, + "coordinate_descent_tuning": True, + "triton.cudagraphs": True, # set False to mimic *-no-cudagraphs + "epilogue_fusion": True, # requires max_autotune + "shape_padding": True, +} + @dataclass class CtrlInput: @@ -47,7 +55,7 @@ def __init__( self.model_cfg.merge_with(model_config_overrides) # Model - self.vae = InferenceAE.from_pretrained(self.model_cfg.ae_uri, device=device, dtype=dtype) + self.vae = get_ae(self.model_cfg.ae_uri, getattr(self.model_cfg, "taehv_ae", False), device=device, dtype=dtype) self.prompt_encoder = None if self.model_cfg.prompt_conditioning is not None: @@ -153,7 +161,7 @@ def prep_inputs(self, x, ctrl=None): self.set_prompt("An explorable world") return {**ctx, **self._prompt_ctx} - @torch.compile(fullgraph=True, mode="max-autotune", dynamic=False) + @torch.compile(fullgraph=True, dynamic=False, options=COMPILE_OPTIONS) def _denoise_pass(self, x, ctx: Dict[str, Tensor], kv_cache): kv_cache.set_frozen(True) sigma = x.new_empty((x.size(0), x.size(1))) @@ -162,7 +170,7 @@ def _denoise_pass(self, x, ctx: Dict[str, Tensor], kv_cache): x = x + step_dsig * v return x - @torch.compile(fullgraph=True, mode="max-autotune", dynamic=False) + @torch.compile(fullgraph=True, dynamic=False, options=COMPILE_OPTIONS) def _cache_pass(self, x, ctx: Dict[str, Tensor], kv_cache): """Side effect: updates kv cache""" kv_cache.set_frozen(False) From 7cf8c25aa55a2dc4f99c0c62f3972e410a87f64d Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 4 Mar 2026 16:54:23 -0500 Subject: [PATCH 06/48] clean up and fix ae --- src/ae.py | 130 ++++++++++++++++++++++++------------------------------ 1 file changed, 57 insertions(+), 73 deletions(-) diff --git a/src/ae.py b/src/ae.py index 4bd1912..624be42 100644 --- a/src/ae.py +++ b/src/ae.py @@ -2,29 +2,6 @@ from torch import Tensor -""" -WARNING: -- Always assumes scale=1, shift=0 -""" - - -def bake_weight_norm_(module) -> int: - """ - Removes weight parametrizations (from torch.nn.utils.parametrizations.weight_norm) - and leaves the current parametrized weight as a plain Parameter. - Returns how many modules were de-parametrized. - """ - import torch.nn.utils.parametrize as parametrize - - n = 0 - for m in module.modules(): - # weight_norm registers a parametrization on "weight" - if hasattr(m, "parametrizations") and "weight" in getattr(m, "parametrizations", {}): - parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) - n += 1 - return n - - class InferenceAE: def __init__(self, ae_model, device=None, dtype=torch.bfloat16): self.device = device @@ -54,10 +31,27 @@ def from_pretrained(cls, model_uri: str, **kwargs): model.encoder.load_state_dict(enc_sd, strict=True) model.decoder.load_state_dict(dec_sd, strict=True) - bake_weight_norm_(model) + cls.bake_weight_norm_(model) return cls(model, **kwargs) + @staticmethod + def bake_weight_norm_(module) -> int: + """ + Removes weight parametrizations (from torch.nn.utils.parametrizations.weight_norm) + and leaves the current parametrized weight as a plain Parameter. + Returns how many modules were de-parametrized. + """ + import torch.nn.utils.parametrize as parametrize + + n = 0 + for m in module.modules(): + # weight_norm registers a parametrization on "weight" + if hasattr(m, "parametrizations") and "weight" in getattr(m, "parametrizations", {}): + parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) + n += 1 + return n + def encode(self, img: Tensor): """RGB -> RGB+D -> latent""" assert img.dim() == 3, "Expected [H, W, C] image tensor" @@ -87,11 +81,16 @@ def decode(self, latent: Tensor): return decoded.squeeze(0).permute(1, 2, 0)[..., :3] -class ChunkedStreamingTAEHV(InferenceAE): + +class ChunkedStreamingTAEHV: def __init__(self, ae_model, device=None, dtype=torch.bfloat16): + from taehv import StreamingTAEHV + self.device = device self.dtype = dtype - self.ae_model = ae_model.eval().to(device=device, dtype=dtype) + self.streaming_ae_model = StreamingTAEHV( + ae_model.eval().to(device=device, dtype=dtype) + ) @classmethod def from_pretrained(cls, model_uri: str, **kwargs): @@ -111,68 +110,53 @@ def from_pretrained(cls, model_uri: str, **kwargs): def reset(self): from taehv import StreamingTAEHV - self._encoder = StreamingTAEHV(self.ae_model) - self._decoder = StreamingTAEHV(self.ae_model) - self._is_first_encode = True + # Rebuild streaming state, reuse same weights model + self.streaming_ae_model = StreamingTAEHV(self.streaming_ae_model.taehv) @torch.inference_mode() def encode(self, img: Tensor): """ - First call: - img: [H, W, C] uint8 - Later calls: - img: [T, H, W, C] uint8 with T == self.ae_model.t_downscale - - Returns: - latent: [B, C, h, w] + img: [T, H, W, C] uint8 where T == t_downscale + returns: latent [B, C, h, w] """ - if img.dim() == 3: - img = img.unsqueeze(0) - - assert img.dim() == 4 and img.shape[-1] == 3, ( - "Expected [H, W, C] or [T, H, W, C] uint8 image tensor" + assert img.dim() == 4 and img.shape[-1] == 3, "Expected [T, H, W, C] RGB uint8" + + t = self.streaming_ae_model.taehv.t_downscale + if img.shape[0] != t: + raise ValueError(f"Expected {t} frames, got {img.shape[0]}") + + rgb = ( + img.unsqueeze(0) + .to(device=self.device, dtype=self.dtype) + .permute(0, 1, 4, 2, 3) + .contiguous() + .div(255) ) - expected_t = 1 if self._is_first_encode else self.ae_model.t_downscale - if img.shape[0] != expected_t: - raise ValueError( - f"Expected {expected_t} frame(s), got {img.shape[0]}" - ) - - rgb = img.unsqueeze(0).to(device=self.device, dtype=self.dtype) - rgb = rgb.permute(0, 1, 4, 2, 3).contiguous().div(255) - - if self._is_first_encode: - rgb = rgb.repeat(1, self.ae_model.t_downscale, 1, 1, 1) - self._is_first_encode = False - - latent = self._encoder.encode(rgb) + latent = self.streaming_ae_model.encode(rgb) if latent is None: raise RuntimeError("Expected a latent after a full chunk") - return latent[:, 0] + return latent.squeeze(1) @torch.inference_mode() def decode(self, latent: Tensor): """ - Input: - latent: [B, C, h, w] - - Returns: - frames: [T, H, W, C] uint8 + latent: [B, C, h, w] + returns: frames [T, H, W, C] uint8 """ - latent = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype) - - first = self._decoder.decode(latent) - if first is None: - raise RuntimeError("Expected decoded output after a latent") - - frames = [first] - while True: - frame = self._decoder.decode() - if frame is None: - break - frames.append(frame) + assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor" + + z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype) + + if self.streaming_ae_model.n_frames_decoded == 0: + for _ in range(self.streaming_ae_model.taehv.frames_to_trim): + self.streaming_ae_model.decode(z) + self.streaming_ae_model.flush_decoder() + + first = self.streaming_ae_model.decode(z) + assert first is not None, "Expected decoded output after a latent" + frames = [first, *self.streaming_ae_model.flush_decoder()] decoded = torch.cat(frames, dim=1) decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8) From 586a3c0a438def8a1f00433bbf9fe9819279f753 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 4 Mar 2026 17:13:31 -0500 Subject: [PATCH 07/48] fix temporal compression rope bugs --- src/model/kv_cache.py | 25 ++++++++++--------------- src/model/world_model.py | 8 +++++++- src/world_engine.py | 8 +++++++- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/model/kv_cache.py b/src/model/kv_cache.py index da244a2..898b8a9 100644 --- a/src/model/kv_cache.py +++ b/src/model/kv_cache.py @@ -107,26 +107,21 @@ def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): t_pos: [B, T], all equal per frame (ignoring -1) """ T = self.tpf - t_pos = pos_ids["t_pos"] + f_pos = pos_ids["f_pos"] if not torch.compiler.is_compiling(): torch._check(kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert") - torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") + torch._check(f_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") torch._check(self.tpf <= self.L, "frame longer than KV ring capacity") - torch._check(self.L % self.tpf == 0, - f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})") - torch._check(self.kv.size(3) == self.capacity, - "KV buffer has unexpected length (expected L + tokens_per_frame)") - torch._check( - (t_pos >= 0).all().item(), - "t_pos must be non-negative during inference", - ) - torch._check(((t_pos == t_pos[:, :1]).all()).item(), "t_pos must be constant within frame") + torch._check(self.L % self.tpf == 0, f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})") + torch._check(self.kv.size(3) == self.capacity, "KV buffer too long (expected L + tokens_per_frame)") + torch._check((f_pos >= 0).all().item(), "t_pos must be non-negative during inference") + torch._check(((f_pos == f_pos[:, :1]).all()).item(), "t_pos must be constant within frame") - frame_t = t_pos[0, 0] + frame_idx = f_pos[0, 0] # map frame_t to a bucket, each bucket owns T contiguous slots - bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation + bucket = (frame_idx + (self.pinned_dilation - 1)) // self.pinned_dilation slot = bucket % self.num_buckets base = slot * T @@ -137,7 +132,7 @@ def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): # this is the "self-attention component" for the current frame. self.kv.index_copy_(3, self.current_idx, kv) - write_step = (frame_t.remainder(self.pinned_dilation) == 0) + write_step = (frame_idx.remainder(self.pinned_dilation) == 0) mask_written = self._mask_written mask_written.copy_(self.written) mask_written[ring_idx] = mask_written[ring_idx] & ~write_step @@ -158,7 +153,7 @@ class StaticKVCache(nn.Module): def __init__(self, config, batch_size, dtype): super().__init__() - self.tpf = config.tokens_per_frame + self.tpf = config.height * config.width local_L = config.local_window * self.tpf global_L = config.global_window * self.tpf diff --git a/src/model/world_model.py b/src/model/world_model.py index 07e8e5e..1c4ca23 100644 --- a/src/model/world_model.py +++ b/src/model/world_model.py @@ -359,6 +359,7 @@ def forward( x: Tensor, sigma: Tensor, frame_timestamp: Tensor, + frame_idx: Optional[Tensor] = None, prompt_emb: Optional[Tensor] = None, prompt_pad_mask: Optional[Tensor] = None, mouse: Optional[Tensor] = None, @@ -385,7 +386,12 @@ def forward( torch._assert(B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1") self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) pos_ids = TensorDict( - {"t_pos": self._t_pos_1f[None], "y_pos": self._y_pos_1f[None], "x_pos": self._x_pos_1f[None]}, + { + "f_pos": (frame_timestamp if frame_idx is None else frame_idx)[0, 0].expand_as(self._t_pos_1f)[None], + "t_pos": self._t_pos_1f[None], + "y_pos": self._y_pos_1f[None], + "x_pos": self._x_pos_1f[None], + }, batch_size=[1, self._t_pos_1f.numel()], ) diff --git a/src/world_engine.py b/src/world_engine.py index 363db42..ca441db 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -83,12 +83,17 @@ def __init__( self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device) self.frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) + inference_fps = getattr(self.model_cfg, "inference_fps", self.model_cfg.base_fps) + latent_fps = inference_fps / getattr(self.model_cfg, "temporal_compression", 1) + self.ts_mult = int(self.model_cfg.base_fps) // latent_fps + # Static input context tensors self._ctx = { "button": torch.zeros((1, 1, self.model_cfg.n_buttons), device=device, dtype=dtype), "mouse": torch.zeros((1, 1, 2), device=device, dtype=dtype), "scroll": torch.zeros((1, 1, 1), device=device, dtype=dtype), "frame_timestamp": torch.empty((1, 1), device=device, dtype=torch.long), + "frame_idx": torch.empty((1, 1), device=device, dtype=torch.long), } self._prompt_ctx = {"prompt_emb": None, "prompt_pad_mask": None} @@ -140,7 +145,8 @@ def _prep_inputs(self, x, ctrl=None): self._ctx["mouse"][0, 0, 1] = ctrl.mouse[1] self._ctx["scroll"][0, 0, 0] = ctrl.scroll_wheel - self._ctx["frame_timestamp"].copy_(self.frame_ts) + self._ctx["frame_idx"].copy_(self.frame_ts) + self._ctx["frame_timestamp"].copy_(self.frame_ts).mul_(self.ts_mult) self.frame_ts.add_(1) return self._ctx From 5125dc1fa4e57d97a41bce6b63f8b50d392ca4f2 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 4 Mar 2026 17:48:41 -0500 Subject: [PATCH 08/48] vae reset in world_engine.reset --- src/ae.py | 3 +++ src/world_engine.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/ae.py b/src/ae.py index 624be42..fe38090 100644 --- a/src/ae.py +++ b/src/ae.py @@ -8,6 +8,9 @@ def __init__(self, ae_model, device=None, dtype=torch.bfloat16): self.dtype = dtype self.ae_model = ae_model.eval().to(device=device, dtype=dtype) + def reset(self): + pass + @classmethod def from_pretrained(cls, model_uri: str, **kwargs): import pathlib diff --git a/src/world_engine.py b/src/world_engine.py index ca441db..5a71783 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -105,6 +105,7 @@ def reset(self): self.frame_ts.zero_() for v in self._ctx.values(): v.zero_() + self.vae.reset() @torch.inference_mode() def get_state(self): From 9ba9b4db585c1a3dfcbb0d272aa06778b56c30b8 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 5 Mar 2026 18:10:50 -0500 Subject: [PATCH 09/48] reduce peak memory --- src/ae.py | 1 - src/model/base_model.py | 6 +++--- src/world_engine.py | 7 +++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/ae.py b/src/ae.py index fe38090..d8d537c 100644 --- a/src/ae.py +++ b/src/ae.py @@ -84,7 +84,6 @@ def decode(self, latent: Tensor): return decoded.squeeze(0).permute(1, 2, 0)[..., :3] - class ChunkedStreamingTAEHV: def __init__(self, ae_model, device=None, dtype=torch.bfloat16): from taehv import StreamingTAEHV diff --git a/src/model/base_model.py b/src/model/base_model.py index 2b2905e..d70e07b 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -18,7 +18,7 @@ def save_pretrained(self, path: str) -> None: OmegaConf.save(self.config, os.path.join(path, "config.yaml")) @classmethod - def from_pretrained(cls, path: str, cfg=None, device=None): + def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None): """Load weights and OmegaConf YAML.""" device = device or "cpu" @@ -31,8 +31,8 @@ def from_pretrained(cls, path: str, cfg=None, device=None): cfg = cls.load_config(path) model = cls(cfg) - if device != "cpu": - model = model.to(device) + if dtype is not None: + model = model.to(dtype=dtype, device=device) # Stream weights straight into `model` (no CPU state_dict first) safetensors_path = os.path.join(path, "model.safetensors") diff --git a/src/world_engine.py b/src/world_engine.py index 5a71783..28f1268 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -63,12 +63,11 @@ def __init__( self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).to(device).eval() if load_weights: - self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg) + self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg, dtype=dtype) else: - self.model = WorldModel(self.model_cfg) - self.model = self.model.to(device=device, dtype=dtype).eval() - + self.model = WorldModel(self.model_cfg).to(dtype=dtype) apply_inference_patches(self.model) + self.model = self.model.to(device=device).eval() if quant is not None: quantize_model(self.model, quant) From 4c5ecb54426adeba83ca53e7f174016bd63922f7 Mon Sep 17 00:00:00 2001 From: Clydingus <40514241+Clydingus@users.noreply.github.com> Date: Tue, 10 Mar 2026 01:23:34 +0800 Subject: [PATCH 10/48] Implements the orthorope angles computation instead of precomputing (#25) * fix: uv sync issue with python version 3.9 * fix: VRAM explosion * refactor: init on gpu device directly * fix: don't use fbgemm on windows for now * feat: orthoropeangles * fix: NoCastModule OrthoRoPEAngles * fix: remove pos_ids from args * fix: remove old src rope replacement patch * fix: remove out of scope ae changes * fix: remove out of scope text encoder changes * fix: patch_model pos_ids --------- Co-authored-by: Philpax --- pyproject.toml | 4 +- src/model/attn.py | 92 ++++++++++++++++++++-------------------- src/model/world_model.py | 25 +++++------ src/patch_model.py | 4 +- src/world_engine.py | 2 +- 5 files changed, 64 insertions(+), 63 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7abd72..df92c93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,13 +5,13 @@ build-backend = "setuptools.build_meta" [project] name = "world_engine" version = "1.5.0" -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ "taehv @ git+https://github.com/madebyollin/taehv.git@7dc60ec6601af2e668e31bc70acc4cb3665e4c22", "torch==2.10.0", "torchvision==0.25.0", "torchaudio==2.10.0", - "fbgemm-gpu-genai==1.5.0", + "fbgemm-gpu-genai==1.5.0; sys_platform == 'linux'", "einops", "rotary-embedding-torch>=0.8.8", "tensordict==0.10.0", diff --git a/src/model/attn.py b/src/model/attn.py index 52c2b57..c6024e8 100644 --- a/src/model/attn.py +++ b/src/model/attn.py @@ -9,63 +9,63 @@ from .nn import rms_norm, NoCastModule -class RoPE(NoCastModule): +class OrthoRoPEAngles(NoCastModule): + """Functions as a on the fly RoPE angle computer called every fwd pass. Should be setup + as a module under WordDiT, then each forward pass it computes a shared tuple of (rope_cos, rope_sin) + tensors that get passed to every block for their underlying RoPE computations.""" def __init__(self, config): super().__init__() self.config = config - assert not getattr(self.config, "has_audio", False) - freqs = self.get_freqs(config) - self.cos = nn.Buffer(freqs.cos().contiguous(), persistent=False) - self.sin = nn.Buffer(freqs.sin().contiguous(), persistent=False) + d_head = config.d_model // config.n_heads + torch._assert(d_head % 8 == 0, "d_head must be divisible by 8") + d_xy, d_t = d_head // 8, d_head // 4 + + nyq = float(getattr(config, "rope_nyquist_frac", 0.8)) + max_freq = min(self.config.height, self.config.width) * nyq + n = (d_xy + 1) // 2 + xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy] + + theta = float(getattr(config, "rope_theta", 10000.0)) + inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t)) + inv_t = inv_t.repeat_interleave(2) # [d_t] + + self.register_buffer("xy", xy, persistent=False) # [d_xy] + self.register_buffer("inv_t", inv_t, persistent=False) # [d_t] - def get_angles(self, pos_ids): - t, y, x = pos_ids["t_pos"], pos_ids["y_pos"], pos_ids["x_pos"] # [B,T] - H, W = self.config.height, self.config.width + @torch.autocast("cuda", enabled=False) + def forward(self, pos_ids): if not torch.compiler.is_compiling(): - torch._assert((y.max() < H) & (x.max() < W), f"pos_ids out of bounds, {y.max()}, {x.max()}") - flat = t * (H * W) + y * W + x # [B,T] - idx = flat.reshape(-1).to(torch.long) - cos = self.cos.index_select(0, idx).view(*flat.shape, -1) - sin = self.sin.index_select(0, idx).view(*flat.shape, -1) - return cos[:, None], sin[:, None] # add head dim for broadcast + torch._assert( + (pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width), + f"pos_ids out of bounds, {self.config.height}, {self.config.width}" + ) + + x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0 + y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0 + t = pos_ids["t_pos"].float() + + freqs = torch.cat( + (x.unsqueeze(-1) * self.xy, y.unsqueeze(-1) * self.xy, t.unsqueeze(-1) * self.inv_t), + dim=-1, # [B,T,d_head//2] + ) + # Returns rope_cos, rope_sin angles of shape [B, 1, T, D/2] + return freqs.cos()[:, None], freqs.sin()[:, None] + +class OrthoRoPE(NoCastModule): + def __init__(self, config): + super().__init__() + self.config = config + assert not getattr(self.config, "has_audio", False) @torch.autocast("cuda", enabled=False) - def forward(self, x, pos_ids): - assert self.cos.dtype == self.sin.dtype == torch.float32 - cos, sin = self.get_angles(pos_ids) + def forward(self, x, rope_angles): + cos, sin = rope_angles x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1) y0 = x0 * cos - x1 * sin y1 = x1 * cos + x0 * sin return torch.cat((y0, y1), dim=-1).type_as(x) - def get_freqs(self, config): - raise NotImplementedError - - -class OrthoRoPE(RoPE): - """ - RoPE for rotation across orthogonal axes: time, height, and width - Time: Geometric Spectrum -- rotates 1/2 of head dim - Height / Width: Linear Spectrum -- rotates 1/4th of head dim each (1/2 combined) - """ - def get_freqs(self, config): - H, W, T = config.height, config.width, config.n_frames - head_dim = config.d_model // config.n_heads - - max_freq = min(H, W) * 0.8 # stay below nyquist - rope_xy = RotaryEmbedding(dim=head_dim // 8, freqs_for='pixel', max_freq=max_freq) - freqs_x = rope_xy(torch.linspace(-1 + 1 / W, 1 - 1 / W, W))[None, :, :] # [1,W,D] - freqs_y = rope_xy(torch.linspace(-1 + 1 / H, 1 - 1 / H, H))[:, None, :] # [H,1,D] - - freq_t = RotaryEmbedding(dim=head_dim // 4, freqs_for='lang').forward(torch.arange(T)) - - return torch.cat([ - eo.repeat(freqs_x.expand(H, W, -1), 'h w d -> (t h w) d', t=T), # X - eo.repeat(freqs_y.expand(H, W, -1), 'h w d -> (t h w) d', t=T), # Y - eo.repeat(freq_t, 't d -> (t h w) d', h=H, w=W) # T - ], dim=-1) - class Attn(nn.Module): def __init__(self, config, layer_idx): @@ -96,7 +96,7 @@ def __init__(self, config, layer_idx): self.gate_proj = nn.Linear(self.n_heads, self.n_heads, bias=False) # sparse attn gate nn.init.zeros_(self.gate_proj.weight) - def forward(self, x, pos_ids, v1, kv_cache): + def forward(self, x, pos_ids, rope_angles, v1, kv_cache): # Q, K, V proj -> QK-norm -> RoPE q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head) k = eo.rearrange(self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head) @@ -107,7 +107,7 @@ def forward(self, x, pos_ids, v1, kv_cache): v = torch.lerp(v, v1.view_as(v), self.v_lamb) q, k = rms_norm(q), rms_norm(k) - q, k = self.rope(q, pos_ids), self.rope(k, pos_ids) + q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) # Update KV-cache in-place k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) diff --git a/src/model/world_model.py b/src/model/world_model.py index 1c4ca23..8ea6053 100644 --- a/src/model/world_model.py +++ b/src/model/world_model.py @@ -10,11 +10,15 @@ import torch.nn.functional as F -from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling -import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa +try: + from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling + import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa + HAS_FBGEMM = True +except ImportError: + HAS_FBGEMM = False -from .attn import Attn, CrossAttention +from .attn import Attn, CrossAttention, OrthoRoPEAngles from .nn import AdaLN, ada_gate, ada_rmsnorm, NoiseConditioner from .base_model import BaseModel @@ -248,7 +252,7 @@ def __init__(self, config, layer_idx): self.config = config self.attn = Attn(config, layer_idx) if getattr(config, "moe", False): - self.mlp = MoE(config) + self.mlp = MoE(config) if HAS_FBGEMM else MoEWithoutFBGEMM(config) else: self.mlp = MLP(config.d_model, config.d_model * config.mlp_ratio, config.d_model) self.cond_head = CondHead(config) @@ -258,7 +262,7 @@ def __init__(self, config, layer_idx): do_ctrl_cond = config.ctrl_conditioning_period is not None and layer_idx % config.ctrl_conditioning_period == 0 self.ctrl_mlpfusion = MLPFusion(config) if do_ctrl_cond else None - def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): + def forward(self, x, pos_ids, rope_angles, cond, ctx, v, kv_cache=None): """ 0) Causal Frame Attention 1) Frame->CTX Cross Attention @@ -269,7 +273,7 @@ def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): # Self / Causal Attention residual = x x = ada_rmsnorm(x, s0, b0) - x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache) + x, v = self.attn(x, pos_ids, rope_angles, v, kv_cache=kv_cache) x = ada_gate(x, g0) + residual # Cross Attention Prompt Conditioning @@ -295,6 +299,7 @@ def __init__(self, config): super().__init__() self.config = config self.blocks = nn.ModuleList([WorldDiTBlock(config, idx) for idx in range(config.n_layers)]) + self.rope_angles = OrthoRoPEAngles(config) if self.config.noise_conditioning in ("dit_air", "wan"): ref_proj = self.blocks[0].cond_head.cond_proj @@ -302,15 +307,11 @@ def __init__(self, config): for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj): blk_mod.weight = ref_mod.weight - # Shared RoPE buffers - ref_rope = self.blocks[0].attn.rope - for blk in self.blocks[1:]: - blk.attn.rope = ref_rope - def forward(self, x, pos_ids, cond, ctx, kv_cache=None): + rope_angles = self.rope_angles(pos_ids) v = None for i, block in enumerate(self.blocks): - x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache) + x, v = block(x, pos_ids, rope_angles, cond, ctx, v, kv_cache=kv_cache) return x diff --git a/src/patch_model.py b/src/patch_model.py index 987ded1..817f220 100644 --- a/src/patch_model.py +++ b/src/patch_model.py @@ -113,7 +113,7 @@ def __init__(self, src: Attn, config): del self.q_proj, self.k_proj, self.v_proj - def forward(self, x, pos_ids, v1, kv_cache): + def forward(self, x, pos_ids, rope_angles, v1, kv_cache): q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1) B, T = x.shape[:2] @@ -126,7 +126,7 @@ def forward(self, x, pos_ids, v1, kv_cache): v = torch.lerp(v, v1.view_as(v), self.v_lamb) q, k = rms_norm(q), rms_norm(k) - q, k = self.rope(q, pos_ids), self.rope(k, pos_ids) + q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) diff --git a/src/world_engine.py b/src/world_engine.py index 5a71783..aba598c 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -63,7 +63,7 @@ def __init__( self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).to(device).eval() if load_weights: - self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg) + self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg, device=device) else: self.model = WorldModel(self.model_cfg) self.model = self.model.to(device=device, dtype=dtype).eval() From bf905201744128b602073e207530136d5c7b3fc5 Mon Sep 17 00:00:00 2001 From: Clydingus <40514241+Clydingus@users.noreply.github.com> Date: Tue, 10 Mar 2026 01:42:00 +0800 Subject: [PATCH 11/48] test: revert direct device init (#28) --- src/world_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/world_engine.py b/src/world_engine.py index aba598c..5a71783 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -63,7 +63,7 @@ def __init__( self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).to(device).eval() if load_weights: - self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg, device=device) + self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg) else: self.model = WorldModel(self.model_cfg) self.model = self.model.to(device=device, dtype=dtype).eval() From 177101f7d2339cc1049ca0bb0575101558ae91c1 Mon Sep 17 00:00:00 2001 From: Philpax Date: Fri, 6 Mar 2026 22:57:17 +0100 Subject: [PATCH 12/48] feat: use built triton-windows fork to fix long-path issue --- pyproject.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df92c93..7aeba41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,10 @@ dependencies = [ "accelerate==1.12.0", # Triton (platform-specific) - "triton; sys_platform == 'linux'", - "triton-windows; sys_platform == 'win32'", + "triton==3.6.0; sys_platform == 'linux'", + # TODO: move back to mainline triton-windows once long-path fix is merged + # (see https://github.com/triton-lang/triton-windows/pull/11#issuecomment-4014081904) + "triton-windows @ https://github.com/Overworldai/triton-windows/releases/download/v3.6.0-windows-longpath/triton_windows-3.6.0+git7df5604d-cp313-cp313-win_amd64.whl ; sys_platform == 'win32'", ] [tool.setuptools] From fe5873d16ac7cc534e6d5d4252ffd61c131f0ce1 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 9 Mar 2026 13:51:48 -0400 Subject: [PATCH 13/48] update gen_sample --- examples/gen_sample.py | 63 +++++++++++++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 62c27c5..a1c5af2 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -1,22 +1,53 @@ +# python3 examples/gen_sample.py +# e.g. python3 examples/gen_sample.py Overworld-Models/Lapp0-WP-Mini-1.4.5-BL-Distill + import cv2 -from world_engine import WorldEngine +import imageio.v3 as iio +import random +import sys +import urllib.request +import numpy as np +import torch + +from world_engine import WorldEngine, CtrlInput + + +# Create inference engine +engine = WorldEngine(sys.argv[1], device="cuda") + +# Set seed frame +url = random.choice([ + "https://gist.github.com/user-attachments/assets/d81c6d26-a838-4afe-9d13-fd67677043c3", + "https://gist.github.com/user-attachments/assets/b6d18c38-098e-43b0-8e61-66a16e5d8946", + "https://gist.github.com/user-attachments/assets/0734a8c1-3eb4-4ffe-8c37-5665c45ab559", + "https://gist.github.com/user-attachments/assets/f9c20d4d-7565-452d-8b02-42a85ea175ed", + "https://gist.github.com/user-attachments/assets/68c943a4-008a-4c25-948c-c81ab4c47d21", +]) +frame = cv2.imdecode(np.frombuffer(urllib.request.urlopen(url).read(), np.uint8), cv2.IMREAD_COLOR) +frame = cv2.resize(frame, (1024, 512))[:, :, ::-1] +engine.append_frame(torch.from_numpy(np.repeat(frame[None], 4, axis=0))) -def gen_vid(): - engine = WorldEngine("OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing", device="cuda") - writer = None - for _ in range(240): - frame = engine.gen_frame().cpu().numpy()[:, :, ::-1] # RGB -> BGR for OpenCV - writer = writer or cv2.VideoWriter( - "out.mp4", - cv2.VideoWriter_fourcc(*"mp4v"), - 60, - (frame.shape[1], frame.shape[0]) - ) - writer.write(frame) - writer.release() +# Define sequence of controller inputs applied +controller_sequence = [ + # move mouse, jump, do nothing, trigger, do nothing, trigger+jump, do nothing + CtrlInput(mouse=[0.2, 0.2]), CtrlInput(button={32}), CtrlInput(), CtrlInput(), CtrlInput(), + CtrlInput(button={1}), CtrlInput(), CtrlInput(), CtrlInput(button={1, 32}), + CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), +] * 4 +controller_sequence += [CtrlInput()] * 8 +controller_sequence += ( + [CtrlInput(button={32})] * 10 + # forward + [CtrlInput(button={65})] * 10 + # left + [CtrlInput(button={68})] * 10 + # right + [CtrlInput(button={83})] * 10 # backwards +) +controller_sequence += [CtrlInput()] * 10 -if __name__ == "__main__": - gen_vid() +# Generate frames conditioned on controller inputs +with iio.imopen("out.mp4", "w", plugin="pyav") as out: + out.write(engine.gen_frame().cpu().numpy(), fps=60, codec="libx264") + for ctrl in controller_sequence: + out.write(engine.gen_frame(ctrl=ctrl).cpu().numpy()) From 1935b6407fd6811a6cca060e3faeda8c4a6b2eea Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 9 Mar 2026 13:52:50 -0400 Subject: [PATCH 14/48] better quant --- src/quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quantize.py b/src/quantize.py index 172b8b1..b74825b 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -151,7 +151,7 @@ def __init__(self, lin: nn.Linear): if lin.bias is not None else None ) - w_amax = lin.weight.data.clone().amax().float().squeeze() + w_amax = lin.weight.data.abs().amax() w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn) self.register_buffer("w_amax", w_amax) self.register_buffer("weightT", w.t()) From facd12ac3f7975ec428a50032abd99dc222e71c7 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 9 Mar 2026 13:53:07 -0400 Subject: [PATCH 15/48] avoid warning when creating mouse / scroll tensors --- src/world_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/world_engine.py b/src/world_engine.py index 28f1268..6518fc5 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -156,8 +156,8 @@ def prep_inputs(self, x, ctrl=None): self._ctx["button"].zero_() if ctrl.button: self._ctx["button"][..., list(ctrl.button)] = 1.0 - ctrl.mouse = torch.tensor(ctrl.mouse, device=x.device, dtype=self._ctx["mouse"].dtype) - ctrl.scroll_wheel = torch.sign(torch.tensor(ctrl.scroll_wheel, device=x.device, dtype=self._ctx["scroll"].dtype)) + ctrl.mouse = torch.as_tensor(ctrl.mouse, device=x.device, dtype=self.dtype) + ctrl.scroll_wheel = torch.sign(torch.as_tensor(ctrl.scroll_wheel, device=x.device, dtype=self.dtype)) ctx = self._prep_inputs(x, ctrl) # prepare prompt conditioning From b2b3fb6567c058b8ffdfdd47abe8a688400d2bff Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 9 Mar 2026 13:53:16 -0400 Subject: [PATCH 16/48] disable unimportant compile options --- src/world_engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/world_engine.py b/src/world_engine.py index 6518fc5..6ba0568 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -19,9 +19,10 @@ COMPILE_OPTIONS = { "max_autotune": True, "coordinate_descent_tuning": True, - "triton.cudagraphs": True, # set False to mimic *-no-cudagraphs - "epilogue_fusion": True, # requires max_autotune - "shape_padding": True, + "triton.cudagraphs": True, + # Negligible improvement in throughput: + # "epilogue_fusion": True, + # "shape_padding": True, } From 8f847951d1cd707e69bfe668f8fff73c478f5db3 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 9 Mar 2026 14:56:38 -0400 Subject: [PATCH 17/48] clean up model loading --- src/model/base_model.py | 5 +---- src/world_engine.py | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/model/base_model.py b/src/model/base_model.py index d70e07b..46a938e 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -29,10 +29,7 @@ def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None): if cfg is None: cfg = cls.load_config(path) - model = cls(cfg) - - if dtype is not None: - model = model.to(dtype=dtype, device=device) + model = cls(cfg).to(device=device, dtype=dtype) # Stream weights straight into `model` (no CPU state_dict first) safetensors_path = os.path.join(path, "model.safetensors") diff --git a/src/world_engine.py b/src/world_engine.py index 6ba0568..72c237b 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -64,11 +64,11 @@ def __init__( self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).to(device).eval() if load_weights: - self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg, dtype=dtype) + self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg, dtype=dtype).eval() else: - self.model = WorldModel(self.model_cfg).to(dtype=dtype) + self.model = WorldModel(self.model_cfg).to(dtype=dtype).eval() apply_inference_patches(self.model) - self.model = self.model.to(device=device).eval() + self.model = self.model.to(device=device) if quant is not None: quantize_model(self.model, quant) From 3df610bd55b73b1b382d375985f47ce2e47c9769 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 9 Mar 2026 15:34:40 -0400 Subject: [PATCH 18/48] remove unnecessary push_to_hub --- src/model/base_model.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/model/base_model.py b/src/model/base_model.py index 46a938e..e19d6c7 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -37,12 +37,6 @@ def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None): return model - def push_to_hub(self, uri: str, **kwargs): - huggingface_hub.create_repo(uri, repo_type="model", exist_ok=True, private=True) - with tempfile.TemporaryDirectory() as d: - self.save_pretrained(d) - huggingface_hub.upload_folder(folder_path=d, repo_id=uri, **kwargs) - @staticmethod def load_config(path): if os.path.isdir(path): From a437614bf19f051865b3d7cc479559db8eb4999c Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 9 Mar 2026 15:45:04 -0400 Subject: [PATCH 19/48] remove unnecessary save_pretrained --- src/model/base_model.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/model/base_model.py b/src/model/base_model.py index e19d6c7..8d12bfb 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -1,22 +1,12 @@ import huggingface_hub import os -import tempfile from omegaconf import OmegaConf -from safetensors.torch import save_file, load_file +from safetensors.torch import load_file from torch import nn class BaseModel(nn.Module): - def save_pretrained(self, path: str) -> None: - """Save weights (.safetensors) and OmegaConf YAML.""" - os.makedirs(path, exist_ok=True) - save_file( - {k: v.detach().cpu() for k, v in self.state_dict().items()}, - os.path.join(path, "model.safetensors"), - ) - OmegaConf.save(self.config, os.path.join(path, "config.yaml")) - @classmethod def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None): """Load weights and OmegaConf YAML.""" From 39630e42e35e2ec6375c9daae83fa9406ca9a37c Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 9 Mar 2026 17:15:48 -0400 Subject: [PATCH 20/48] reduce cpu memory --- src/model/base_model.py | 14 ++++--- src/world_engine.py | 85 ++++++++++++++++++++--------------------- 2 files changed, 50 insertions(+), 49 deletions(-) diff --git a/src/model/base_model.py b/src/model/base_model.py index 8d12bfb..5932c57 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -4,13 +4,15 @@ from omegaconf import OmegaConf from safetensors.torch import load_file from torch import nn +import torch class BaseModel(nn.Module): @classmethod - def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None): + def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None, load_weights: bool = True): """Load weights and OmegaConf YAML.""" - device = device or "cpu" + device = torch.get_default_device() if device is None else device + dtype = torch.get_default_dtype() if dtype is None else dtype try: path = huggingface_hub.snapshot_download(path) @@ -19,11 +21,11 @@ def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None): if cfg is None: cfg = cls.load_config(path) - model = cls(cfg).to(device=device, dtype=dtype) + model = cls(cfg).to(dtype=dtype, device=device) - # Stream weights straight into `model` (no CPU state_dict first) - safetensors_path = os.path.join(path, "model.safetensors") - model.load_state_dict(load_file(safetensors_path, device=device), strict=True) + if load_weights: + safetensors_path = os.path.join(path, "model.safetensors") + model.load_state_dict(load_file(safetensors_path, device=device), strict=True) return model diff --git a/src/world_engine.py b/src/world_engine.py index 72c237b..95a99ff 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -48,55 +48,54 @@ def __init__( quant: None | w8a8 | nvfp4 model_config_overrides: Dict to override model config values """ - self.device, self.dtype = device, dtype + self.device = torch.get_default_device() if device is None else device + self.dtype = torch.get_default_dtype() if dtype is None else dtype self.model_cfg = WorldModel.load_config(model_uri) if model_config_overrides: self.model_cfg.merge_with(model_config_overrides) - # Model - self.vae = get_ae(self.model_cfg.ae_uri, getattr(self.model_cfg, "taehv_ae", False), device=device, dtype=dtype) - - self.prompt_encoder = None - if self.model_cfg.prompt_conditioning is not None: - pe_uri = getattr(self.model_cfg, "prompt_encoder_uri", "google/umt5-xl") - self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).to(device).eval() - - if load_weights: - self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg, dtype=dtype).eval() - else: - self.model = WorldModel(self.model_cfg).to(dtype=dtype).eval() - apply_inference_patches(self.model) - self.model = self.model.to(device=device) - - if quant is not None: - quantize_model(self.model, quant) - - # Inference Scheduler - self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, device=device, dtype=dtype) - - pH, pW = getattr(self.model_cfg, "patch", [1, 1]) - self.frm_shape = 1, 1, self.model_cfg.channels, self.model_cfg.height * pH, self.model_cfg.width * pW - - # State - self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device) - self.frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) - - inference_fps = getattr(self.model_cfg, "inference_fps", self.model_cfg.base_fps) - latent_fps = inference_fps / getattr(self.model_cfg, "temporal_compression", 1) - self.ts_mult = int(self.model_cfg.base_fps) // latent_fps - - # Static input context tensors - self._ctx = { - "button": torch.zeros((1, 1, self.model_cfg.n_buttons), device=device, dtype=dtype), - "mouse": torch.zeros((1, 1, 2), device=device, dtype=dtype), - "scroll": torch.zeros((1, 1, 1), device=device, dtype=dtype), - "frame_timestamp": torch.empty((1, 1), device=device, dtype=torch.long), - "frame_idx": torch.empty((1, 1), device=device, dtype=torch.long), - } - - self._prompt_ctx = {"prompt_emb": None, "prompt_pad_mask": None} + with torch.device(self.device): + # Load Model / Modules + self.vae = get_ae(self.model_cfg.ae_uri, getattr(self.model_cfg, "taehv_ae", False), dtype=dtype) + + self.prompt_encoder = None + if self.model_cfg.prompt_conditioning is not None: + pe_uri = getattr(self.model_cfg, "prompt_encoder_uri", "google/umt5-xl") + self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).eval() + + self.model = WorldModel.from_pretrained( + model_uri, cfg=self.model_cfg, dtype=dtype, load_weights=load_weights + ).eval() + apply_inference_patches(self.model) + if quant is not None: + quantize_model(self.model, quant) + + self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype) + + # Inference Scheduler + self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, dtype=dtype) + + pH, pW = getattr(self.model_cfg, "patch", [1, 1]) + self.frm_shape = 1, 1, self.model_cfg.channels, self.model_cfg.height * pH, self.model_cfg.width * pW + + # State + inference_fps = getattr(self.model_cfg, "inference_fps", self.model_cfg.base_fps) + latent_fps = inference_fps / getattr(self.model_cfg, "temporal_compression", 1) + self.ts_mult = int(self.model_cfg.base_fps) // latent_fps + self.frame_ts = torch.tensor([[0]], dtype=torch.long) + + # Static input context tensors + self._ctx = { + "button": torch.zeros((1, 1, self.model_cfg.n_buttons), dtype=dtype), + "mouse": torch.zeros((1, 1, 2), dtype=dtype), + "scroll": torch.zeros((1, 1, 1), dtype=dtype), + "frame_timestamp": torch.empty((1, 1), dtype=torch.long), + "frame_idx": torch.empty((1, 1), dtype=torch.long), + } + + self._prompt_ctx = {"prompt_emb": None, "prompt_pad_mask": None} @torch.inference_mode() def reset(self): From 235276fc09d8a6e8765b7441e537f9e9599a4380 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 10 Mar 2026 14:20:02 -0400 Subject: [PATCH 21/48] pass device --- src/world_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/world_engine.py b/src/world_engine.py index 95a99ff..7d1a892 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -66,7 +66,7 @@ def __init__( self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).eval() self.model = WorldModel.from_pretrained( - model_uri, cfg=self.model_cfg, dtype=dtype, load_weights=load_weights + model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights ).eval() apply_inference_patches(self.model) if quant is not None: From d4fd76a364868e04e0138d80416f5ffc619d61dd Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 10 Mar 2026 23:16:58 +0100 Subject: [PATCH 22/48] fix #27 - use triton-windows longpath fix --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7aeba41..480cdf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,7 @@ dependencies = [ # Triton (platform-specific) "triton==3.6.0; sys_platform == 'linux'", - # TODO: move back to mainline triton-windows once long-path fix is merged - # (see https://github.com/triton-lang/triton-windows/pull/11#issuecomment-4014081904) - "triton-windows @ https://github.com/Overworldai/triton-windows/releases/download/v3.6.0-windows-longpath/triton_windows-3.6.0+git7df5604d-cp313-cp313-win_amd64.whl ; sys_platform == 'win32'", + "triton-windows==3.6.0.post26; sys_platform == 'win32'", ] [tool.setuptools] From 4469f3ef384e8c1432cdb7f4c7b3771be9d1c8b7 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 18 Mar 2026 11:33:44 -0400 Subject: [PATCH 23/48] cleanup dead code --- src/model/world_model.py | 140 +-------------------------------------- src/world_engine.py | 7 +- 2 files changed, 7 insertions(+), 140 deletions(-) diff --git a/src/model/world_model.py b/src/model/world_model.py index 8ea6053..da78b4f 100644 --- a/src/model/world_model.py +++ b/src/model/world_model.py @@ -9,15 +9,6 @@ from torch import nn import torch.nn.functional as F - -try: - from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling - import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa - HAS_FBGEMM = True -except ImportError: - HAS_FBGEMM = False - - from .attn import Attn, CrossAttention, OrthoRoPEAngles from .nn import AdaLN, ada_gate, ada_rmsnorm, NoiseConditioner from .base_model import BaseModel @@ -88,104 +79,6 @@ def forward(self, x: torch.Tensor, is_conditioned: Optional[bool] = None) -> tor return x if is_conditioned else null -class MoEWithoutFBGEMM(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.top_k = config.moe_top_k - moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or config.mlp_ratio / config.moe_top_k - d_intermediate = int(config.d_model * moe_mlp_ratio) - self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False) - self.expert_in_proj = nn.Parameter( - torch.empty(config.moe_n_experts, d_intermediate * (2 if config.gated_linear else 1), config.d_model) - ) - self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_intermediate)) - - def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: - if self.training or torch.is_grad_enabled(): - raise NotImplementedError("inference only") - - orig_shape = x.shape - x = x.reshape(-1, orig_shape[-1]) - logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1)) - - logits_fp32 = logits.float() - scores, expert = logits.topk(self.top_k, dim=-1, sorted=False) - weights = (scores.float() - logits_fp32.logsumexp(dim=-1, keepdim=True)).exp().to(x.dtype) - - expert = expert.flatten() - expert_sorted, sort_idx = expert.sort() - expert_ids = torch.arange(self.expert_in_proj.size(0), device=expert.device, dtype=expert_sorted.dtype) - offsets = torch.searchsorted(expert_sorted, expert_ids, right=True).to(torch.int32) - - # (1) Pad the *indices* instead of cat-copying x_grouped - src = sort_idx // self.top_k - x_grouped = x.index_select(0, torch.cat((src, src[:1]), dim=0)) - h = F.grouped_mm( - x_grouped, - self.expert_in_proj.transpose(-2, -1), - offs=offsets - ) - h[-1].zero_() # ensure last row initialized - - if self.config.gated_linear: - gate_act, up = h.chunk(2, dim=-1) - h = F.silu(gate_act) * up - else: - h = F.silu(h) - - y_grouped = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offsets)[:-1] - y = torch.empty_like(y_grouped).index_copy_(0, sort_idx, y_grouped).view(x.size(0), self.top_k, -1) - return (y * weights.unsqueeze(-1)).sum(dim=1).reshape(orig_shape) - - -class MoE(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.top_k = config.moe_top_k - moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or (config.mlp_ratio / config.moe_top_k) - d_int = int(config.d_model * moe_mlp_ratio) - - self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False) - self.expert_in_proj = nn.Parameter( - torch.empty(config.moe_n_experts, d_int * (2 if config.gated_linear else 1), config.d_model) - ) # (E, N, K) for grouped_mm - self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_int)) # (E, N, K) - - def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: - if self.training or torch.is_grad_enabled(): - raise NotImplementedError("inference only") - - orig = x.shape - x = x.reshape(-1, orig[-1]) - logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1)) - - logits32 = logits.float() - token_counts, expert_sorted, src = index_shuffling(logits32, top_k=self.top_k) - - E = self.expert_in_proj.size(0) - offs = token_counts[:E].cumsum(0).to(torch.int32) - - src = src.to(torch.long) - expert_sorted = expert_sorted.to(torch.long) - logZ = logits32.logsumexp(-1) - w = (logits32[src, expert_sorted] - logZ[src]).exp().to(x.dtype) # [T*K] - - xg = x.index_select(0, torch.cat((src, src[:1]), 0)) # pad by 1 for grouped_mm offs constraint - h = F.grouped_mm(xg, self.expert_in_proj.transpose(-2, -1), offs=offs) - if self.config.gated_linear: - ga, up = h.chunk(2, -1) - h = F.silu(ga) * up - else: - h = F.silu(h) - - yg = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offs)[:-1] - out = torch.zeros_like(x) - torch.ops.fbgemm.scatter_add_dense_tokens(out, (yg * w.unsqueeze(-1)).contiguous(), src) - return out.reshape(orig) - - class MLP(nn.Module): def __init__(self, dim_in, dim_middle, dim_out): super().__init__() @@ -251,10 +144,7 @@ def __init__(self, config, layer_idx): super().__init__() self.config = config self.attn = Attn(config, layer_idx) - if getattr(config, "moe", False): - self.mlp = MoE(config) if HAS_FBGEMM else MoEWithoutFBGEMM(config) - else: - self.mlp = MLP(config.d_model, config.d_model * config.mlp_ratio, config.d_model) + self.mlp = MLP(config.d_model, config.d_model * config.mlp_ratio, config.d_model) self.cond_head = CondHead(config) do_prompt_cond = config.prompt_conditioning is not None and layer_idx % config.prompt_conditioning_period == 0 @@ -418,15 +308,6 @@ def forward( return x - def get_active_parameters(self) -> int: - total = sum(p.numel() for p in self.parameters()) - c = self.config - if getattr(c, "moe", False): - moe_mlp_ratio = getattr(c, "moe_mlp_ratio", None) or c.mlp_ratio / c.moe_top_k - hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts) - total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2) - return total - def load_state_dict(self, state_dict, strict=True, assign=False): if getattr(self.config, "model_type", "waypoint-1") != "waypoint-1.5": return super().load_state_dict(state_dict, strict=strict, assign=assign) @@ -479,22 +360,3 @@ def load_state_dict(self, state_dict, strict=True, assign=False): state_dict = {k: v for k, v in state_dict.items() if ".cond_heads." not in k} return super().load_state_dict(state_dict, strict=strict, assign=assign) - -""" -TODO: use the below for quantization -import torch.nn as nn - - -class BufferLinear(nn.Linear): - capture = True - - def _load_from_state_dict(self, sd, prefix, *args): - if self.capture: - keep = {prefix + "weight", prefix + "bias"} - for k in list(sd): - if k.startswith(prefix) and k not in keep and "." not in k[len(prefix):]: - n = k[len(prefix):] - self.register_buffer(n, sd.pop(k).to(self.weight.device), - persistent=False) - super()._load_from_state_dict(sd, prefix, *args) -""" diff --git a/src/world_engine.py b/src/world_engine.py index 7d1a892..44458fc 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -58,7 +58,12 @@ def __init__( with torch.device(self.device): # Load Model / Modules - self.vae = get_ae(self.model_cfg.ae_uri, getattr(self.model_cfg, "taehv_ae", False), dtype=dtype) + self.vae = get_ae( + self.model_cfg.ae_uri, + getattr(self.model_cfg, "taehv_ae", False), + auto_aspect_ratio=getattr(self.model_cfg, "auto_aspect_ratio", True), + dtype=dtype + ) self.prompt_encoder = None if self.model_cfg.prompt_conditioning is not None: From 4511a7b8f5e98ca654ea47379cf45c84eed2dedb Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 18 Mar 2026 15:05:48 -0400 Subject: [PATCH 24/48] auto 720p --- examples/gen_sample.py | 1 - src/ae.py | 172 +++++++++++++++++++++-------------------- src/world_engine.py | 7 +- 3 files changed, 94 insertions(+), 86 deletions(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index a1c5af2..6d465d0 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -25,7 +25,6 @@ "https://gist.github.com/user-attachments/assets/68c943a4-008a-4c25-948c-c81ab4c47d21", ]) frame = cv2.imdecode(np.frombuffer(urllib.request.urlopen(url).read(), np.uint8), cv2.IMREAD_COLOR) -frame = cv2.resize(frame, (1024, 512))[:, :, ::-1] engine.append_frame(torch.from_numpy(np.repeat(frame[None], 4, axis=0))) diff --git a/src/ae.py b/src/ae.py index d8d537c..4b42b3f 100644 --- a/src/ae.py +++ b/src/ae.py @@ -1,7 +1,94 @@ import torch +import torch.nn.functional as F from torch import Tensor +class ChunkedStreamingTAEHV: + _ENCODE_SIZES = {(720, 1280): (512, 1024), (360, 640): (256, 512)} + _DECODE_SIZES = {v: k for k, v in _ENCODE_SIZES.items()} + + def __init__(self, ae_model, auto_aspect_ratio=True, device=None, dtype=torch.bfloat16): + """ + auto_aspect_ratio: automatically resize so encode input must be 720p or 360p (converted 16:9 to 16:8) + and decode output is 720p or 360p + """ + from taehv import StreamingTAEHV + + self.device = device + self.dtype = dtype + self.auto_aspect_ratio = auto_aspect_ratio + self.streaming_ae_model = StreamingTAEHV(ae_model.eval().to(device=device, dtype=dtype)) + + @classmethod + def from_pretrained(cls, model_uri: str, auto_aspect_ratio=True, **kwargs): + import pathlib + + import huggingface_hub + from taehv import TAEHV + + try: + base = pathlib.Path(huggingface_hub.snapshot_download(model_uri)) + except Exception: + base = pathlib.Path(model_uri) + + ckpt = base if base.is_file() else base / "taehv1_5.pth" + return cls(TAEHV(str(ckpt)), auto_aspect_ratio=auto_aspect_ratio, **kwargs) + + def reset(self): + from taehv import StreamingTAEHV + + # Rebuild streaming state, reuse same weights model + self.streaming_ae_model = StreamingTAEHV(self.streaming_ae_model.taehv) + + def _resize(self, x: Tensor, size: tuple[int, int]) -> Tensor: + return F.interpolate(x[0], size=size, mode="bilinear", align_corners=False)[None] + + @torch.inference_mode() + def encode(self, img: Tensor): + """ + img: [T, H, W, C] uint8 where T == t_downscale + returns: latent [B, C, h, w] + """ + t = self.streaming_ae_model.taehv.t_downscale + assert img.dim() == 4 and img.shape[-1] == 3 and img.shape[0] == t, f"Expected [{t}, H, W, 3] RGB uint8" + + rgb = img.unsqueeze(0)\ + .to(device=self.device, dtype=self.dtype)\ + .permute(0, 1, 4, 2, 3).contiguous().div(255) + + if self.auto_aspect_ratio: + rgb = self._resize(rgb, self._ENCODE_SIZES[img.shape[1:3]]) + + return self.streaming_ae_model.encode(rgb).squeeze(1) + + @torch.inference_mode() + def decode(self, latent: Tensor): + """ + latent: [B, C, h, w] + returns: frames [T, H, W, C] uint8 + """ + assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor" + + z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype) + + if self.streaming_ae_model.n_frames_decoded == 0: + for _ in range(self.streaming_ae_model.taehv.frames_to_trim): + self.streaming_ae_model.decode(z) + self.streaming_ae_model.flush_decoder() + + first = self.streaming_ae_model.decode(z) + assert first is not None, "Expected decoded output after a latent" + frames = [first, *self.streaming_ae_model.flush_decoder()] + + decoded = torch.cat(frames, dim=1) + + if self.auto_aspect_ratio: + decoded = self._resize(decoded, self._DECODE_SIZES[decoded.shape[-2:]]) + + decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8) + return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3] + + class InferenceAE: def __init__(self, ae_model, device=None, dtype=torch.bfloat16): self.device = device @@ -84,89 +171,8 @@ def decode(self, latent: Tensor): return decoded.squeeze(0).permute(1, 2, 0)[..., :3] -class ChunkedStreamingTAEHV: - def __init__(self, ae_model, device=None, dtype=torch.bfloat16): - from taehv import StreamingTAEHV - - self.device = device - self.dtype = dtype - self.streaming_ae_model = StreamingTAEHV( - ae_model.eval().to(device=device, dtype=dtype) - ) - - @classmethod - def from_pretrained(cls, model_uri: str, **kwargs): - import pathlib - - import huggingface_hub - from taehv import TAEHV - - try: - base = pathlib.Path(huggingface_hub.snapshot_download(model_uri)) - except Exception: - base = pathlib.Path(model_uri) - - ckpt = base if base.is_file() else base / "taehv1_5.pth" - return cls(TAEHV(str(ckpt)), **kwargs) - - def reset(self): - from taehv import StreamingTAEHV - - # Rebuild streaming state, reuse same weights model - self.streaming_ae_model = StreamingTAEHV(self.streaming_ae_model.taehv) - - @torch.inference_mode() - def encode(self, img: Tensor): - """ - img: [T, H, W, C] uint8 where T == t_downscale - returns: latent [B, C, h, w] - """ - assert img.dim() == 4 and img.shape[-1] == 3, "Expected [T, H, W, C] RGB uint8" - - t = self.streaming_ae_model.taehv.t_downscale - if img.shape[0] != t: - raise ValueError(f"Expected {t} frames, got {img.shape[0]}") - - rgb = ( - img.unsqueeze(0) - .to(device=self.device, dtype=self.dtype) - .permute(0, 1, 4, 2, 3) - .contiguous() - .div(255) - ) - - latent = self.streaming_ae_model.encode(rgb) - if latent is None: - raise RuntimeError("Expected a latent after a full chunk") - - return latent.squeeze(1) - - @torch.inference_mode() - def decode(self, latent: Tensor): - """ - latent: [B, C, h, w] - returns: frames [T, H, W, C] uint8 - """ - assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor" - - z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype) - - if self.streaming_ae_model.n_frames_decoded == 0: - for _ in range(self.streaming_ae_model.taehv.frames_to_trim): - self.streaming_ae_model.decode(z) - self.streaming_ae_model.flush_decoder() - - first = self.streaming_ae_model.decode(z) - assert first is not None, "Expected decoded output after a latent" - frames = [first, *self.streaming_ae_model.flush_decoder()] - - decoded = torch.cat(frames, dim=1) - decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8) - return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3] - - -def get_ae(ae_uri, is_taehv_ae=False, **kwargs): +def get_ae(ae_uri, is_taehv_ae=False, auto_aspect_ratio=True, **kwargs): if is_taehv_ae: - return ChunkedStreamingTAEHV.from_pretrained(ae_uri, **kwargs) + return ChunkedStreamingTAEHV.from_pretrained(ae_uri, auto_aspect_ratio=auto_aspect_ratio, **kwargs) else: return InferenceAE.from_pretrained(ae_uri, **kwargs) diff --git a/src/world_engine.py b/src/world_engine.py index 44458fc..a2ffc6c 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -46,7 +46,9 @@ def __init__( """ model_uri: HF URI or local folder containing model.safetensors and config.yaml quant: None | w8a8 | nvfp4 + model_config_overrides: Dict to override model config values + - auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p """ self.device = torch.get_default_device() if device is None else device self.dtype = torch.get_default_dtype() if dtype is None else dtype @@ -60,9 +62,10 @@ def __init__( # Load Model / Modules self.vae = get_ae( self.model_cfg.ae_uri, - getattr(self.model_cfg, "taehv_ae", False), + is_taehv_ae=getattr(self.model_cfg, "taehv_ae", False), auto_aspect_ratio=getattr(self.model_cfg, "auto_aspect_ratio", True), - dtype=dtype + dtype=dtype, + device=device, ) self.prompt_encoder = None From 844c44ca311c16188bb967591bdfa1dc73549034 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 19 Mar 2026 15:15:29 -0400 Subject: [PATCH 25/48] ensure correct device --- src/world_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/world_engine.py b/src/world_engine.py index a2ffc6c..00aff1a 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -71,7 +71,7 @@ def __init__( self.prompt_encoder = None if self.model_cfg.prompt_conditioning is not None: pe_uri = getattr(self.model_cfg, "prompt_encoder_uri", "google/umt5-xl") - self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).eval() + self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype, device=device).eval() self.model = WorldModel.from_pretrained( model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights @@ -80,10 +80,10 @@ def __init__( if quant is not None: quantize_model(self.model, quant) - self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype) + self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype, device=device) # Inference Scheduler - self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, dtype=dtype) + self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, dtype=dtype, device=device) pH, pW = getattr(self.model_cfg, "patch", [1, 1]) self.frm_shape = 1, 1, self.model_cfg.channels, self.model_cfg.height * pH, self.model_cfg.width * pW From 3612b8b7250f677b6f9efd5b070a55d4e96902cf Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 19 Mar 2026 15:29:02 -0400 Subject: [PATCH 26/48] no internal model URIs, document requirements in docstring at top --- examples/benchmark.py | 24 +++++++++++---- examples/gen_sample.py | 8 +++-- examples/prof.py | 8 ++++- examples/simple_client.py | 64 --------------------------------------- 4 files changed, 30 insertions(+), 74 deletions(-) delete mode 100644 examples/simple_client.py diff --git a/examples/benchmark.py b/examples/benchmark.py index 9bd75fa..a153836 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -1,12 +1,17 @@ +""" +Additional Dependencies: pytest-benchmark +Run: `pytest ./examples/benchmark.py` +Run: `MODEL_URI="Overworld/Waypoint-1.5-1B" pytest ./examples/benchmark.py` +""" + +import os import pytest import torch from world_engine import WorldEngine -# TODO -# - benchmark encode img -# - benchmark encode prompt +MODEL_URI = os.environ.get("MODEL_URI", "Overworld/Waypoint-1-Small") def version_with_commit(pkg): @@ -65,8 +70,8 @@ def get_warm_engine(model_uri, model_overrides=None): @pytest.fixture(scope="session") -def engine(model_uri="Overworld/Waypoint-1-Small"): - return get_warm_engine(model_uri) +def engine(): + return get_warm_engine(MODEL_URI) @pytest.fixture(scope="session") @@ -93,7 +98,14 @@ def run(): ids=lambda d: (",".join(f"{k}={v}" for k, v in d.items()) or "") if d else "" ) def test_ar_rollout(benchmark, dit_only, n_frames, model_overrides): - engine = get_warm_engine("Overworld/Waypoint-1-Small", model_overrides=model_overrides) + engine = get_warm_engine(MODEL_URI, model_overrides=model_overrides) + + try: + total_params = sum(p.numel() for p in engine.model.parameters()) + active_params = int(engine.model.get_active_parameters()) + benchmark.name = f"{benchmark.name} | params={total_params:,} | active={active_params:,}" + except Exception: + pass def setup(): engine.reset() diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 6d465d0..ea5743e 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -1,6 +1,8 @@ -# python3 examples/gen_sample.py -# e.g. python3 examples/gen_sample.py Overworld-Models/Lapp0-WP-Mini-1.4.5-BL-Distill - +""" +Additional Dependencies: opencv-python imageio[pyav] +Run: `python3 examples/gen_sample.py Overworld/Waypoint-1.5-1B` +Run: `python3 examples/gen_sample.py ` +""" import cv2 import imageio.v3 as iio import random diff --git a/examples/prof.py b/examples/prof.py index b8cd570..b469076 100644 --- a/examples/prof.py +++ b/examples/prof.py @@ -1,3 +1,9 @@ +""" +Additional Dependencies: N/A +Run: `python3 examples/prof.py Overworld/Waypoint-1.5-1B` +""" +import sys + import torch from torch.profiler import profile, ProfilerActivity @@ -5,7 +11,7 @@ def do_profile(n_frames=64, row_limit=20): - engine = WorldEngine("OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing", device="cuda") + engine = WorldEngine(sys.argv[1], device="cuda") # warmup for _ in range(4): engine.gen_frame() diff --git a/examples/simple_client.py b/examples/simple_client.py deleted file mode 100644 index 2f5d410..0000000 --- a/examples/simple_client.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import AsyncIterable, AsyncIterator - -import asyncio -import contextlib -import cv2 -import sys -import torch - -from world_engine import WorldEngine, CtrlInput - - -async def render(frames: AsyncIterable[torch.Tensor], win_name="Hello (Over)World (ESC to exit)") -> None: - """Render stream of RGB tensor images.""" - cv2.namedWindow(win_name, cv2.WINDOW_AUTOSIZE | cv2.WINDOW_GUI_NORMAL) - async for t in frames: - cv2.imshow(win_name, t.cpu().numpy()) - await asyncio.sleep(0) - cv2.destroyAllWindows() - - -async def frame_stream(engine: WorldEngine, ctrls: AsyncIterable[CtrlInput]) -> AsyncIterator[torch.Tensor]: - """Generate frame by calling Engine for each ctrl.""" - yield await asyncio.to_thread(engine.gen_frame) - async for ctrl in ctrls: - yield await asyncio.to_thread(engine.gen_frame, ctrl=ctrl) - - -async def ctrl_stream(delay: int = 1) -> AsyncIterator[CtrlInput]: - """Accumulate key presses asyncronously. Yield CtrlInput once next() is called.""" - q: asyncio.Queue[int] = asyncio.Queue() - - async def producer() -> None: - while True: - k = cv2.waitKey(delay) - if k != -1: - await q.put(k) - await asyncio.sleep(0) - - prod_task = asyncio.create_task(producer()) - while True: - buttons: set[int] = set() - # Drain everything currently in the queue into this batch - with contextlib.suppress(asyncio.QueueEmpty): - while True: - k = q.get_nowait() - if k == 27: - # End if ESC pressed - prod_task.cancel() - return - buttons.add(k) - - yield CtrlInput(button=buttons) - - -async def main() -> None: - uri = sys.argv[1] if len(sys.argv) > 1 else "OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing" - engine = WorldEngine(uri, device="cuda") - ctrls = ctrl_stream() - frames = frame_stream(engine, ctrls) - await render(frames) - - -if __name__ == "__main__": - asyncio.run(main()) From f5cc301fb85befbfa04639ac791f32b196c22db3 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 19 Mar 2026 15:45:10 -0400 Subject: [PATCH 27/48] update readme to document WP1.5 --- README.md | 21 ++++++++++++++++----- examples/gen_sample.py | 2 +- src/world_engine.py | 4 ++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index ba885b2..e90e018 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ export HF_TOKEN= from world_engine import WorldEngine, CtrlInput # Create inference engine -engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") +engine = WorldEngine("Overworld/Waypoint-1.5-1B", device="cuda") # Specify a prompt engine.set_prompt("A fun game") @@ -77,6 +77,15 @@ for controller_input in [ img = engine.gen_frame(ctrl=controller_input) ``` +## Waypoint-1.5 Behavior +All interfaces between Waypoint-1 (or 1.1) and Waypoint-1.5 **except** the following: + +In Waypoint-1.5, the `img` passed to `append_frame(...)` and returned by `gen_frame(...)` is now a sequence of 4 frames. Waypoint-1.5 applies temporal compression and generates 4 frames for every controller input. + +Whereas previously, `img` was a uint8 rgb array of shape `[Height, Width, 3]`, **in Waypoint-1.5 it is of shape `[4, Height, Width, 3]`**. + +Additionally, Waypoint-1.5 expects 720p inputs / outputs, therefore `img` is `[4, 720, 1280, 3]`. + ## Usage ``` from world_engine import WorldEngine, CtrlInput @@ -84,7 +93,7 @@ from world_engine import WorldEngine, CtrlInput Load model to GPU ``` -engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") +engine = WorldEngine("Overworld/Waypoint-1.5-1B", device="cuda") ``` Specify a prompt which will be used until this function is called again @@ -118,11 +127,13 @@ Note: returned `img` is always on the same device as `engine.device` @dataclass class CtrlInput: button: Set[int] = field(default_factory=set) # pressed button IDs - mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) position + mouse: Tuple[float, float] = (0.0, 0.0) # (dx, dy) position change + scroll_wheel: int = 0 # down, stationary, or up -> (-1, 0, 1) ``` - `button` keycodes are defined by [Owl-Control](https://github.com/Overworldai/owl-control/blob/main/src/system/keycode.rs) -- `mouse` is the raw mouse velocity vector +- `mouse` is the the amount the change in mouse since last frame +- `scroll_wheel` is the ternary scroll wheel movement identifier ## Showcase and Examples @@ -138,5 +149,5 @@ class CtrlInput: ### Examples and Reference Code -- ["Hello (Over)World" client](./examples/simple_client.py) +- ["Generate MP4 Sample Given Controller Inputs](./examples/gen_sample.py) - [Run Performance Benchmarks (`pytest examples/benchmark.py`)](./examples/benchmark.py) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index ea5743e..6b0aa55 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -1,7 +1,7 @@ """ Additional Dependencies: opencv-python imageio[pyav] Run: `python3 examples/gen_sample.py Overworld/Waypoint-1.5-1B` -Run: `python3 examples/gen_sample.py ` +Run: `python3 examples/gen_sample.py ` """ import cv2 import imageio.v3 as iio diff --git a/src/world_engine.py b/src/world_engine.py index 00aff1a..6923df5 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -29,8 +29,8 @@ @dataclass class CtrlInput: button: Set[int] = field(default_factory=set) # pressed button IDs - mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) velocity - scroll_wheel: int = 0 # bwd, stationary, or fwd -> (-1, 0, 1) + mouse: Tuple[float, float] = (0.0, 0.0) # (dx, dy) velocity + scroll_wheel: int = 0 # down, stationary, or up -> (-1, 0, 1) class WorldEngine: From 5c910b703b07df05e25fab96d846362ec034a80b Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 19 Mar 2026 15:45:47 -0400 Subject: [PATCH 28/48] update readme to document WP1.5 --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index e90e018..9716499 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,8 @@ Whereas previously, `img` was a uint8 rgb array of shape `[Height, Width, 3]`, * Additionally, Waypoint-1.5 expects 720p inputs / outputs, therefore `img` is `[4, 720, 1280, 3]`. +See [examples/gen_sample.py](./examples/gen_sample.py) for reference. + ## Usage ``` from world_engine import WorldEngine, CtrlInput From 274d685ee857c3665ca6d50a781799f602e04a7e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 12:32:53 -0400 Subject: [PATCH 29/48] no fbgemm dep --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 480cdf9..4ebfe2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ dependencies = [ "torch==2.10.0", "torchvision==0.25.0", "torchaudio==2.10.0", - "fbgemm-gpu-genai==1.5.0; sys_platform == 'linux'", "einops", "rotary-embedding-torch>=0.8.8", "tensordict==0.10.0", From ae9ac616ae4644da8ff5d3f51f44ab0aa766d36e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 12:33:44 -0400 Subject: [PATCH 30/48] benchmark dont force AE --- examples/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index a153836..222e70c 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -54,7 +54,7 @@ def print_env_info(): def get_warm_engine(model_uri, model_overrides=None): - model_config_overrides = {"ae_uri": "OpenWorldLabs/owl_vae_f16_c16_distill_v0_nogan"} + model_config_overrides = {} model_config_overrides.update(model_overrides or {}) engine = WorldEngine( model_uri, From 25503c1e7296a53e66d80cfb1a902a6d2f515b59 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 12:33:59 -0400 Subject: [PATCH 31/48] move kv cache to appropriate device --- src/world_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/world_engine.py b/src/world_engine.py index 6923df5..5f771d8 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -80,7 +80,7 @@ def __init__( if quant is not None: quantize_model(self.model, quant) - self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype, device=device) + self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device) # Inference Scheduler self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, dtype=dtype, device=device) From da66a6df8fc0b00f0d17aebbf2d3eb78f2171100 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 12:36:46 -0400 Subject: [PATCH 32/48] improve example w/ four_frames var --- examples/gen_sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 6b0aa55..2e82723 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -49,6 +49,7 @@ # Generate frames conditioned on controller inputs with iio.imopen("out.mp4", "w", plugin="pyav") as out: - out.write(engine.gen_frame().cpu().numpy(), fps=60, codec="libx264") + four_frames = engine.gen_frame().cpu().numpy() # int8 [4, H, W, 3] + out.write(four_frames, fps=60, codec="libx264") for ctrl in controller_sequence: out.write(engine.gen_frame(ctrl=ctrl).cpu().numpy()) From f482fc279d70c376bed1b9a97918ab27ad0a0153 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 12:44:17 -0400 Subject: [PATCH 33/48] improve example w/ four_frames var --- examples/gen_sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 2e82723..7bbba05 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -52,4 +52,5 @@ four_frames = engine.gen_frame().cpu().numpy() # int8 [4, H, W, 3] out.write(four_frames, fps=60, codec="libx264") for ctrl in controller_sequence: - out.write(engine.gen_frame(ctrl=ctrl).cpu().numpy()) + four_frames = engine.gen_frame(ctrl=ctrl).cpu().numpy() + out.write(four_frames) From e8cd112ddf5f83efe6058a8572fca9ab65e70be6 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 12:46:41 -0400 Subject: [PATCH 34/48] improve example w/ four_frames var --- examples/gen_sample.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 7bbba05..84087f6 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -18,18 +18,6 @@ engine = WorldEngine(sys.argv[1], device="cuda") -# Set seed frame -url = random.choice([ - "https://gist.github.com/user-attachments/assets/d81c6d26-a838-4afe-9d13-fd67677043c3", - "https://gist.github.com/user-attachments/assets/b6d18c38-098e-43b0-8e61-66a16e5d8946", - "https://gist.github.com/user-attachments/assets/0734a8c1-3eb4-4ffe-8c37-5665c45ab559", - "https://gist.github.com/user-attachments/assets/f9c20d4d-7565-452d-8b02-42a85ea175ed", - "https://gist.github.com/user-attachments/assets/68c943a4-008a-4c25-948c-c81ab4c47d21", -]) -frame = cv2.imdecode(np.frombuffer(urllib.request.urlopen(url).read(), np.uint8), cv2.IMREAD_COLOR) -engine.append_frame(torch.from_numpy(np.repeat(frame[None], 4, axis=0))) - - # Define sequence of controller inputs applied controller_sequence = [ # move mouse, jump, do nothing, trigger, do nothing, trigger+jump, do nothing @@ -47,10 +35,22 @@ controller_sequence += [CtrlInput()] * 10 +# Set seed frame +url = random.choice([ + "https://gist.github.com/user-attachments/assets/d81c6d26-a838-4afe-9d13-fd67677043c3", + "https://gist.github.com/user-attachments/assets/b6d18c38-098e-43b0-8e61-66a16e5d8946", + "https://gist.github.com/user-attachments/assets/0734a8c1-3eb4-4ffe-8c37-5665c45ab559", + "https://gist.github.com/user-attachments/assets/f9c20d4d-7565-452d-8b02-42a85ea175ed", + "https://gist.github.com/user-attachments/assets/68c943a4-008a-4c25-948c-c81ab4c47d21", +]) +seed_frame = cv2.imdecode(np.frombuffer(urllib.request.urlopen(url).read(), np.uint8), cv2.IMREAD_COLOR) +seed_frame_x4 = torch.from_numpy(np.repeat(seed_frame[None], 4, axis=0)) + + # Generate frames conditioned on controller inputs with iio.imopen("out.mp4", "w", plugin="pyav") as out: - four_frames = engine.gen_frame().cpu().numpy() # int8 [4, H, W, 3] - out.write(four_frames, fps=60, codec="libx264") + engine.append_frame(seed_frame_x4) + out.write(seed_frame_x4, fps=60, codec="libx264") for ctrl in controller_sequence: four_frames = engine.gen_frame(ctrl=ctrl).cpu().numpy() out.write(four_frames) From 9308e9e612da97fe79b4157895607098dd296074 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 13:22:22 -0400 Subject: [PATCH 35/48] dev dependency group for examples, uv docs --- examples/benchmark.py | 6 +----- examples/gen_sample.py | 7 ++----- pyproject.toml | 11 ++++++++++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 222e70c..d9d3126 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -1,8 +1,4 @@ -""" -Additional Dependencies: pytest-benchmark -Run: `pytest ./examples/benchmark.py` -Run: `MODEL_URI="Overworld/Waypoint-1.5-1B" pytest ./examples/benchmark.py` -""" +# MODEL_URI="Overworld/Waypoint-1.5-1B" uv run --dev pytest examples/benchmark.py Overworld/Waypoint-1.5-1B import os import pytest diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 84087f6..0217896 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -1,8 +1,5 @@ -""" -Additional Dependencies: opencv-python imageio[pyav] -Run: `python3 examples/gen_sample.py Overworld/Waypoint-1.5-1B` -Run: `python3 examples/gen_sample.py ` -""" +# uv run --dev examples/gen_sample.py Overworld/Waypoint-1.5-1B + import cv2 import imageio.v3 as iio import random diff --git a/pyproject.toml b/pyproject.toml index 4ebfe2f..b30b2cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "torchvision==0.25.0", "torchaudio==2.10.0", "einops", - "rotary-embedding-torch>=0.8.8", + "rotary-embedding-torch==0.8.9", "tensordict==0.10.0", "transformers==4.57.3", "ftfy", @@ -34,3 +34,12 @@ packages = [ [tool.setuptools.package-dir] world_engine = "src" + +[dependency-groups] +dev = [ + "pytest", + "pytest-benchmark", + "opencv-python", + "imageio[pyav]", + "numpy", +] From c2261dd3fcae1c300fa512202eaa85a72be163e4 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 13:24:16 -0400 Subject: [PATCH 36/48] dev dependency group for examples, uv docs --- examples/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index d9d3126..98724ea 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -1,4 +1,4 @@ -# MODEL_URI="Overworld/Waypoint-1.5-1B" uv run --dev pytest examples/benchmark.py Overworld/Waypoint-1.5-1B +# MODEL_URI="Overworld/Waypoint-1.5-1B" uv run --dev pytest examples/benchmark.py import os import pytest From e2060f0f5dbb782afb843d36ba6fb8bdb28987d5 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 20 Mar 2026 13:36:53 -0400 Subject: [PATCH 37/48] fix a missing word --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9716499..8d68292 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ for controller_input in [ ``` ## Waypoint-1.5 Behavior -All interfaces between Waypoint-1 (or 1.1) and Waypoint-1.5 **except** the following: +All interfaces and handling for Waypoint-1 (or 1.1) and Waypoint-1.5 remain the same **except** the following: In Waypoint-1.5, the `img` passed to `append_frame(...)` and returned by `gen_frame(...)` is now a sequence of 4 frames. Waypoint-1.5 applies temporal compression and generates 4 frames for every controller input. From 3d8b327da9457474f508fb51f5580a8b43f19960 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 21 Mar 2026 12:35:28 -0400 Subject: [PATCH 38/48] remove rotary embedding pytorch dependency --- pyproject.toml | 1 - src/model/attn.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b30b2cc..414c61f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ dependencies = [ "torchvision==0.25.0", "torchaudio==2.10.0", "einops", - "rotary-embedding-torch==0.8.9", "tensordict==0.10.0", "transformers==4.57.3", "ftfy", diff --git a/src/model/attn.py b/src/model/attn.py index c6024e8..bac2e95 100644 --- a/src/model/attn.py +++ b/src/model/attn.py @@ -4,14 +4,12 @@ from torch.nn.attention.flex_attention import flex_attention -from rotary_embedding_torch import RotaryEmbedding - from .nn import rms_norm, NoCastModule class OrthoRoPEAngles(NoCastModule): """Functions as a on the fly RoPE angle computer called every fwd pass. Should be setup - as a module under WordDiT, then each forward pass it computes a shared tuple of (rope_cos, rope_sin) + as a module under WordDiT, then each forward pass it computes a shared tuple of (rope_cos, rope_sin) tensors that get passed to every block for their underlying RoPE computations.""" def __init__(self, config): super().__init__() From ebe31bb57fc3a385b5c0e60bb8f575c9216a9180 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 21 Mar 2026 14:55:31 -0400 Subject: [PATCH 39/48] improve throughput by 4% --- src/model/kv_cache.py | 54 +++++++++++++++++++------------------------ src/world_engine.py | 1 + 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/src/model/kv_cache.py b/src/model/kv_cache.py index 898b8a9..6694049 100644 --- a/src/model/kv_cache.py +++ b/src/model/kv_cache.py @@ -13,53 +13,47 @@ def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: """ T: Q length for this frame L: KV capacity == written.numel() - written: [L] bool, True where there is valid KV data + written: [L] bool, True where there is valid KV data. + T and L must be exact multiples of the sparse block size; `written` must be + block-aligned, i.e. each block is either all True or all False. """ BS = _DEFAULT_SPARSE_BLOCK_SIZE - KV_blocks = (L + BS - 1) // BS - Q_blocks = (T + BS - 1) // BS - # [KV_blocks, BS] - written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view(KV_blocks, BS) - - # Block-level occupancy - block_any = written_blocks.any(-1) # block has at least one written token - block_all = written_blocks.all(-1) # block is fully written + if not torch.compiler.is_compiling(): + torch._check(T % BS == 0, f"T ({T}) must be a multiple of block size ({BS})") + torch._check(L % BS == 0, f"L ({L}) must be a multiple of block size ({BS})") - # Every Q-block sees the same KV-block pattern - nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) # [Q_blocks, KV_blocks] - full_bm = block_all[None, :].expand_as(nonzero_bm) # [Q_blocks, KV_blocks] - partial_bm = nonzero_bm & ~full_bm # [Q_blocks, KV_blocks] + Q_blocks = T // BS + KV_blocks = L // BS - def dense_to_ordered(dense_mask: torch.Tensor): - # dense_mask: [Q_blocks, KV_blocks] bool - # returns: [1,1,Q_blocks], [1,1,Q_blocks,KV_blocks] - num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) # [Q_blocks] - indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(torch.int32) - return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + # [KV_blocks, BS] + written_blocks = written.view(KV_blocks, BS) - # Partial blocks (need mask_mod) - kv_num_blocks, kv_indices = dense_to_ordered(partial_bm) + # For a valid block-aligned mask, each block is either all written or all empty. + block_any = written_blocks.any(-1) + if not torch.compiler.is_compiling(): + assert torch.equal(block_any, written_blocks.all(-1)), "written must be block-aligned" - # Full blocks (mask_mod can be skipped entirely) - full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) + # Every KV block is a full block + full_bm = block_any[None, :].expand(Q_blocks, KV_blocks) + full_kv_num_blocks = full_bm.sum(dim=-1, dtype=torch.int32)[None, None].contiguous() + full_kv_indices = full_bm.argsort(dim=-1, descending=True, stable=True).to(torch.int32)[None, None].contiguous() - def mask_mod(b, h, q, kv): - return written[kv] + # No partial blocks at all. + kv_num_blocks = torch.zeros((1, 1, Q_blocks), dtype=torch.int32, device=written.device) + kv_indices = torch.zeros((1, 1, Q_blocks, KV_blocks), dtype=torch.int32, device=written.device) - bm = BlockMask.from_kv_blocks( + return BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, BLOCK_SIZE=BS, - mask_mod=mask_mod, + mask_mod=None, seq_lengths=(T, L), - compute_q_blocks=False, # no backward, avoids the transpose/_ordered_to_dense path + compute_q_blocks=False, ) - return bm - class LayerKVCache(nn.Module): """ diff --git a/src/world_engine.py b/src/world_engine.py index 5f771d8..8c32843 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -159,6 +159,7 @@ def _prep_inputs(self, x, ctrl=None): return self._ctx + @torch.compile def prep_inputs(self, x, ctrl=None): ctrl = ctrl if ctrl is not None else CtrlInput() self._ctx["button"].zero_() From 3ff84dd0c5ffd2b1c785d0d203eedf22dbfb161b Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 21 Mar 2026 15:51:47 -0400 Subject: [PATCH 40/48] add non-blocking benchmark --- examples/benchmark.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 98724ea..d420133 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -87,13 +87,14 @@ def run(): MODEL_OVERRIDES = [None] +@pytest.mark.parametrize("blocking", [True, False]) @pytest.mark.parametrize("dit_only", [True]) @pytest.mark.parametrize("n_frames", [256]) @pytest.mark.parametrize( "model_overrides", MODEL_OVERRIDES, ids=lambda d: (",".join(f"{k}={v}" for k, v in d.items()) or "") if d else "" ) -def test_ar_rollout(benchmark, dit_only, n_frames, model_overrides): +def test_ar_rollout(benchmark, dit_only, n_frames, model_overrides, blocking): engine = get_warm_engine(MODEL_URI, model_overrides=model_overrides) try: @@ -111,6 +112,7 @@ def setup(): def target(): for _ in range(n_frames): engine.gen_frame(return_img=not dit_only) - torch.cuda.synchronize() + if blocking: + torch.cuda.synchronize() benchmark.pedantic(target, setup=setup, rounds=20) From 60c886486803d167cc381d1c3f96217cf340eeb0 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 23 Mar 2026 14:46:56 -0400 Subject: [PATCH 41/48] no compile prep inputs --- src/world_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/world_engine.py b/src/world_engine.py index 8c32843..5f771d8 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -159,7 +159,6 @@ def _prep_inputs(self, x, ctrl=None): return self._ctx - @torch.compile def prep_inputs(self, x, ctrl=None): ctrl = ctrl if ctrl is not None else CtrlInput() self._ctx["button"].zero_() From f5bf64e9f2bcf715ec55ca77c7393c919c5589f8 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 23 Mar 2026 15:31:28 -0400 Subject: [PATCH 42/48] compile prep inputs after converting to tensor, avoid blocking --- src/world_engine.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/world_engine.py b/src/world_engine.py index 5f771d8..4b61496 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -149,8 +149,12 @@ def gen_frame(self, ctrl: CtrlInput = None, return_img: bool = True): @torch.compile def _prep_inputs(self, x, ctrl=None): + self._ctx["button"].zero_() + self._ctx["button"][..., ctrl.button] = 1.0 + self._ctx["mouse"][0, 0, 0] = ctrl.mouse[0] self._ctx["mouse"][0, 0, 1] = ctrl.mouse[1] + self._ctx["scroll"][0, 0, 0] = ctrl.scroll_wheel self._ctx["frame_idx"].copy_(self.frame_ts) @@ -161,11 +165,9 @@ def _prep_inputs(self, x, ctrl=None): def prep_inputs(self, x, ctrl=None): ctrl = ctrl if ctrl is not None else CtrlInput() - self._ctx["button"].zero_() - if ctrl.button: - self._ctx["button"][..., list(ctrl.button)] = 1.0 - ctrl.mouse = torch.as_tensor(ctrl.mouse, device=x.device, dtype=self.dtype) - ctrl.scroll_wheel = torch.sign(torch.as_tensor(ctrl.scroll_wheel, device=x.device, dtype=self.dtype)) + ctrl.button = torch.as_tensor(x, dtype=torch.int64).to(x.device, non_blocking=True) + ctrl.mouse = torch.as_tensor(ctrl.mouse).to(x.device, non_blocking=True) + ctrl.scroll_wheel = torch.as_tensor(ctrl.scroll_wheel).to(x.device, non_blocking=True) ctx = self._prep_inputs(x, ctrl) # prepare prompt conditioning From 2926cfb57d3165bb5ff26950a0361216bf565bb2 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 23 Mar 2026 17:27:06 -0400 Subject: [PATCH 43/48] fix, don't pass device to prompt encoder --- src/world_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/world_engine.py b/src/world_engine.py index 4b61496..b912ac1 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -71,7 +71,7 @@ def __init__( self.prompt_encoder = None if self.model_cfg.prompt_conditioning is not None: pe_uri = getattr(self.model_cfg, "prompt_encoder_uri", "google/umt5-xl") - self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype, device=device).eval() + self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).eval() self.model = WorldModel.from_pretrained( model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights From 946be2c5057ec2c1cb3397c3f7ffe49a100bcf6f Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 23 Mar 2026 17:29:16 -0400 Subject: [PATCH 44/48] fix button incorrect --- src/world_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/world_engine.py b/src/world_engine.py index b912ac1..c97599a 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -165,7 +165,7 @@ def _prep_inputs(self, x, ctrl=None): def prep_inputs(self, x, ctrl=None): ctrl = ctrl if ctrl is not None else CtrlInput() - ctrl.button = torch.as_tensor(x, dtype=torch.int64).to(x.device, non_blocking=True) + ctrl.button = torch.as_tensor(ctrl.button, dtype=torch.int64).to(x.device, non_blocking=True) ctrl.mouse = torch.as_tensor(ctrl.mouse).to(x.device, non_blocking=True) ctrl.scroll_wheel = torch.as_tensor(ctrl.scroll_wheel).to(x.device, non_blocking=True) ctx = self._prep_inputs(x, ctrl) From f0be311cc3994c181c84226784cb9a6b5cae3a89 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 23 Mar 2026 17:32:01 -0400 Subject: [PATCH 45/48] fix button incorrect --- src/world_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/world_engine.py b/src/world_engine.py index c97599a..8e22f35 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -165,7 +165,7 @@ def _prep_inputs(self, x, ctrl=None): def prep_inputs(self, x, ctrl=None): ctrl = ctrl if ctrl is not None else CtrlInput() - ctrl.button = torch.as_tensor(ctrl.button, dtype=torch.int64).to(x.device, non_blocking=True) + ctrl.button = torch.as_tensor(list(ctrl.button), dtype=torch.int64).to(x.device, non_blocking=True) ctrl.mouse = torch.as_tensor(ctrl.mouse).to(x.device, non_blocking=True) ctrl.scroll_wheel = torch.as_tensor(ctrl.scroll_wheel).to(x.device, non_blocking=True) ctx = self._prep_inputs(x, ctrl) From a236313591526d6c0a48df3f96d938e2ccc50100 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 25 Mar 2026 13:58:31 -0400 Subject: [PATCH 46/48] benchmark w/ ctrls --- examples/benchmark.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index d420133..b42c4b4 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -110,7 +110,15 @@ def setup(): torch.cuda.synchronize() def target(): - for _ in range(n_frames): + ctrls = [ + CtrlInput( + button=set(random.sample(range(1, 65), random.randint(0, 10))), + mouse=(random.random(), random.random()), + scroll_wheel=random.choice((-1, 0, 1)) + ) + for _ in range(n_frames) + ] + for ctrl in ctrls: engine.gen_frame(return_img=not dit_only) if blocking: torch.cuda.synchronize() From 9607bcf763f1e8c3a14dabbbd41aee2148a7af9e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 25 Mar 2026 13:58:37 -0400 Subject: [PATCH 47/48] benchmark w/ ctrls --- examples/benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index b42c4b4..3d8e623 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -3,8 +3,9 @@ import os import pytest import torch +import random -from world_engine import WorldEngine +from world_engine import WorldEngine, CtrlInput MODEL_URI = os.environ.get("MODEL_URI", "Overworld/Waypoint-1-Small") From 1d286e168ef27fc0304720c642838bc3be8ef943 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 25 Mar 2026 14:00:09 -0400 Subject: [PATCH 48/48] update config defaults --- src/model/attn.py | 20 ++++++++++---------- src/model/base_model.py | 22 +++++++++++++++++++++- src/model/kv_cache.py | 4 ++-- src/model/world_model.py | 2 +- src/world_engine.py | 12 +++++------- 5 files changed, 39 insertions(+), 21 deletions(-) diff --git a/src/model/attn.py b/src/model/attn.py index bac2e95..76405b1 100644 --- a/src/model/attn.py +++ b/src/model/attn.py @@ -19,25 +19,25 @@ def __init__(self, config): torch._assert(d_head % 8 == 0, "d_head must be divisible by 8") d_xy, d_t = d_head // 8, d_head // 4 - nyq = float(getattr(config, "rope_nyquist_frac", 0.8)) + nyq = float(config.rope_nyquist_frac) max_freq = min(self.config.height, self.config.width) * nyq n = (d_xy + 1) // 2 xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy] - theta = float(getattr(config, "rope_theta", 10000.0)) + theta = float(config.rope_theta) inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t)) inv_t = inv_t.repeat_interleave(2) # [d_t] - self.register_buffer("xy", xy, persistent=False) # [d_xy] - self.register_buffer("inv_t", inv_t, persistent=False) # [d_t] + self.register_buffer("xy", xy, persistent=False) # [d_xy] + self.register_buffer("inv_t", inv_t, persistent=False) # [d_t] @torch.autocast("cuda", enabled=False) def forward(self, pos_ids): if not torch.compiler.is_compiling(): torch._assert( - (pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width), + (pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width), f"pos_ids out of bounds, {self.config.height}, {self.config.width}" - ) + ) x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0 y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0 @@ -50,11 +50,11 @@ def forward(self, pos_ids): # Returns rope_cos, rope_sin angles of shape [B, 1, T, D/2] return freqs.cos()[:, None], freqs.sin()[:, None] + class OrthoRoPE(NoCastModule): def __init__(self, config): super().__init__() self.config = config - assert not getattr(self.config, "has_audio", False) @torch.autocast("cuda", enabled=False) def forward(self, x, rope_angles): @@ -70,13 +70,13 @@ def __init__(self, config, layer_idx): super().__init__() self.config = config self.layer_idx = layer_idx + self.value_residual = config.value_residual - self.value_residual = getattr(config, "value_residual", False) if self.value_residual: self.v_lamb = nn.Parameter(torch.tensor(0.5)) self.n_heads = config.n_heads - self.n_kv_heads = getattr(config, "n_kv_heads", config.n_heads) + self.n_kv_heads = config.n_kv_heads self.d_head = config.d_model // self.n_heads assert config.d_model % self.n_heads == 0 @@ -89,7 +89,7 @@ def __init__(self, config, layer_idx): self.rope = OrthoRoPE(config) - self.gated_attn = getattr(config, "gated_attn", False) + self.gated_attn = config.gated_attn if self.gated_attn: self.gate_proj = nn.Linear(self.n_heads, self.n_heads, bias=False) # sparse attn gate nn.init.zeros_(self.gate_proj.weight) diff --git a/src/model/base_model.py b/src/model/base_model.py index 5932c57..f94314e 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -7,6 +7,25 @@ import torch +MODEL_CONFIG_DEFAULTS = OmegaConf.create( + { + "auto_aspect_ratio": True, + "gated_attn": False, + "inference_fps": "${base_fps}", + "model_type": "waypoint-1", + "n_kv_heads": "${n_heads}", + "patch": [1, 1], + "prompt_conditioning": None, + "prompt_encoder_uri": "google/umt5-xl", + "rope_nyquist_frac": 0.8, + "rope_theta": 10000.0, + "taehv_ae": False, + "temporal_compression": 1, + "value_residual": False, + } +) + + class BaseModel(nn.Module): @classmethod def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None, load_weights: bool = True): @@ -35,4 +54,5 @@ def load_config(path): cfg_path = os.path.join(path, "config.yaml") else: cfg_path = huggingface_hub.hf_hub_download(repo_id=path, filename="config.yaml") - return OmegaConf.load(cfg_path) + cfg = OmegaConf.load(cfg_path) + return OmegaConf.merge(MODEL_CONFIG_DEFAULTS, cfg) diff --git a/src/model/kv_cache.py b/src/model/kv_cache.py index 6694049..e6f7049 100644 --- a/src/model/kv_cache.py +++ b/src/model/kv_cache.py @@ -153,11 +153,11 @@ def __init__(self, config, batch_size, dtype): global_L = config.global_window * self.tpf period = config.global_attn_period - off = getattr(config, "global_attn_offset", 0) % period + off = config.global_attn_offset % period self.layers = nn.ModuleList([ LayerKVCache( batch_size, - getattr(config, "n_kv_heads", config.n_heads), + config.n_kv_heads, global_L if ((layer_idx - off) % period == 0) else local_L, config.d_model // config.n_heads, dtype, diff --git a/src/model/world_model.py b/src/model/world_model.py index da78b4f..b17d872 100644 --- a/src/model/world_model.py +++ b/src/model/world_model.py @@ -309,7 +309,7 @@ def forward( return x def load_state_dict(self, state_dict, strict=True, assign=False): - if getattr(self.config, "model_type", "waypoint-1") != "waypoint-1.5": + if self.config.model_type != "waypoint-1.5": return super().load_state_dict(state_dict, strict=strict, assign=assign) state_dict = dict(state_dict) diff --git a/src/world_engine.py b/src/world_engine.py index 8e22f35..91fc2ae 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -62,16 +62,15 @@ def __init__( # Load Model / Modules self.vae = get_ae( self.model_cfg.ae_uri, - is_taehv_ae=getattr(self.model_cfg, "taehv_ae", False), - auto_aspect_ratio=getattr(self.model_cfg, "auto_aspect_ratio", True), + is_taehv_ae=self.model_cfg.taehv_ae, + auto_aspect_ratio=self.model_cfg.auto_aspect_ratio, dtype=dtype, device=device, ) self.prompt_encoder = None if self.model_cfg.prompt_conditioning is not None: - pe_uri = getattr(self.model_cfg, "prompt_encoder_uri", "google/umt5-xl") - self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).eval() + self.prompt_encoder = PromptEncoder(self.model_cfg.prompt_encoder_uri, dtype=dtype).eval() self.model = WorldModel.from_pretrained( model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights @@ -85,12 +84,11 @@ def __init__( # Inference Scheduler self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, dtype=dtype, device=device) - pH, pW = getattr(self.model_cfg, "patch", [1, 1]) + pH, pW = self.model_cfg.patch self.frm_shape = 1, 1, self.model_cfg.channels, self.model_cfg.height * pH, self.model_cfg.width * pW # State - inference_fps = getattr(self.model_cfg, "inference_fps", self.model_cfg.base_fps) - latent_fps = inference_fps / getattr(self.model_cfg, "temporal_compression", 1) + latent_fps = self.model_cfg.inference_fps / self.model_cfg.temporal_compression self.ts_mult = int(self.model_cfg.base_fps) // latent_fps self.frame_ts = torch.tensor([[0]], dtype=torch.long)