-
Notifications
You must be signed in to change notification settings - Fork 8
Waypoint-1.5 Support #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lapp0
wants to merge
53
commits into
main
Choose a base branch
from
wp-1.5
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+507
−311
Open
Changes from all commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
6198fa1
Fix README examples and align default model URI in example scripts
MalarzDawid 991a861
Add examples extra dependencies for OpenCV and benchmark tooling
MalarzDawid 3ab6475
implement state loading / saving
lapp0 f1c93e1
moe + fbgemm optimization
lapp0 c6f95be
wp-1.5 staging
lapp0 7cf8c25
clean up and fix ae
lapp0 586a3c0
fix temporal compression rope bugs
lapp0 5125dc1
vae reset in world_engine.reset
lapp0 9ba9b4d
reduce peak memory
lapp0 4c5ecb5
Implements the orthorope angles computation instead of precomputing (…
Clydingus bf90520
test: revert direct device init (#28)
Clydingus 177101f
feat: use built triton-windows fork to fix long-path issue
philpax fe5873d
update gen_sample
lapp0 1935b64
better quant
lapp0 facd12a
avoid warning when creating mouse / scroll tensors
lapp0 b2b3fb6
disable unimportant compile options
lapp0 48d5f68
Merge remote-tracking branch 'origin/wp-1.5' into wp1.5
lapp0 8f84795
clean up model loading
lapp0 3df610b
remove unnecessary push_to_hub
lapp0 a437614
remove unnecessary save_pretrained
lapp0 c076f88
Merge pull request #29 from Overworldai/use-patched-triton-windows
philpax 39630e4
reduce cpu memory
lapp0 5ba3689
Merge remote-tracking branch 'origin/wp-1.5' into wp1.5
lapp0 235276f
pass device
lapp0 d4fd76a
fix #27 - use triton-windows longpath fix
philpax 4469f3e
cleanup dead code
lapp0 aa02df5
Merge remote-tracking branch 'refs/remotes/origin/wp-1.5' into wp-1.5
lapp0 4511a7b
auto 720p
lapp0 844c44c
ensure correct device
lapp0 3612b8b
no internal model URIs, document requirements in docstring at top
lapp0 f5cc301
update readme to document WP1.5
lapp0 5c910b7
update readme to document WP1.5
lapp0 274d685
no fbgemm dep
lapp0 ae9ac61
benchmark dont force AE
lapp0 25503c1
move kv cache to appropriate device
lapp0 da66a6d
improve example w/ four_frames var
lapp0 f482fc2
improve example w/ four_frames var
lapp0 e8cd112
improve example w/ four_frames var
lapp0 9308e9e
dev dependency group for examples, uv docs
lapp0 c2261dd
dev dependency group for examples, uv docs
lapp0 e2060f0
fix a missing word
lapp0 c487603
Credit PR #20
lapp0 3d8b327
remove rotary embedding pytorch dependency
lapp0 ebe31bb
improve throughput by 4%
lapp0 3ff84dd
add non-blocking benchmark
lapp0 60c8864
no compile prep inputs
lapp0 f5bf64e
compile prep inputs after converting to tensor, avoid blocking
lapp0 2926cfb
fix, don't pass device to prompt encoder
lapp0 946be2c
fix button incorrect
lapp0 f0be311
fix button incorrect
lapp0 a236313
benchmark w/ ctrls
lapp0 9607bcf
benchmark w/ ctrls
lapp0 1d286e1
update config defaults
lapp0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,7 +60,7 @@ export HF_TOKEN=<your access token> | |
| from world_engine import WorldEngine, CtrlInput | ||
|
|
||
| # Create inference engine | ||
| engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") | ||
| engine = WorldEngine("Overworld/Waypoint-1.5-1B", device="cuda") | ||
|
|
||
| # Specify a prompt | ||
| engine.set_prompt("A fun game") | ||
|
|
@@ -77,14 +77,25 @@ for controller_input in [ | |
| img = engine.gen_frame(ctrl=controller_input) | ||
| ``` | ||
|
|
||
| ## Waypoint-1.5 Behavior | ||
| All interfaces and handling for Waypoint-1 (or 1.1) and Waypoint-1.5 remain the same **except** the following: | ||
|
|
||
| In Waypoint-1.5, the `img` passed to `append_frame(...)` and returned by `gen_frame(...)` is now a sequence of 4 frames. Waypoint-1.5 applies temporal compression and generates 4 frames for every controller input. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. describe the implications this has for frame pacing; what's the correct way to feed inputs and display the rendered frames to the user? |
||
|
|
||
| Whereas previously, `img` was a uint8 rgb array of shape `[Height, Width, 3]`, **in Waypoint-1.5 it is of shape `[4, Height, Width, 3]`**. | ||
|
|
||
| Additionally, Waypoint-1.5 expects 720p inputs / outputs, therefore `img` is `[4, 720, 1280, 3]`. | ||
|
|
||
| See [examples/gen_sample.py](./examples/gen_sample.py) for reference. | ||
|
|
||
| ## Usage | ||
| ``` | ||
| from world_engine import WorldEngine, CtrlInput | ||
| ``` | ||
|
|
||
| Load model to GPU | ||
| ``` | ||
| engine = WorldEngine("Overworld/Waypoint-1-Small", device="cuda") | ||
| engine = WorldEngine("Overworld/Waypoint-1.5-1B", device="cuda") | ||
| ``` | ||
|
|
||
| Specify a prompt which will be used until this function is called again | ||
|
|
@@ -118,11 +129,13 @@ Note: returned `img` is always on the same device as `engine.device` | |
| @dataclass | ||
| class CtrlInput: | ||
| button: Set[int] = field(default_factory=set) # pressed button IDs | ||
| mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) position | ||
| mouse: Tuple[float, float] = (0.0, 0.0) # (dx, dy) position change | ||
| scroll_wheel: int = 0 # down, stationary, or up -> (-1, 0, 1) | ||
| ``` | ||
|
|
||
| - `button` keycodes are defined by [Owl-Control](https://github.com/Overworldai/owl-control/blob/main/src/system/keycode.rs) | ||
| - `mouse` is the raw mouse velocity vector | ||
| - `mouse` is the the amount the change in mouse since last frame | ||
| - `scroll_wheel` is the ternary scroll wheel movement identifier | ||
|
|
||
|
|
||
| ## Showcase and Examples | ||
|
|
@@ -138,5 +151,5 @@ class CtrlInput: | |
|
|
||
| ### Examples and Reference Code | ||
|
|
||
| - ["Hello (Over)World" client](./examples/simple_client.py) | ||
| - ["Generate MP4 Sample Given Controller Inputs](./examples/gen_sample.py) | ||
| - [Run Performance Benchmarks (`pytest examples/benchmark.py`)](./examples/benchmark.py) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,22 +1,53 @@ | ||
| # uv run --dev examples/gen_sample.py Overworld/Waypoint-1.5-1B | ||
|
|
||
| import cv2 | ||
| from world_engine import WorldEngine | ||
| import imageio.v3 as iio | ||
| import random | ||
| import sys | ||
| import urllib.request | ||
| import numpy as np | ||
| import torch | ||
|
|
||
| from world_engine import WorldEngine, CtrlInput | ||
|
|
||
|
|
||
| # Create inference engine | ||
| engine = WorldEngine(sys.argv[1], device="cuda") | ||
|
|
||
|
|
||
| # Define sequence of controller inputs applied | ||
| controller_sequence = [ | ||
| # move mouse, jump, do nothing, trigger, do nothing, trigger+jump, do nothing | ||
| CtrlInput(mouse=[0.2, 0.2]), CtrlInput(button={32}), CtrlInput(), CtrlInput(), CtrlInput(), | ||
| CtrlInput(button={1}), CtrlInput(), CtrlInput(), CtrlInput(button={1, 32}), | ||
| CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), CtrlInput(), | ||
| ] * 4 | ||
| controller_sequence += [CtrlInput()] * 8 | ||
| controller_sequence += ( | ||
| [CtrlInput(button={32})] * 10 + # forward | ||
| [CtrlInput(button={65})] * 10 + # left | ||
| [CtrlInput(button={68})] * 10 + # right | ||
| [CtrlInput(button={83})] * 10 # backwards | ||
| ) | ||
| controller_sequence += [CtrlInput()] * 10 | ||
|
|
||
| def gen_vid(): | ||
| engine = WorldEngine("OpenWorldLabs/CoDCtl-Causal-Flux-SelfForcing", device="cuda") | ||
| writer = None | ||
| for _ in range(240): | ||
| frame = engine.gen_frame().cpu().numpy()[:, :, ::-1] # RGB -> BGR for OpenCV | ||
| writer = writer or cv2.VideoWriter( | ||
| "out.mp4", | ||
| cv2.VideoWriter_fourcc(*"mp4v"), | ||
| 60, | ||
| (frame.shape[1], frame.shape[0]) | ||
| ) | ||
| writer.write(frame) | ||
|
|
||
| writer.release() | ||
| # Set seed frame | ||
| url = random.choice([ | ||
| "https://gist.github.com/user-attachments/assets/d81c6d26-a838-4afe-9d13-fd67677043c3", | ||
| "https://gist.github.com/user-attachments/assets/b6d18c38-098e-43b0-8e61-66a16e5d8946", | ||
| "https://gist.github.com/user-attachments/assets/0734a8c1-3eb4-4ffe-8c37-5665c45ab559", | ||
| "https://gist.github.com/user-attachments/assets/f9c20d4d-7565-452d-8b02-42a85ea175ed", | ||
| "https://gist.github.com/user-attachments/assets/68c943a4-008a-4c25-948c-c81ab4c47d21", | ||
| ]) | ||
| seed_frame = cv2.imdecode(np.frombuffer(urllib.request.urlopen(url).read(), np.uint8), cv2.IMREAD_COLOR) | ||
| seed_frame_x4 = torch.from_numpy(np.repeat(seed_frame[None], 4, axis=0)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| gen_vid() | ||
| # Generate frames conditioned on controller inputs | ||
| with iio.imopen("out.mp4", "w", plugin="pyav") as out: | ||
| engine.append_frame(seed_frame_x4) | ||
| out.write(seed_frame_x4, fps=60, codec="libx264") | ||
| for ctrl in controller_sequence: | ||
| four_frames = engine.gen_frame(ctrl=ctrl).cpu().numpy() | ||
| out.write(four_frames) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this probably needs to be updated for 4-frame use, or this snippet should be deleted entirely and pointed at one of the examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, the Waypoint-1.5 clarification below on the nature of
imgis sufficientThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add a comment pointing to the clarification below, so they have an idea of what to expect for the shape of
img