Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
669,136 changes: 1 addition & 669,135 deletions configs/T5_tokenizer/tokenizer.json

Large diffs are not rendered by default.

2,748 changes: 0 additions & 2,748 deletions configs/T5_tokenizer/tokenizer_config.json

Large diffs are not rendered by default.

14 changes: 0 additions & 14 deletions configs/transformer_config_i2v.json
Original file line number Diff line number Diff line change
@@ -1,14 +0,0 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"model_type": "i2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512
}
184 changes: 0 additions & 184 deletions context.py
Original file line number Diff line number Diff line change
@@ -1,184 +0,0 @@
import numpy as np
from typing import Callable, Optional, List


def ordered_halving(val):
bin_str = f"{val:064b}"
bin_flip = bin_str[::-1]
as_int = int(bin_flip, 2)

return as_int / (1 << 64)

def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
prev_val = -1
for i, val in enumerate(window):
val = val % num_frames
if val < prev_val:
return True, i
prev_val = val
return False, -1

def shift_window_to_start(window: list[int], num_frames: int):
start_val = window[0]
for i in range(len(window)):
# 1) subtract each element by start_val to move vals relative to the start of all frames
# 2) add num_frames and take modulus to get adjusted vals
window[i] = ((window[i] - start_val) + num_frames) % num_frames

def shift_window_to_end(window: list[int], num_frames: int):
# 1) shift window to start
shift_window_to_start(window, num_frames)
end_val = window[-1]
end_delta = num_frames - end_val - 1
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta

def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
all_indexes = list(range(num_frames))
for w in windows:
for val in w:
try:
all_indexes.remove(val)
except ValueError:
pass
return all_indexes

def uniform_looped(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
if num_frames <= context_size:
yield list(range(num_frames))
return

context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)

for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(step)))
for j in range(
int(ordered_halving(step) * context_step) + pad,
num_frames + pad + (0 if closed_loop else -context_overlap),
(context_size * context_step - context_overlap),
):
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]

#from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved)
def uniform_standard(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
windows = []
if num_frames <= context_size:
windows.append(list(range(num_frames)))
return windows

context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)

for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(step)))
for j in range(
int(ordered_halving(step) * context_step) + pad,
num_frames + pad + (0 if closed_loop else -context_overlap),
(context_size * context_step - context_overlap),
):
windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)])

# now that windows are created, shift any windows that loop, and delete duplicate windows
delete_idxs = []
win_i = 0
while win_i < len(windows):
# if window is rolls over itself, need to shift it
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
if is_roll:
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
shift_window_to_end(windows[win_i], num_frames=num_frames)
# check if next window (cyclical) is missing roll_val
if roll_val not in windows[(win_i+1) % len(windows)]:
# need to insert new window here - just insert window starting at roll_val
windows.insert(win_i+1, list(range(roll_val, roll_val + context_size)))
# delete window if it's not unique
for pre_i in range(0, win_i):
if windows[win_i] == windows[pre_i]:
delete_idxs.append(win_i)
break
win_i += 1

# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
delete_idxs.reverse()
for i in delete_idxs:
windows.pop(i)
return windows

def static_standard(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
windows = []
if num_frames <= context_size:
windows.append(list(range(num_frames)))
return windows
# always return the same set of windows
delta = context_size - context_overlap
for start_idx in range(0, num_frames, delta):
# if past the end of frames, move start_idx back to allow same context_length
ending = start_idx + context_size
if ending >= num_frames:
final_delta = ending - num_frames
final_start_idx = start_idx - final_delta
windows.append(list(range(final_start_idx, final_start_idx + context_size)))
break
windows.append(list(range(start_idx, start_idx + context_size)))
return windows

def get_context_scheduler(name: str) -> Callable:
if name == "uniform_looped":
return uniform_looped
elif name == "uniform_standard":
return uniform_standard
elif name == "static_standard":
return static_standard
else:
raise ValueError(f"Unknown context_overlap policy {name}")


def get_total_steps(
scheduler,
timesteps: List[int],
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
return sum(
len(
list(
scheduler(
i,
num_steps,
num_frames,
context_size,
context_stride,
context_overlap,
)
)
)
for i in range(len(timesteps))
)
Loading