diff --git a/.gitignore b/.gitignore index acb8253..999fe96 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,5 @@ dist/* # examples examples/python/view_control.pkl +video/*.mp4 +!video/test.mp4 diff --git a/README.md b/README.md index 3e0f85c..c12af1f 100644 --- a/README.md +++ b/README.md @@ -321,6 +321,24 @@ struct pose_frame { }; ``` +pose definition: +`pose_frame` is defined in the world coordinate system. The axis convention is: +- X: left +- Y: up +- Z: into the page + +right-handed: +``` + Y+ + | / Z+ (into page) + | / +X+ <------O + +``` + +Example visualization output: [docs/vis_pose_result.mp4](docs/vis_pose_result.mp4) + + ## 📖 Usage Guide (python) diff --git a/docs/vis_pose_result.mp4 b/docs/vis_pose_result.mp4 new file mode 100644 index 0000000..7e5e67b Binary files /dev/null and b/docs/vis_pose_result.mp4 differ diff --git a/examples/python/vis_pose.py b/examples/python/vis_pose.py new file mode 100644 index 0000000..fe6c16a --- /dev/null +++ b/examples/python/vis_pose.py @@ -0,0 +1,300 @@ +import os +import subprocess +from typing import Optional + +import numpy as np +import typer +import shutil +from tqdm import tqdm + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import spatialmp4 as sm + + +def quat_to_rot(qx: float, qy: float, qz: float, qw: float) -> np.ndarray: + norm = qx * qx + qy * qy + qz * qz + qw * qw + if norm < 1e-12: + return np.eye(3) + s = 2.0 / norm + x = qx * s + y = qy * s + z = qz * s + wx = qw * x + wy = qw * y + wz = qw * z + xx = qx * x + xy = qx * y + xz = qx * z + yy = qy * y + yz = qy * z + zz = qz * z + return np.array( + [ + [1.0 - (yy + zz), xy - wz, xz + wy], + [xy + wz, 1.0 - (xx + zz), yz - wx], + [xz - wy, yz + wx, 1.0 - (xx + yy)], + ], + dtype=np.float64, + ) + + +def pose_to_matrix(pose: sm.PoseFrame) -> np.ndarray: + rot = quat_to_rot(pose.qx, pose.qy, pose.qz, pose.qw) + mat = np.eye(4, dtype=np.float64) + mat[:3, :3] = rot + mat[:3, 3] = np.array([pose.x, pose.y, pose.z], dtype=np.float64) + return mat + + +def make_cube(size: float) -> tuple[np.ndarray, list[tuple[int, int]]]: + half = size * 0.5 + vertices = np.array( + [ + [-half, -half, -half], + [half, -half, -half], + [half, half, -half], + [-half, half, -half], + [-half, -half, half], + [half, -half, half], + [half, half, half], + [-half, half, half], + ], + dtype=np.float64, + ) + edges = [ + (0, 1), + (1, 2), + (2, 3), + (3, 0), + (4, 5), + (5, 6), + (6, 7), + (7, 4), + (0, 4), + (1, 5), + (2, 6), + (3, 7), + ] + return vertices, edges + + +def update_cube_lines(lines, vertices, edges, camera_origin, min_lw=0.8, max_lw=3.0): + midpoints = [] + for i, j in edges: + midpoints.append((vertices[i] + vertices[j]) * 0.5) + midpoints = np.array(midpoints, dtype=np.float64) + distances = np.linalg.norm(midpoints - camera_origin, axis=1) + d_min = distances.min() + d_max = distances.max() + if d_max <= d_min + 1e-9: + widths = np.full_like(distances, max_lw) + else: + t = (distances - d_min) / (d_max - d_min) + widths = max_lw - t * (max_lw - min_lw) + for line, (i, j), lw in zip(lines, edges, widths): + xs = [vertices[i, 0], vertices[j, 0]] + ys = [vertices[i, 1], vertices[j, 1]] + zs = [vertices[i, 2], vertices[j, 2]] + line.set_data(xs, ys) + line.set_3d_properties(zs) + line.set_linewidth(float(lw)) + + +def apply_world_axes_transform(T_W_H: np.ndarray) -> np.ndarray: + transform = np.diag([-1.0, 1.0, -1.0, 1.0]) + return transform @ T_W_H + + +def compute_bounds( + reader: sm.Reader, + topk: Optional[int], + sample: int, + total_samples: int, +) -> tuple[np.ndarray, np.ndarray]: + min_xyz = None + max_xyz = None + valid_count = 0 + total_count = 0 + + with tqdm(total=total_samples, desc="Bounds", unit="frame") as pbar: + while reader.has_next(): + if topk is not None and reader.get_index() >= topk: + break + if sample > 1 and reader.get_index() % sample != 0: + reader.load_rgb() + continue + rgb_frame = reader.load_rgb() + pose = rgb_frame.pose + total_count += 1 + if pose.timestamp == 0: + pbar.update(1) + continue + T_W_H = apply_world_axes_transform(pose_to_matrix(pose)) + t = T_W_H[:3, 3] + if min_xyz is None: + min_xyz = t.copy() + max_xyz = t.copy() + else: + min_xyz = np.minimum(min_xyz, t) + max_xyz = np.maximum(max_xyz, t) + valid_count += 1 + pbar.update(1) + + if min_xyz is None: + raise ValueError("No valid pose frames found") + + typer.echo(f"Bounds pass complete: {valid_count} valid pose frames, {total_count} total frames") + return min_xyz, max_xyz + + +def main( + video_file: str, + topk: Optional[int] = typer.Option(None, help="Limit to the first N frames"), + cube_size: float = typer.Option(0.2, help="Cube size in meters"), + axis_len: float = typer.Option(0.4, help="World axis length in meters"), + output_dir: Optional[str] = typer.Option( + None, + help="Output directory (default: