-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_plausibilities.py
More file actions
150 lines (121 loc) · 7.05 KB
/
create_plausibilities.py
File metadata and controls
150 lines (121 loc) · 7.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import pickle
import numpy as np
import argparse
from tqdm import tqdm
from collections import defaultdict
def create_plausibilities(args):
# 1. Load Data
print("Loading indexed data...")
index_path = args.input_index
if not os.path.exists(index_path):
raise FileNotFoundError(f"Index file not found at {index_path}.")
index_data = np.load(index_path)
poses = index_data['poses']
trans = index_data['trans']
file_indices = index_data['file_indices']
codebook_path = args.input_codebook
if not os.path.exists(codebook_path):
raise FileNotFoundError(f"Codebook file not found at {codebook_path}.")
codebook_data = np.load(codebook_path)
codebook_tokens = codebook_data['tokens']
# 2. Compute velocities and trajectories matching the engine's MM logic
print("Computing velocities and trajectories...")
valid_mask = file_indices[:-1] == file_indices[1:]
valid_mask = np.append(valid_mask, False)
pose_vel = np.zeros_like(poses)
trans_vel = np.zeros_like(trans)
pose_vel[:-1][valid_mask[:-1]] = poses[1:][valid_mask[:-1]] - poses[:-1][valid_mask[:-1]]
trans_vel[:-1][valid_mask[:-1]] = trans[1:][valid_mask[:-1]] - trans[:-1][valid_mask[:-1]]
traj_15 = np.zeros_like(trans)
traj_30 = np.zeros_like(trans)
for offset, traj_array in [(15, traj_15), (30, traj_30)]:
shifted_trans = np.roll(trans, -offset, axis=0)
shifted_file_indices = np.roll(file_indices, -offset, axis=0)
valid_traj_mask = (file_indices == shifted_file_indices)
traj_array[valid_traj_mask] = shifted_trans[valid_traj_mask] - trans[valid_traj_mask]
traj_array[~valid_traj_mask] = trans_vel[~valid_traj_mask] * offset
# 3. Group frames by Region
print("Grouping valid frames by Codebook Region...")
region_to_frames = defaultdict(list)
for idx, reg in enumerate(codebook_tokens):
if reg != -1 and valid_mask[idx]:
region_to_frames[reg].append(idx)
regions = sorted(list(region_to_frames.keys()))
print(f"Found {len(regions)} valid codebook regions.")
# 4. Build Plausibility Graph
graph = defaultdict(dict)
# Fast sampling approach
np.random.seed(42)
max_source_samples = args.max_source_samples # Number of sequences to forward-simulate per region
# 4. Extract global feature vectors for fast L2 matrix operations
# Combine the weights directly into the feature vectors so simple L2 distance works implicitly
print("Building global continuous feature array...")
# shape: (N, feature_dim)
# MM Cost: pose(x1) + pose_vel(x1) + 10*trans_vel(x10) + 2*traj15(x2) + 2*traj30(x2)
global_features = np.concatenate([
poses,
pose_vel,
trans_vel * np.sqrt(10.0),
traj_15 * np.sqrt(2.0),
traj_30 * np.sqrt(2.0)
], axis=1).astype(np.float32)
print("Building graph edges...")
for reg_A in tqdm(regions, desc="Processing source regions"):
frames_A = region_to_frames[reg_A]
# Sample starting frames for Region A to avoid excessive computation
num_start_samples = min(max_source_samples, len(frames_A))
start_frames = np.random.choice(frames_A, num_start_samples, replace=False)
# Collect window frames
window_frames = []
for i in start_frames:
i_end = i + args.fast_forward_frames
max_idx = i_end + args.transition_window_size
if max_idx < len(file_indices) and file_indices[i] == file_indices[max_idx]:
window_frames.extend(range(i_end, max_idx + 1))
if not window_frames:
continue
window_frames = np.array(window_frames)
feat_A_windows = global_features[window_frames] # [W, Feature_Dim]
for reg_B in regions:
frames_B = region_to_frames[reg_B]
if not frames_B:
continue
num_samples = min(args.max_target_samples, len(frames_B))
candidates_j = np.random.choice(frames_B, num_samples, replace=False)
feat_B_candidates = global_features[candidates_j] # [C, Feature_Dim]
# Efficient pairwise squared euclidean distance calculation:
# (a-b)^2 = a^2 + b^2 - 2ab
# Using scipy.spatial.distance.cdist or einsum is incredibly fast here
from scipy.spatial.distance import cdist
dist_matrix = cdist(feat_A_windows, feat_B_candidates, metric='sqeuclidean')
# For each sampled source frame, find the best landing spot
min_dists = np.min(dist_matrix, axis=1)
valid_sources = min_dists <= args.max_plausible_cost
valid_ratio = np.mean(valid_sources)
if valid_ratio >= args.min_valid_source_ratio:
# Store the average cost of the VALID transitions as the edge weight
avg_valid_cost = np.mean(min_dists[valid_sources])
graph[int(reg_A)][int(reg_B)] = float(avg_valid_cost)
# 5. Save the Graph
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path, 'wb') as f:
pickle.dump(dict(graph), f)
num_edges = sum(len(targets) for targets in graph.values())
print(f"Graph generation complete. Saved to {args.output_path}.")
print(f"Total Regions (Nodes): {len(graph)}")
print(f"Total Valid Transitions (Edges): {num_edges}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Offline Graph Building: Map out codebook transition plausibilities.")
parser.add_argument("--input_index", type=str, default=os.path.join("data", "index", "motion_index.npz"), help="Path to motion index")
parser.add_argument("--input_codebook", type=str, default=os.path.join("data", "index", "codebook.npz"), help="Path to codebook region labels")
parser.add_argument("--output_path", type=str, default=os.path.join("data", "index", "plausibility_graph.pkl"), help="Output path for the graph")
# Engine simulation and transition pruning parameters
parser.add_argument("--fast_forward_frames", type=int, default=30, help="Frames played before allowing a jump search (respects 30-frame lock)")
parser.add_argument("--transition_window_size", type=int, default=15, help="Simulation transition search window")
parser.add_argument("--min_valid_source_ratio", type=float, default=0.5, help="Minimum percentage (0-1) of source frames that must have a valid landing spot to draw a graph edge")
parser.add_argument("--max_plausible_cost", type=float, default=4.0, help="Max standard MM cost before edge is considered an 'impossible' glitch transition")
parser.add_argument("--max_source_samples", type=int, default=300, help="Number of sequences to forward-simulate per region")
parser.add_argument("--max_target_samples", type=int, default=200, help="Number of random candidate transition targets per region")
args = parser.parse_args()
create_plausibilities(args)