diff --git a/README.md b/README.md index ba885b2..8d68292 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ export HF_TOKEN= from world_engine import WorldEngine, CtrlInput # Create inference engine -engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") +engine = WorldEngine("Overworld/Waypoint-1.5-1B", device="cuda") # Specify a prompt engine.set_prompt("A fun game") @@ -77,6 +77,17 @@ for controller_input in [ img = engine.gen_frame(ctrl=controller_input) ``` +## Waypoint-1.5 Behavior +All interfaces and handling for Waypoint-1 (or 1.1) and Waypoint-1.5 remain the same **except** the following: + +In Waypoint-1.5, the `img` passed to `append_frame(...)` and returned by `gen_frame(...)` is now a sequence of 4 frames. Waypoint-1.5 applies temporal compression and generates 4 frames for every controller input. + +Whereas previously, `img` was a uint8 rgb array of shape `[Height, Width, 3]`, **in Waypoint-1.5 it is of shape `[4, Height, Width, 3]`**. + +Additionally, Waypoint-1.5 expects 720p inputs / outputs, therefore `img` is `[4, 720, 1280, 3]`. + +See [examples/gen_sample.py](./examples/gen_sample.py) for reference. + ## Usage ``` from world_engine import WorldEngine, CtrlInput @@ -84,7 +95,7 @@ from world_engine import WorldEngine, CtrlInput Load model to GPU ``` -engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") +engine = WorldEngine("Overworld/Waypoint-1.5-1B", device="cuda") ``` Specify a prompt which will be used until this function is called again @@ -118,11 +129,13 @@ Note: returned `img` is always on the same device as `engine.device` @dataclass class CtrlInput: button: Set[int] = field(default_factory=set) # pressed button IDs - mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) position + mouse: Tuple[float, float] = (0.0, 0.0) # (dx, dy) position change + scroll_wheel: int = 0 # down, stationary, or up -> (-1, 0, 1) ``` - `button` keycodes are defined by [Owl-Control](https://github.com/Overworldai/owl-control/blob/main/src/system/keycode.rs) -- `mouse` is the raw mouse velocity vector +- `mouse` is the the amount the change in mouse since last frame +- `scroll_wheel` is the ternary scroll wheel movement identifier ## Showcase and Examples @@ -138,5 +151,5 @@ class CtrlInput: ### Examples and Reference Code -- ["Hello (Over)World" client](./examples/simple_client.py) +- ["Generate MP4 Sample Given Controller Inputs](./examples/gen_sample.py) - [Run Performance Benchmarks (`pytest examples/benchmark.py`)](./examples/benchmark.py) diff --git a/examples/benchmark.py b/examples/benchmark.py index 9bd75fa..3d8e623 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -1,12 +1,14 @@ +# MODEL_URI="Overworld/Waypoint-1.5-1B" uv run --dev pytest examples/benchmark.py + +import os import pytest import torch +import random -from world_engine import WorldEngine +from world_engine import WorldEngine, CtrlInput -# TODO -# - benchmark encode img -# - benchmark encode prompt +MODEL_URI = os.environ.get("MODEL_URI", "Overworld/Waypoint-1-Small") def version_with_commit(pkg): @@ -49,7 +51,7 @@ def print_env_info(): def get_warm_engine(model_uri, model_overrides=None): - model_config_overrides = {"ae_uri": "OpenWorldLabs/owl_vae_f16_c16_distill_v0_nogan"} + model_config_overrides = {} model_config_overrides.update(model_overrides or {}) engine = WorldEngine( model_uri, @@ -65,8 +67,8 @@ def get_warm_engine(model_uri, model_overrides=None): @pytest.fixture(scope="session") -def engine(model_uri="Overworld/Waypoint-1-Small"): - return get_warm_engine(model_uri) +def engine(): + return get_warm_engine(MODEL_URI) @pytest.fixture(scope="session") @@ -86,14 +88,22 @@ def run(): MODEL_OVERRIDES = [None] +@pytest.mark.parametrize("blocking", [True, False]) @pytest.mark.parametrize("dit_only", [True]) @pytest.mark.parametrize("n_frames", [256]) @pytest.mark.parametrize( "model_overrides", MODEL_OVERRIDES, ids=lambda d: (",".join(f"{k}={v}" for k, v in d.items()) or "") if d else "" ) -def test_ar_rollout(benchmark, dit_only, n_frames, model_overrides): - engine = get_warm_engine("Overworld/Waypoint-1-Small", model_overrides=model_overrides) +def test_ar_rollout(benchmark, dit_only, n_frames, model_overrides, blocking): + engine = get_warm_engine(MODEL_URI, model_overrides=model_overrides) + + try: + total_params = sum(p.numel() for p in engine.model.parameters()) + active_params = int(engine.model.get_active_parameters()) + benchmark.name = f"{benchmark.name} | params={total_params:,} | active={active_params:,}" + except Exception: + pass def setup(): engine.reset() @@ -101,8 +111,17 @@ def setup(): torch.cuda.synchronize() def target(): - for _ in range(n_frames): + ctrls = [ + CtrlInput( + button=set(random.sample(range(1, 65), random.randint(0, 10))), + mouse=(random.random(), random.random()), + scroll_wheel=random.choice((-1, 0, 1)) + ) + for _ in range(n_frames) + ] + for ctrl in ctrls: engine.gen_frame(return_img=not dit_only) - torch.cuda.synchronize() + if blocking: + torch.cuda.synchronize() benchmark.pedantic(target, setup=setup, rounds=20) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 62c27c5..0217896 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -1,22 +1,53 @@ +# uv run --dev examples/gen_sample.py Overworld/Waypoint-1.5-1B + import cv2 -from world_engine import WorldEngine +import imageio.v3 as iio +import random +import sys +import urllib.request +import numpy as np +import torch + +from world_engine import WorldEngine, CtrlInput + + +# Create inference engine +engine = WorldEngine(sys.argv[1], device="cuda") + +# Define sequence of controller inputs applied +controller_sequence = [ + # move mouse, jump, do nothing, trigger, do nothing, trigger+jump, do nothing + CtrlInput(mouse=[0.2, 0.2]), CtrlInput(button={32}), CtrlInput(), CtrlInput(), CtrlInput(), + CtrlInput(button={1}), CtrlInput(), CtrlInput(), CtrlInput(button={1, 32}), + CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), +] * 4 +controller_sequence += [CtrlInput()] * 8 +controller_sequence += ( + [CtrlInput(button={32})] * 10 + # forward + [CtrlInput(button={65})] * 10 + # left + [CtrlInput(button={68})] * 10 + # right + [CtrlInput(button={83})] * 10 # backwards +) +controller_sequence += [CtrlInput()] * 10 -def gen_vid(): - engine = WorldEngine("OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing", device="cuda") - writer = None - for _ in range(240): - frame = engine.gen_frame().cpu().numpy()[:, :, ::-1] # RGB -> BGR for OpenCV - writer = writer or cv2.VideoWriter( - "out.mp4", - cv2.VideoWriter_fourcc(*"mp4v"), - 60, - (frame.shape[1], frame.shape[0]) - ) - writer.write(frame) - writer.release() +# Set seed frame +url = random.choice([ + "https://gist.github.com/user-attachments/assets/d81c6d26-a838-4afe-9d13-fd67677043c3", + "https://gist.github.com/user-attachments/assets/b6d18c38-098e-43b0-8e61-66a16e5d8946", + "https://gist.github.com/user-attachments/assets/0734a8c1-3eb4-4ffe-8c37-5665c45ab559", + "https://gist.github.com/user-attachments/assets/f9c20d4d-7565-452d-8b02-42a85ea175ed", + "https://gist.github.com/user-attachments/assets/68c943a4-008a-4c25-948c-c81ab4c47d21", +]) +seed_frame = cv2.imdecode(np.frombuffer(urllib.request.urlopen(url).read(), np.uint8), cv2.IMREAD_COLOR) +seed_frame_x4 = torch.from_numpy(np.repeat(seed_frame[None], 4, axis=0)) -if __name__ == "__main__": - gen_vid() +# Generate frames conditioned on controller inputs +with iio.imopen("out.mp4", "w", plugin="pyav") as out: + engine.append_frame(seed_frame_x4) + out.write(seed_frame_x4, fps=60, codec="libx264") + for ctrl in controller_sequence: + four_frames = engine.gen_frame(ctrl=ctrl).cpu().numpy() + out.write(four_frames) diff --git a/examples/prof.py b/examples/prof.py index b8cd570..b469076 100644 --- a/examples/prof.py +++ b/examples/prof.py @@ -1,3 +1,9 @@ +""" +Additional Dependencies: N/A +Run: `python3 examples/prof.py Overworld/Waypoint-1.5-1B` +""" +import sys + import torch from torch.profiler import profile, ProfilerActivity @@ -5,7 +11,7 @@ def do_profile(n_frames=64, row_limit=20): - engine = WorldEngine("OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing", device="cuda") + engine = WorldEngine(sys.argv[1], device="cuda") # warmup for _ in range(4): engine.gen_frame() diff --git a/examples/simple_client.py b/examples/simple_client.py deleted file mode 100644 index 2f5d410..0000000 --- a/examples/simple_client.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import AsyncIterable, AsyncIterator - -import asyncio -import contextlib -import cv2 -import sys -import torch - -from world_engine import WorldEngine, CtrlInput - - -async def render(frames: AsyncIterable[torch.Tensor], win_name="Hello (Over)World (ESC to exit)") -> None: - """Render stream of RGB tensor images.""" - cv2.namedWindow(win_name, cv2.WINDOW_AUTOSIZE | cv2.WINDOW_GUI_NORMAL) - async for t in frames: - cv2.imshow(win_name, t.cpu().numpy()) - await asyncio.sleep(0) - cv2.destroyAllWindows() - - -async def frame_stream(engine: WorldEngine, ctrls: AsyncIterable[CtrlInput]) -> AsyncIterator[torch.Tensor]: - """Generate frame by calling Engine for each ctrl.""" - yield await asyncio.to_thread(engine.gen_frame) - async for ctrl in ctrls: - yield await asyncio.to_thread(engine.gen_frame, ctrl=ctrl) - - -async def ctrl_stream(delay: int = 1) -> AsyncIterator[CtrlInput]: - """Accumulate key presses asyncronously. Yield CtrlInput once next() is called.""" - q: asyncio.Queue[int] = asyncio.Queue() - - async def producer() -> None: - while True: - k = cv2.waitKey(delay) - if k != -1: - await q.put(k) - await asyncio.sleep(0) - - prod_task = asyncio.create_task(producer()) - while True: - buttons: set[int] = set() - # Drain everything currently in the queue into this batch - with contextlib.suppress(asyncio.QueueEmpty): - while True: - k = q.get_nowait() - if k == 27: - # End if ESC pressed - prod_task.cancel() - return - buttons.add(k) - - yield CtrlInput(button=buttons) - - -async def main() -> None: - uri = sys.argv[1] if len(sys.argv) > 1 else "OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing" - engine = WorldEngine(uri, device="cuda") - ctrls = ctrl_stream() - frames = frame_stream(engine, ctrls) - await render(frames) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index d0fb905..414c61f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,14 +4,14 @@ build-backend = "setuptools.build_meta" [project] name = "world_engine" -version = "1.0.0" -requires-python = ">=3.9" +version = "1.5.0" +requires-python = ">=3.10" dependencies = [ + "taehv @ git+https://github.com/madebyollin/taehv.git@7dc60ec6601af2e668e31bc70acc4cb3665e4c22", "torch==2.10.0", "torchvision==0.25.0", "torchaudio==2.10.0", "einops", - "rotary-embedding-torch>=0.8.8", "tensordict==0.10.0", "transformers==4.57.3", "ftfy", @@ -21,8 +21,8 @@ dependencies = [ "accelerate==1.12.0", # Triton (platform-specific) - "triton; sys_platform == 'linux'", - "triton-windows; sys_platform == 'win32'", + "triton==3.6.0; sys_platform == 'linux'", + "triton-windows==3.6.0.post26; sys_platform == 'win32'", ] [tool.setuptools] @@ -33,3 +33,12 @@ packages = [ [tool.setuptools.package-dir] world_engine = "src" + +[dependency-groups] +dev = [ + "pytest", + "pytest-benchmark", + "opencv-python", + "imageio[pyav]", + "numpy", +] diff --git a/src/ae.py b/src/ae.py index 577d765..4b42b3f 100644 --- a/src/ae.py +++ b/src/ae.py @@ -1,28 +1,92 @@ import torch +import torch.nn.functional as F from torch import Tensor -""" -WARNING: -- Always assumes scale=1, shift=0 -""" +class ChunkedStreamingTAEHV: + _ENCODE_SIZES = {(720, 1280): (512, 1024), (360, 640): (256, 512)} + _DECODE_SIZES = {v: k for k, v in _ENCODE_SIZES.items()} + def __init__(self, ae_model, auto_aspect_ratio=True, device=None, dtype=torch.bfloat16): + """ + auto_aspect_ratio: automatically resize so encode input must be 720p or 360p (converted 16:9 to 16:8) + and decode output is 720p or 360p + """ + from taehv import StreamingTAEHV -def bake_weight_norm_(module) -> int: - """ - Removes weight parametrizations (from torch.nn.utils.parametrizations.weight_norm) - and leaves the current parametrized weight as a plain Parameter. - Returns how many modules were de-parametrized. - """ - import torch.nn.utils.parametrize as parametrize + self.device = device + self.dtype = dtype + self.auto_aspect_ratio = auto_aspect_ratio + self.streaming_ae_model = StreamingTAEHV(ae_model.eval().to(device=device, dtype=dtype)) + + @classmethod + def from_pretrained(cls, model_uri: str, auto_aspect_ratio=True, **kwargs): + import pathlib + + import huggingface_hub + from taehv import TAEHV + + try: + base = pathlib.Path(huggingface_hub.snapshot_download(model_uri)) + except Exception: + base = pathlib.Path(model_uri) + + ckpt = base if base.is_file() else base / "taehv1_5.pth" + return cls(TAEHV(str(ckpt)), auto_aspect_ratio=auto_aspect_ratio, **kwargs) + + def reset(self): + from taehv import StreamingTAEHV + + # Rebuild streaming state, reuse same weights model + self.streaming_ae_model = StreamingTAEHV(self.streaming_ae_model.taehv) + + def _resize(self, x: Tensor, size: tuple[int, int]) -> Tensor: + return F.interpolate(x[0], size=size, mode="bilinear", align_corners=False)[None] + + @torch.inference_mode() + def encode(self, img: Tensor): + """ + img: [T, H, W, C] uint8 where T == t_downscale + returns: latent [B, C, h, w] + """ + t = self.streaming_ae_model.taehv.t_downscale + assert img.dim() == 4 and img.shape[-1] == 3 and img.shape[0] == t, f"Expected [{t}, H, W, 3] RGB uint8" + + rgb = img.unsqueeze(0)\ + .to(device=self.device, dtype=self.dtype)\ + .permute(0, 1, 4, 2, 3).contiguous().div(255) + + if self.auto_aspect_ratio: + rgb = self._resize(rgb, self._ENCODE_SIZES[img.shape[1:3]]) + + return self.streaming_ae_model.encode(rgb).squeeze(1) + + @torch.inference_mode() + def decode(self, latent: Tensor): + """ + latent: [B, C, h, w] + returns: frames [T, H, W, C] uint8 + """ + assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor" + + z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype) - n = 0 - for m in module.modules(): - # weight_norm registers a parametrization on "weight" - if hasattr(m, "parametrizations") and "weight" in getattr(m, "parametrizations", {}): - parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) - n += 1 - return n + if self.streaming_ae_model.n_frames_decoded == 0: + for _ in range(self.streaming_ae_model.taehv.frames_to_trim): + self.streaming_ae_model.decode(z) + self.streaming_ae_model.flush_decoder() + + first = self.streaming_ae_model.decode(z) + assert first is not None, "Expected decoded output after a latent" + frames = [first, *self.streaming_ae_model.flush_decoder()] + + decoded = torch.cat(frames, dim=1) + + if self.auto_aspect_ratio: + decoded = self._resize(decoded, self._DECODE_SIZES[decoded.shape[-2:]]) + + decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8) + return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3] class InferenceAE: @@ -31,6 +95,9 @@ def __init__(self, ae_model, device=None, dtype=torch.bfloat16): self.dtype = dtype self.ae_model = ae_model.eval().to(device=device, dtype=dtype) + def reset(self): + pass + @classmethod def from_pretrained(cls, model_uri: str, **kwargs): import pathlib @@ -54,10 +121,27 @@ def from_pretrained(cls, model_uri: str, **kwargs): model.encoder.load_state_dict(enc_sd, strict=True) model.decoder.load_state_dict(dec_sd, strict=True) - bake_weight_norm_(model) + cls.bake_weight_norm_(model) return cls(model, **kwargs) + @staticmethod + def bake_weight_norm_(module) -> int: + """ + Removes weight parametrizations (from torch.nn.utils.parametrizations.weight_norm) + and leaves the current parametrized weight as a plain Parameter. + Returns how many modules were de-parametrized. + """ + import torch.nn.utils.parametrize as parametrize + + n = 0 + for m in module.modules(): + # weight_norm registers a parametrization on "weight" + if hasattr(m, "parametrizations") and "weight" in getattr(m, "parametrizations", {}): + parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) + n += 1 + return n + def encode(self, img: Tensor): """RGB -> RGB+D -> latent""" assert img.dim() == 3, "Expected [H, W, C] image tensor" @@ -85,3 +169,10 @@ def decode(self, latent: Tensor): decoded = (decoded / 2 + 0.5).clamp(0, 1) decoded = (decoded * 255).round().to(torch.uint8) return decoded.squeeze(0).permute(1, 2, 0)[..., :3] + + +def get_ae(ae_uri, is_taehv_ae=False, auto_aspect_ratio=True, **kwargs): + if is_taehv_ae: + return ChunkedStreamingTAEHV.from_pretrained(ae_uri, auto_aspect_ratio=auto_aspect_ratio, **kwargs) + else: + return InferenceAE.from_pretrained(ae_uri, **kwargs) diff --git a/src/model/attn.py b/src/model/attn.py index 52c2b57..76405b1 100644 --- a/src/model/attn.py +++ b/src/model/attn.py @@ -4,67 +4,65 @@ from torch.nn.attention.flex_attention import flex_attention -from rotary_embedding_torch import RotaryEmbedding - from .nn import rms_norm, NoCastModule -class RoPE(NoCastModule): +class OrthoRoPEAngles(NoCastModule): + """Functions as a on the fly RoPE angle computer called every fwd pass. Should be setup + as a module under WordDiT, then each forward pass it computes a shared tuple of (rope_cos, rope_sin) + tensors that get passed to every block for their underlying RoPE computations.""" def __init__(self, config): super().__init__() self.config = config - assert not getattr(self.config, "has_audio", False) - freqs = self.get_freqs(config) - self.cos = nn.Buffer(freqs.cos().contiguous(), persistent=False) - self.sin = nn.Buffer(freqs.sin().contiguous(), persistent=False) + d_head = config.d_model // config.n_heads + torch._assert(d_head % 8 == 0, "d_head must be divisible by 8") + d_xy, d_t = d_head // 8, d_head // 4 - def get_angles(self, pos_ids): - t, y, x = pos_ids["t_pos"], pos_ids["y_pos"], pos_ids["x_pos"] # [B,T] - H, W = self.config.height, self.config.width - if not torch.compiler.is_compiling(): - torch._assert((y.max() < H) & (x.max() < W), f"pos_ids out of bounds, {y.max()}, {x.max()}") - flat = t * (H * W) + y * W + x # [B,T] - idx = flat.reshape(-1).to(torch.long) - cos = self.cos.index_select(0, idx).view(*flat.shape, -1) - sin = self.sin.index_select(0, idx).view(*flat.shape, -1) - return cos[:, None], sin[:, None] # add head dim for broadcast + nyq = float(config.rope_nyquist_frac) + max_freq = min(self.config.height, self.config.width) * nyq + n = (d_xy + 1) // 2 + xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy] - @torch.autocast("cuda", enabled=False) - def forward(self, x, pos_ids): - assert self.cos.dtype == self.sin.dtype == torch.float32 - cos, sin = self.get_angles(pos_ids) - x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1) - y0 = x0 * cos - x1 * sin - y1 = x1 * cos + x0 * sin - return torch.cat((y0, y1), dim=-1).type_as(x) + theta = float(config.rope_theta) + inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t)) + inv_t = inv_t.repeat_interleave(2) # [d_t] - def get_freqs(self, config): - raise NotImplementedError + self.register_buffer("xy", xy, persistent=False) # [d_xy] + self.register_buffer("inv_t", inv_t, persistent=False) # [d_t] + + @torch.autocast("cuda", enabled=False) + def forward(self, pos_ids): + if not torch.compiler.is_compiling(): + torch._assert( + (pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width), + f"pos_ids out of bounds, {self.config.height}, {self.config.width}" + ) + x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0 + y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0 + t = pos_ids["t_pos"].float() -class OrthoRoPE(RoPE): - """ - RoPE for rotation across orthogonal axes: time, height, and width - Time: Geometric Spectrum -- rotates 1/2 of head dim - Height / Width: Linear Spectrum -- rotates 1/4th of head dim each (1/2 combined) - """ - def get_freqs(self, config): - H, W, T = config.height, config.width, config.n_frames - head_dim = config.d_model // config.n_heads + freqs = torch.cat( + (x.unsqueeze(-1) * self.xy, y.unsqueeze(-1) * self.xy, t.unsqueeze(-1) * self.inv_t), + dim=-1, # [B,T,d_head//2] + ) + # Returns rope_cos, rope_sin angles of shape [B, 1, T, D/2] + return freqs.cos()[:, None], freqs.sin()[:, None] - max_freq = min(H, W) * 0.8 # stay below nyquist - rope_xy = RotaryEmbedding(dim=head_dim // 8, freqs_for='pixel', max_freq=max_freq) - freqs_x = rope_xy(torch.linspace(-1 + 1 / W, 1 - 1 / W, W))[None, :, :] # [1,W,D] - freqs_y = rope_xy(torch.linspace(-1 + 1 / H, 1 - 1 / H, H))[:, None, :] # [H,1,D] - freq_t = RotaryEmbedding(dim=head_dim // 4, freqs_for='lang').forward(torch.arange(T)) +class OrthoRoPE(NoCastModule): + def __init__(self, config): + super().__init__() + self.config = config - return torch.cat([ - eo.repeat(freqs_x.expand(H, W, -1), 'h w d -> (t h w) d', t=T), # X - eo.repeat(freqs_y.expand(H, W, -1), 'h w d -> (t h w) d', t=T), # Y - eo.repeat(freq_t, 't d -> (t h w) d', h=H, w=W) # T - ], dim=-1) + @torch.autocast("cuda", enabled=False) + def forward(self, x, rope_angles): + cos, sin = rope_angles + x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1) + y0 = x0 * cos - x1 * sin + y1 = x1 * cos + x0 * sin + return torch.cat((y0, y1), dim=-1).type_as(x) class Attn(nn.Module): @@ -72,13 +70,13 @@ def __init__(self, config, layer_idx): super().__init__() self.config = config self.layer_idx = layer_idx + self.value_residual = config.value_residual - self.value_residual = getattr(config, "value_residual", False) if self.value_residual: self.v_lamb = nn.Parameter(torch.tensor(0.5)) self.n_heads = config.n_heads - self.n_kv_heads = getattr(config, "n_kv_heads", config.n_heads) + self.n_kv_heads = config.n_kv_heads self.d_head = config.d_model // self.n_heads assert config.d_model % self.n_heads == 0 @@ -91,12 +89,12 @@ def __init__(self, config, layer_idx): self.rope = OrthoRoPE(config) - self.gated_attn = getattr(config, "gated_attn", False) + self.gated_attn = config.gated_attn if self.gated_attn: self.gate_proj = nn.Linear(self.n_heads, self.n_heads, bias=False) # sparse attn gate nn.init.zeros_(self.gate_proj.weight) - def forward(self, x, pos_ids, v1, kv_cache): + def forward(self, x, pos_ids, rope_angles, v1, kv_cache): # Q, K, V proj -> QK-norm -> RoPE q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head) k = eo.rearrange(self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head) @@ -107,7 +105,7 @@ def forward(self, x, pos_ids, v1, kv_cache): v = torch.lerp(v, v1.view_as(v), self.v_lamb) q, k = rms_norm(q), rms_norm(k) - q, k = self.rope(q, pos_ids), self.rope(k, pos_ids) + q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) # Update KV-cache in-place k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) diff --git a/src/model/base_model.py b/src/model/base_model.py index 2b2905e..f94314e 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -1,26 +1,37 @@ import huggingface_hub import os -import tempfile from omegaconf import OmegaConf -from safetensors.torch import save_file, load_file +from safetensors.torch import load_file from torch import nn +import torch + + +MODEL_CONFIG_DEFAULTS = OmegaConf.create( + { + "auto_aspect_ratio": True, + "gated_attn": False, + "inference_fps": "${base_fps}", + "model_type": "waypoint-1", + "n_kv_heads": "${n_heads}", + "patch": [1, 1], + "prompt_conditioning": None, + "prompt_encoder_uri": "google/umt5-xl", + "rope_nyquist_frac": 0.8, + "rope_theta": 10000.0, + "taehv_ae": False, + "temporal_compression": 1, + "value_residual": False, + } +) class BaseModel(nn.Module): - def save_pretrained(self, path: str) -> None: - """Save weights (.safetensors) and OmegaConf YAML.""" - os.makedirs(path, exist_ok=True) - save_file( - {k: v.detach().cpu() for k, v in self.state_dict().items()}, - os.path.join(path, "model.safetensors"), - ) - OmegaConf.save(self.config, os.path.join(path, "config.yaml")) - @classmethod - def from_pretrained(cls, path: str, cfg=None, device=None): + def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None, load_weights: bool = True): """Load weights and OmegaConf YAML.""" - device = device or "cpu" + device = torch.get_default_device() if device is None else device + dtype = torch.get_default_dtype() if dtype is None else dtype try: path = huggingface_hub.snapshot_download(path) @@ -29,27 +40,19 @@ def from_pretrained(cls, path: str, cfg=None, device=None): if cfg is None: cfg = cls.load_config(path) - model = cls(cfg) - - if device != "cpu": - model = model.to(device) + model = cls(cfg).to(dtype=dtype, device=device) - # Stream weights straight into `model` (no CPU state_dict first) - safetensors_path = os.path.join(path, "model.safetensors") - model.load_state_dict(load_file(safetensors_path, device=device), strict=True) + if load_weights: + safetensors_path = os.path.join(path, "model.safetensors") + model.load_state_dict(load_file(safetensors_path, device=device), strict=True) return model - def push_to_hub(self, uri: str, **kwargs): - huggingface_hub.create_repo(uri, repo_type="model", exist_ok=True, private=True) - with tempfile.TemporaryDirectory() as d: - self.save_pretrained(d) - huggingface_hub.upload_folder(folder_path=d, repo_id=uri, **kwargs) - @staticmethod def load_config(path): if os.path.isdir(path): cfg_path = os.path.join(path, "config.yaml") else: cfg_path = huggingface_hub.hf_hub_download(repo_id=path, filename="config.yaml") - return OmegaConf.load(cfg_path) + cfg = OmegaConf.load(cfg_path) + return OmegaConf.merge(MODEL_CONFIG_DEFAULTS, cfg) diff --git a/src/model/kv_cache.py b/src/model/kv_cache.py index 839c9a5..e6f7049 100644 --- a/src/model/kv_cache.py +++ b/src/model/kv_cache.py @@ -13,53 +13,47 @@ def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: """ T: Q length for this frame L: KV capacity == written.numel() - written: [L] bool, True where there is valid KV data + written: [L] bool, True where there is valid KV data. + T and L must be exact multiples of the sparse block size; `written` must be + block-aligned, i.e. each block is either all True or all False. """ BS = _DEFAULT_SPARSE_BLOCK_SIZE - KV_blocks = (L + BS - 1) // BS - Q_blocks = (T + BS - 1) // BS - # [KV_blocks, BS] - written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view(KV_blocks, BS) - - # Block-level occupancy - block_any = written_blocks.any(-1) # block has at least one written token - block_all = written_blocks.all(-1) # block is fully written + if not torch.compiler.is_compiling(): + torch._check(T % BS == 0, f"T ({T}) must be a multiple of block size ({BS})") + torch._check(L % BS == 0, f"L ({L}) must be a multiple of block size ({BS})") - # Every Q-block sees the same KV-block pattern - nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) # [Q_blocks, KV_blocks] - full_bm = block_all[None, :].expand_as(nonzero_bm) # [Q_blocks, KV_blocks] - partial_bm = nonzero_bm & ~full_bm # [Q_blocks, KV_blocks] + Q_blocks = T // BS + KV_blocks = L // BS - def dense_to_ordered(dense_mask: torch.Tensor): - # dense_mask: [Q_blocks, KV_blocks] bool - # returns: [1,1,Q_blocks], [1,1,Q_blocks,KV_blocks] - num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) # [Q_blocks] - indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(torch.int32) - return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + # [KV_blocks, BS] + written_blocks = written.view(KV_blocks, BS) - # Partial blocks (need mask_mod) - kv_num_blocks, kv_indices = dense_to_ordered(partial_bm) + # For a valid block-aligned mask, each block is either all written or all empty. + block_any = written_blocks.any(-1) + if not torch.compiler.is_compiling(): + assert torch.equal(block_any, written_blocks.all(-1)), "written must be block-aligned" - # Full blocks (mask_mod can be skipped entirely) - full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) + # Every KV block is a full block + full_bm = block_any[None, :].expand(Q_blocks, KV_blocks) + full_kv_num_blocks = full_bm.sum(dim=-1, dtype=torch.int32)[None, None].contiguous() + full_kv_indices = full_bm.argsort(dim=-1, descending=True, stable=True).to(torch.int32)[None, None].contiguous() - def mask_mod(b, h, q, kv): - return written[kv] + # No partial blocks at all. + kv_num_blocks = torch.zeros((1, 1, Q_blocks), dtype=torch.int32, device=written.device) + kv_indices = torch.zeros((1, 1, Q_blocks, KV_blocks), dtype=torch.int32, device=written.device) - bm = BlockMask.from_kv_blocks( + return BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, BLOCK_SIZE=BS, - mask_mod=mask_mod, + mask_mod=None, seq_lengths=(T, L), - compute_q_blocks=False, # no backward, avoids the transpose/_ordered_to_dense path + compute_q_blocks=False, ) - return bm - class LayerKVCache(nn.Module): """ @@ -107,26 +101,21 @@ def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): t_pos: [B, T], all equal per frame (ignoring -1) """ T = self.tpf - t_pos = pos_ids["t_pos"] + f_pos = pos_ids["f_pos"] if not torch.compiler.is_compiling(): torch._check(kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert") - torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") + torch._check(f_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") torch._check(self.tpf <= self.L, "frame longer than KV ring capacity") - torch._check(self.L % self.tpf == 0, - f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})") - torch._check(self.kv.size(3) == self.capacity, - "KV buffer has unexpected length (expected L + tokens_per_frame)") - torch._check( - (t_pos >= 0).all().item(), - "t_pos must be non-negative during inference", - ) - torch._check(((t_pos == t_pos[:, :1]).all()).item(), "t_pos must be constant within frame") + torch._check(self.L % self.tpf == 0, f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})") + torch._check(self.kv.size(3) == self.capacity, "KV buffer too long (expected L + tokens_per_frame)") + torch._check((f_pos >= 0).all().item(), "t_pos must be non-negative during inference") + torch._check(((f_pos == f_pos[:, :1]).all()).item(), "t_pos must be constant within frame") - frame_t = t_pos[0, 0] + frame_idx = f_pos[0, 0] # map frame_t to a bucket, each bucket owns T contiguous slots - bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation + bucket = (frame_idx + (self.pinned_dilation - 1)) // self.pinned_dilation slot = bucket % self.num_buckets base = slot * T @@ -137,7 +126,7 @@ def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): # this is the "self-attention component" for the current frame. self.kv.index_copy_(3, self.current_idx, kv) - write_step = (frame_t.remainder(self.pinned_dilation) == 0) + write_step = (frame_idx.remainder(self.pinned_dilation) == 0) mask_written = self._mask_written mask_written.copy_(self.written) mask_written[ring_idx] = mask_written[ring_idx] & ~write_step @@ -158,17 +147,17 @@ class StaticKVCache(nn.Module): def __init__(self, config, batch_size, dtype): super().__init__() - self.tpf = config.tokens_per_frame + self.tpf = config.height * config.width local_L = config.local_window * self.tpf global_L = config.global_window * self.tpf period = config.global_attn_period - off = getattr(config, "global_attn_offset", 0) % period + off = config.global_attn_offset % period self.layers = nn.ModuleList([ LayerKVCache( batch_size, - getattr(config, "n_kv_heads", config.n_heads), + config.n_kv_heads, global_L if ((layer_idx - off) % period == 0) else local_L, config.d_model // config.n_heads, dtype, @@ -185,6 +174,18 @@ def reset(self): layer.reset() self._is_frozen = True + @torch.inference_mode() + def get_state(self): + layers = [(layer.kv.detach().clone(), layer.written.detach().clone()) for layer in self.layers] + return {"_is_frozen": self._is_frozen, "layers": layers} + + @torch.inference_mode() + def load_state(self, state): + self._is_frozen = bool(state.get("_is_frozen", True)) + for layer, (kv, written) in zip(self.layers, state["layers"]): + layer.kv.copy_(kv) + layer.written.copy_(written) + def set_frozen(self, is_frozen: bool): self._is_frozen = is_frozen diff --git a/src/model/world_model.py b/src/model/world_model.py index aec0ce5..b17d872 100644 --- a/src/model/world_model.py +++ b/src/model/world_model.py @@ -9,7 +9,7 @@ from torch import nn import torch.nn.functional as F -from .attn import Attn, CrossAttention +from .attn import Attn, CrossAttention, OrthoRoPEAngles from .nn import AdaLN, ada_gate, ada_rmsnorm, NoiseConditioner from .base_model import BaseModel @@ -149,10 +149,10 @@ def __init__(self, config, layer_idx): do_prompt_cond = config.prompt_conditioning is not None and layer_idx % config.prompt_conditioning_period == 0 self.prompt_cross_attn = CrossAttention(config, config.prompt_embedding_dim) if do_prompt_cond else None - do_ctrl_cond = layer_idx % config.ctrl_conditioning_period == 0 + do_ctrl_cond = config.ctrl_conditioning_period is not None and layer_idx % config.ctrl_conditioning_period == 0 self.ctrl_mlpfusion = MLPFusion(config) if do_ctrl_cond else None - def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): + def forward(self, x, pos_ids, rope_angles, cond, ctx, v, kv_cache=None): """ 0) Causal Frame Attention 1) Frame->CTX Cross Attention @@ -163,7 +163,7 @@ def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): # Self / Causal Attention residual = x x = ada_rmsnorm(x, s0, b0) - x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache) + x, v = self.attn(x, pos_ids, rope_angles, v, kv_cache=kv_cache) x = ada_gate(x, g0) + residual # Cross Attention Prompt Conditioning @@ -189,6 +189,7 @@ def __init__(self, config): super().__init__() self.config = config self.blocks = nn.ModuleList([WorldDiTBlock(config, idx) for idx in range(config.n_layers)]) + self.rope_angles = OrthoRoPEAngles(config) if self.config.noise_conditioning in ("dit_air", "wan"): ref_proj = self.blocks[0].cond_head.cond_proj @@ -196,15 +197,11 @@ def __init__(self, config): for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj): blk_mod.weight = ref_mod.weight - # Shared RoPE buffers - ref_rope = self.blocks[0].attn.rope - for blk in self.blocks[1:]: - blk.attn.rope = ref_rope - def forward(self, x, pos_ids, cond, ctx, kv_cache=None): + rope_angles = self.rope_angles(pos_ids) v = None for i, block in enumerate(self.blocks): - x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache) + x, v = block(x, pos_ids, rope_angles, cond, ctx, v, kv_cache=kv_cache) return x @@ -234,7 +231,7 @@ def __init__(self, config): self.transformer = WorldDiT(config) - self.patch = tuple(getattr(config, "patch", (1, 1))) + self.patch = tuple(config.patch) C, D = config.channels, config.d_model self.patchify = nn.Conv2d(C, D, kernel_size=self.patch, stride=self.patch, bias=False) @@ -253,6 +250,7 @@ def forward( x: Tensor, sigma: Tensor, frame_timestamp: Tensor, + frame_idx: Optional[Tensor] = None, prompt_emb: Optional[Tensor] = None, prompt_pad_mask: Optional[Tensor] = None, mouse: Optional[Tensor] = None, @@ -279,7 +277,12 @@ def forward( torch._assert(B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1") self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) pos_ids = TensorDict( - {"t_pos": self._t_pos_1f[None], "y_pos": self._y_pos_1f[None], "x_pos": self._x_pos_1f[None]}, + { + "f_pos": (frame_timestamp if frame_idx is None else frame_idx)[0, 0].expand_as(self._t_pos_1f)[None], + "t_pos": self._t_pos_1f[None], + "y_pos": self._y_pos_1f[None], + "x_pos": self._x_pos_1f[None], + }, batch_size=[1, self._t_pos_1f.numel()], ) @@ -304,3 +307,56 @@ def forward( ) return x + + def load_state_dict(self, state_dict, strict=True, assign=False): + if self.config.model_type != "waypoint-1.5": + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + state_dict = dict(state_dict) + + if "unpatchify.weight" in state_dict and state_dict["unpatchify.weight"].ndim == 4: + w = state_dict["unpatchify.weight"] # [D, C, ph, pw] + state_dict["unpatchify.weight"] = w.permute(1, 2, 3, 0).reshape(-1, w.shape[0]) + if "unpatchify.bias" in state_dict and state_dict["unpatchify.bias"].numel() != self.unpatchify.bias.numel(): + ph, pw = self.patch + state_dict["unpatchify.bias"] = state_dict["unpatchify.bias"][:, None, None].expand(-1, ph, pw).reshape(-1) + + for i in range(self.config.n_layers): + p = f"transformer.blocks.{i}." + + for name in ("fc1.weight", "fc2.weight"): + old = p + "dit_mlp." + name + if old in state_dict: + state_dict.setdefault(p + "mlp." + name, state_dict.pop(old)) + + attn_bias = state_dict.pop(p + "attn_cond_head.bias_in", None) + mlp_bias = state_dict.pop(p + "mlp_cond_head.bias_in", None) + if attn_bias is not None or mlp_bias is not None: + state_dict.setdefault(p + "cond_head.bias_in", mlp_bias if mlp_bias is not None else attn_bias) + + for j in range(3): + attn = state_dict.pop(p + f"attn_cond_head.cond_proj.{j}.weight", None) + mlp = state_dict.pop(p + f"mlp_cond_head.cond_proj.{j}.weight", None) + if attn is not None: + state_dict.setdefault(p + f"cond_head.cond_proj.{j}.weight", attn) + if mlp is not None: + state_dict.setdefault(p + f"cond_head.cond_proj.{j + 3}.weight", mlp) + + x = state_dict.pop(p + "ctrl_mlpfusion.fc1_x.weight", None) + c = state_dict.pop(p + "ctrl_mlpfusion.fc1_c.weight", None) + if x is not None and c is not None: + state_dict.setdefault(p + "ctrl_mlpfusion.mlp.fc1.weight", torch.cat((x, c), dim=1)) + old = state_dict.pop(p + "ctrl_mlpfusion.fc2.weight", None) + if old is not None: + state_dict.setdefault(p + "ctrl_mlpfusion.mlp.fc2.weight", old) + + ref = "transformer.blocks.0.cond_head.cond_proj." + for i in range(1, self.config.n_layers): + p = f"transformer.blocks.{i}.cond_head.cond_proj." + for j in range(6): + k, rk = p + f"{j}.weight", ref + f"{j}.weight" + if k not in state_dict and rk in state_dict: + state_dict[k] = state_dict[rk] + state_dict = {k: v for k, v in state_dict.items() if ".cond_heads." not in k} + + return super().load_state_dict(state_dict, strict=strict, assign=assign) diff --git a/src/patch_model.py b/src/patch_model.py index 987ded1..817f220 100644 --- a/src/patch_model.py +++ b/src/patch_model.py @@ -113,7 +113,7 @@ def __init__(self, src: Attn, config): del self.q_proj, self.k_proj, self.v_proj - def forward(self, x, pos_ids, v1, kv_cache): + def forward(self, x, pos_ids, rope_angles, v1, kv_cache): q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1) B, T = x.shape[:2] @@ -126,7 +126,7 @@ def forward(self, x, pos_ids, v1, kv_cache): v = torch.lerp(v, v1.view_as(v), self.v_lamb) q, k = rms_norm(q), rms_norm(k) - q, k = self.rope(q, pos_ids), self.rope(k, pos_ids) + q, k = self.rope(q, rope_angles), self.rope(k, rope_angles) k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx) diff --git a/src/quantize.py b/src/quantize.py index 172b8b1..b74825b 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -151,7 +151,7 @@ def __init__(self, lin: nn.Linear): if lin.bias is not None else None ) - w_amax = lin.weight.data.clone().amax().float().squeeze() + w_amax = lin.weight.data.abs().amax() w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn) self.register_buffer("w_amax", w_amax) self.register_buffer("weightT", w.t()) diff --git a/src/world_engine.py b/src/world_engine.py index 09a65e3..91fc2ae 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from .model import WorldModel, StaticKVCache, PromptEncoder -from .ae import InferenceAE +from .ae import get_ae from .patch_model import apply_inference_patches from .quantize import quantize_model @@ -16,12 +16,21 @@ # fix graph break: torch._dynamo.config.capture_scalar_outputs = True +COMPILE_OPTIONS = { + "max_autotune": True, + "coordinate_descent_tuning": True, + "triton.cudagraphs": True, + # Negligible improvement in throughput: + # "epilogue_fusion": True, + # "shape_padding": True, +} + @dataclass class CtrlInput: button: Set[int] = field(default_factory=set) # pressed button IDs - mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) velocity - scroll_wheel: int = 0 # bwd, stationary, or fwd -> (-1, 0, 1) + mouse: Tuple[float, float] = (0.0, 0.0) # (dx, dy) velocity + scroll_wheel: int = 0 # down, stationary, or up -> (-1, 0, 1) class WorldEngine: @@ -37,53 +46,62 @@ def __init__( """ model_uri: HF URI or local folder containing model.safetensors and config.yaml quant: None | w8a8 | nvfp4 + model_config_overrides: Dict to override model config values + - auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p """ - self.device, self.dtype = device, dtype + self.device = torch.get_default_device() if device is None else device + self.dtype = torch.get_default_dtype() if dtype is None else dtype self.model_cfg = WorldModel.load_config(model_uri) if model_config_overrides: self.model_cfg.merge_with(model_config_overrides) - # Model - self.vae = InferenceAE.from_pretrained(self.model_cfg.ae_uri, device=device, dtype=dtype) - - self.prompt_encoder = None - if self.model_cfg.prompt_conditioning is not None: - pe_uri = getattr(self.model_cfg, "prompt_encoder_uri", "google/umt5-xl") - self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).to(device).eval() - - if load_weights: - self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg) - else: - self.model = WorldModel(self.model_cfg) - self.model = self.model.to(device=device, dtype=dtype).eval() - - apply_inference_patches(self.model) - - if quant is not None: - quantize_model(self.model, quant) - - # Inference Scheduler - self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, device=device, dtype=dtype) - - pH, pW = getattr(self.model_cfg, "patch", [1, 1]) - self.frm_shape = 1, 1, self.model_cfg.channels, self.model_cfg.height * pH, self.model_cfg.width * pW - - # State - self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device) - self.frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) - - # Static input context tensors - self._ctx = { - "button": torch.zeros((1, 1, self.model_cfg.n_buttons), device=device, dtype=dtype), - "mouse": torch.zeros((1, 1, 2), device=device, dtype=dtype), - "scroll": torch.zeros((1, 1, 1), device=device, dtype=dtype), - "frame_timestamp": torch.empty((1, 1), device=device, dtype=torch.long), - } - - self._prompt_ctx = {"prompt_emb": None, "prompt_pad_mask": None} + with torch.device(self.device): + # Load Model / Modules + self.vae = get_ae( + self.model_cfg.ae_uri, + is_taehv_ae=self.model_cfg.taehv_ae, + auto_aspect_ratio=self.model_cfg.auto_aspect_ratio, + dtype=dtype, + device=device, + ) + + self.prompt_encoder = None + if self.model_cfg.prompt_conditioning is not None: + self.prompt_encoder = PromptEncoder(self.model_cfg.prompt_encoder_uri, dtype=dtype).eval() + + self.model = WorldModel.from_pretrained( + model_uri, cfg=self.model_cfg, device=self.device, dtype=dtype, load_weights=load_weights + ).eval() + apply_inference_patches(self.model) + if quant is not None: + quantize_model(self.model, quant) + + self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device) + + # Inference Scheduler + self.scheduler_sigmas = torch.tensor(self.model_cfg.scheduler_sigmas, dtype=dtype, device=device) + + pH, pW = self.model_cfg.patch + self.frm_shape = 1, 1, self.model_cfg.channels, self.model_cfg.height * pH, self.model_cfg.width * pW + + # State + latent_fps = self.model_cfg.inference_fps / self.model_cfg.temporal_compression + self.ts_mult = int(self.model_cfg.base_fps) // latent_fps + self.frame_ts = torch.tensor([[0]], dtype=torch.long) + + # Static input context tensors + self._ctx = { + "button": torch.zeros((1, 1, self.model_cfg.n_buttons), dtype=dtype), + "mouse": torch.zeros((1, 1, 2), dtype=dtype), + "scroll": torch.zeros((1, 1, 1), dtype=dtype), + "frame_timestamp": torch.empty((1, 1), dtype=torch.long), + "frame_idx": torch.empty((1, 1), dtype=torch.long), + } + + self._prompt_ctx = {"prompt_emb": None, "prompt_pad_mask": None} @torch.inference_mode() def reset(self): @@ -92,6 +110,18 @@ def reset(self): self.frame_ts.zero_() for v in self._ctx.values(): v.zero_() + self.vae.reset() + + @torch.inference_mode() + def get_state(self): + """Captures a world state to continue via load_state. Doesn't save model""" + return {"kv_cache": self.kv_cache.get_state(), "frame_ts": self.frame_ts.detach().clone()} + + @torch.inference_mode() + def load_state(self, state): + """Loads a world state object saved via save_state. Doesn't load or change model""" + self.kv_cache.load_state(state["kv_cache"]) + self.frame_ts.copy_(state["frame_ts"]) def set_prompt(self, prompt: str): """Apply text conditioning for T2V""" @@ -117,22 +147,25 @@ def gen_frame(self, ctrl: CtrlInput = None, return_img: bool = True): @torch.compile def _prep_inputs(self, x, ctrl=None): + self._ctx["button"].zero_() + self._ctx["button"][..., ctrl.button] = 1.0 + self._ctx["mouse"][0, 0, 0] = ctrl.mouse[0] self._ctx["mouse"][0, 0, 1] = ctrl.mouse[1] + self._ctx["scroll"][0, 0, 0] = ctrl.scroll_wheel - self._ctx["frame_timestamp"].copy_(self.frame_ts) + self._ctx["frame_idx"].copy_(self.frame_ts) + self._ctx["frame_timestamp"].copy_(self.frame_ts).mul_(self.ts_mult) self.frame_ts.add_(1) return self._ctx def prep_inputs(self, x, ctrl=None): ctrl = ctrl if ctrl is not None else CtrlInput() - self._ctx["button"].zero_() - if ctrl.button: - self._ctx["button"][..., list(ctrl.button)] = 1.0 - ctrl.mouse = torch.tensor(ctrl.mouse, device=x.device, dtype=self._ctx["mouse"].dtype) - ctrl.scroll_wheel = torch.sign(torch.tensor(ctrl.scroll_wheel, device=x.device, dtype=self._ctx["scroll"].dtype)) + ctrl.button = torch.as_tensor(list(ctrl.button), dtype=torch.int64).to(x.device, non_blocking=True) + ctrl.mouse = torch.as_tensor(ctrl.mouse).to(x.device, non_blocking=True) + ctrl.scroll_wheel = torch.as_tensor(ctrl.scroll_wheel).to(x.device, non_blocking=True) ctx = self._prep_inputs(x, ctrl) # prepare prompt conditioning @@ -142,7 +175,7 @@ def prep_inputs(self, x, ctrl=None): self.set_prompt("An explorable world") return {**ctx, **self._prompt_ctx} - @torch.compile(fullgraph=True, mode="max-autotune", dynamic=False) + @torch.compile(fullgraph=True, dynamic=False, options=COMPILE_OPTIONS) def _denoise_pass(self, x, ctx: Dict[str, Tensor], kv_cache): kv_cache.set_frozen(True) sigma = x.new_empty((x.size(0), x.size(1))) @@ -151,7 +184,7 @@ def _denoise_pass(self, x, ctx: Dict[str, Tensor], kv_cache): x = x + step_dsig * v return x - @torch.compile(fullgraph=True, mode="max-autotune", dynamic=False) + @torch.compile(fullgraph=True, dynamic=False, options=COMPILE_OPTIONS) def _cache_pass(self, x, ctx: Dict[str, Tensor], kv_cache): """Side effect: updates kv cache""" kv_cache.set_frozen(False)