diff --git a/README.md b/README.md index 7e3fddb..220db79 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.5 +git checkout v0.1.6 git submodule update --init --recursive --progress conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env. conda activate embodiedgen diff --git a/apps/visualize_asset.py b/apps/visualize_asset.py index 089e329..85e12dd 100644 --- a/apps/visualize_asset.py +++ b/apps/visualize_asset.py @@ -31,8 +31,8 @@ import gradio as gr import pandas as pd -import yaml from app_style import custom_theme, lighting_css +from embodied_gen.utils.tags import VERSION try: from embodied_gen.utils.gpt_clients import GPT_CLIENT as gpt_client @@ -48,7 +48,6 @@ # --- Configuration & Data Loading --- -VERSION = "v0.1.5" RUNNING_MODE = "local" # local or hf_remote CSV_FILE = "dataset_index.csv" diff --git a/docs/install.md b/docs/install.md index 56d200f..8262eba 100644 --- a/docs/install.md +++ b/docs/install.md @@ -7,7 +7,7 @@ hide: ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.5 +git checkout v0.1.6 git submodule update --init --recursive --progress conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env. conda activate embodiedgen diff --git a/docs/tutorials/any_simulators.md b/docs/tutorials/any_simulators.md index 3c2d83e..0d6e146 100644 --- a/docs/tutorials/any_simulators.md +++ b/docs/tutorials/any_simulators.md @@ -35,7 +35,8 @@ Leverage **EmbodiedGen-generated assets** with *accurate physical collisions* an ## 🧱 Example: Conversion to Target Simulator ```python -from embodied_gen.data.asset_converter import SimAssetMapper, cvt_embodiedgen_asset_to_anysim +from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim +from embodied_gen.utils.enum import AssetType, SimAssetMapper from typing import Literal simulator_name: Literal[ @@ -52,6 +53,10 @@ dst_asset_path = cvt_embodiedgen_asset_to_anysim( "path1_to_embodiedgen_asset/asset.urdf", "path2_to_embodiedgen_asset/asset.urdf", ], + target_dirs=[ + "path1_to_target_dir/asset.usd", + "path2_to_target_dir/asset.usd", + ], target_type=SimAssetMapper[simulator_name], source_type=AssetType.MESH, overwrite=True, diff --git a/embodied_gen/data/asset_converter.py b/embodied_gen/data/asset_converter.py index 3b32c93..f4e1ac6 100644 --- a/embodied_gen/data/asset_converter.py +++ b/embodied_gen/data/asset_converter.py @@ -4,12 +4,12 @@ import os import xml.etree.ElementTree as ET from abc import ABC, abstractmethod -from dataclasses import dataclass from glob import glob from shutil import copy, copytree, rmtree import trimesh from scipy.spatial.transform import Rotation +from embodied_gen.utils.enum import AssetType logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -17,75 +17,62 @@ __all__ = [ "AssetConverterFactory", - "AssetType", "MeshtoMJCFConverter", "MeshtoUSDConverter", "URDFtoUSDConverter", "cvt_embodiedgen_asset_to_anysim", "PhysicsUSDAdder", - "SimAssetMapper", ] -@dataclass -class AssetType(str): - """Asset type enumeration.""" - - MJCF = "mjcf" - USD = "usd" - URDF = "urdf" - MESH = "mesh" - - -class SimAssetMapper: - _mapping = dict( - ISAACSIM=AssetType.USD, - ISAACGYM=AssetType.URDF, - MUJOCO=AssetType.MJCF, - GENESIS=AssetType.MJCF, - SAPIEN=AssetType.URDF, - PYBULLET=AssetType.URDF, - ) - - @classmethod - def __class_getitem__(cls, key: str): - key = key.upper() - if key.startswith("SAPIEN"): - key = "SAPIEN" - return cls._mapping[key] - - def cvt_embodiedgen_asset_to_anysim( urdf_files: list[str], + target_dirs: list[str], target_type: AssetType, source_type: AssetType, overwrite: bool = False, **kwargs, ) -> dict[str, str]: - """Convert URDF files generated by EmbodiedGen into the format required by all simulators. + """Convert URDF files generated by EmbodiedGen into formats required by simulators. Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet. + Converting to the `USD` format requires `isaacsim` to be installed. Example: + ```py + from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim + from embodied_gen.utils.enum import AssetType + dst_asset_path = cvt_embodiedgen_asset_to_anysim( - urdf_files, - target_type=SimAssetMapper[simulator_name], + urdf_files=[ + "path1_to_embodiedgen_asset/asset.urdf", + "path2_to_embodiedgen_asset/asset.urdf", + ], + target_dirs=[ + "path1_to_target_dir/asset.usd", + "path2_to_target_dir/asset.usd", + ], + target_type=AssetType.USD, source_type=AssetType.MESH, ) + ``` Args: - urdf_files (List[str]): List of URDF file paths to be converted. - target_type (AssetType): The target asset type. - source_type (AssetType): The source asset type. - overwrite (bool): Whether to overwrite existing converted files. - **kwargs: Additional keyword arguments for the converter. + urdf_files (list[str]): List of URDF file paths. + target_dirs (list[str]): List of target directories. + target_type (AssetType): Target asset type. + source_type (AssetType): Source asset type. + overwrite (bool, optional): Overwrite existing files. + **kwargs: Additional converter arguments. Returns: - Dict[str, str]: A dictionary mapping the original URDF file path to the converted asset file path. + dict[str, str]: Mapping from URDF file to converted asset file. """ if isinstance(urdf_files, str): urdf_files = [urdf_files] + if isinstance(target_dirs, str): + urdf_files = [target_dirs] # If the target type is URDF, no conversion is needed. if target_type == AssetType.URDF: @@ -99,18 +86,17 @@ def cvt_embodiedgen_asset_to_anysim( asset_paths = dict() with asset_converter: - for urdf_file in urdf_files: + for urdf_file, target_dir in zip(urdf_files, target_dirs): filename = os.path.basename(urdf_file).replace(".urdf", "") - asset_dir = os.path.dirname(urdf_file) if target_type == AssetType.MJCF: - target_file = f"{asset_dir}/../mjcf/{filename}.xml" + target_file = f"{target_dir}/{filename}.xml" elif target_type == AssetType.USD: - target_file = f"{asset_dir}/../usd/{filename}.usd" + target_file = f"{target_dir}/{filename}.usd" else: raise NotImplementedError( f"Target type {target_type} not supported." ) - if not os.path.exists(target_file): + if not os.path.exists(target_file) or overwrite: asset_converter.convert(urdf_file, target_file) asset_paths[urdf_file] = target_file @@ -119,16 +105,35 @@ def cvt_embodiedgen_asset_to_anysim( class AssetConverterBase(ABC): - """Converter abstract base class.""" + """Abstract base class for asset converters. + + Provides context management and mesh transformation utilities. + """ @abstractmethod def convert(self, urdf_path: str, output_path: str, **kwargs) -> str: + """Convert an asset file. + + Args: + urdf_path (str): Path to input URDF file. + output_path (str): Path to output file. + **kwargs: Additional arguments. + + Returns: + str: Path to converted asset. + """ pass def transform_mesh( self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element ) -> None: - """Apply transform to the mesh based on the origin element in URDF.""" + """Apply transform to mesh based on URDF origin element. + + Args: + input_mesh (str): Path to input mesh. + output_mesh (str): Path to output mesh. + mesh_origin (ET.Element): Origin element from URDF. + """ mesh = trimesh.load(input_mesh, group_material=False) rpy = list(map(float, mesh_origin.get("rpy").split(" "))) rotation = Rotation.from_euler("xyz", rpy, degrees=False) @@ -150,14 +155,19 @@ def transform_mesh( return def __enter__(self): + """Context manager entry.""" return self def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" return False class MeshtoMJCFConverter(AssetConverterBase): - """Convert URDF files into MJCF format.""" + """Converts mesh-based URDF files to MJCF format. + + Handles geometry, materials, and asset copying. + """ def __init__( self, @@ -166,6 +176,12 @@ def __init__( self.kwargs = kwargs def _copy_asset_file(self, src: str, dst: str) -> None: + """Copies asset file if not already present. + + Args: + src (str): Source file path. + dst (str): Destination file path. + """ if os.path.exists(dst): return os.makedirs(os.path.dirname(dst), exist_ok=True) @@ -183,7 +199,19 @@ def add_geometry( material: ET.Element | None = None, is_collision: bool = False, ) -> None: - """Add geometry to the MJCF body from the URDF link.""" + """Adds geometry to MJCF body from URDF link. + + Args: + mujoco_element (ET.Element): MJCF asset element. + link (ET.Element): URDF link element. + body (ET.Element): MJCF body element. + tag (str): Tag name ("visual" or "collision"). + input_dir (str): Input directory. + output_dir (str): Output directory. + mesh_name (str): Mesh name. + material (ET.Element, optional): Material element. + is_collision (bool, optional): If True, treat as collision geometry. + """ element = link.find(tag) geometry = element.find("geometry") mesh = geometry.find("mesh") @@ -242,7 +270,20 @@ def add_materials( name: str, reflectance: float = 0.2, ) -> ET.Element: - """Add materials to the MJCF asset from the URDF link.""" + """Adds materials to MJCF asset from URDF link. + + Args: + mujoco_element (ET.Element): MJCF asset element. + link (ET.Element): URDF link element. + tag (str): Tag name. + input_dir (str): Input directory. + output_dir (str): Output directory. + name (str): Material name. + reflectance (float, optional): Reflectance value. + + Returns: + ET.Element: Material element. + """ element = link.find(tag) geometry = element.find("geometry") mesh = geometry.find("mesh") @@ -282,7 +323,12 @@ def add_materials( return material def convert(self, urdf_path: str, mjcf_path: str): - """Convert a URDF file to MJCF format.""" + """Converts a URDF file to MJCF format. + + Args: + urdf_path (str): Path to URDF file. + mjcf_path (str): Path to output MJCF file. + """ tree = ET.parse(urdf_path) root = tree.getroot() @@ -336,10 +382,22 @@ def convert(self, urdf_path: str, mjcf_path: str): class URDFtoMJCFConverter(MeshtoMJCFConverter): - """Convert URDF files with joints to MJCF format, handling transformations from joints.""" + """Converts URDF files with joints to MJCF format, handling joint transformations. + + Handles fixed joints and hierarchical body structure. + """ def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str: - """Convert a URDF file with joints to MJCF format.""" + """Converts a URDF file with joints to MJCF format. + + Args: + urdf_path (str): Path to URDF file. + mjcf_path (str): Path to output MJCF file. + **kwargs: Additional arguments. + + Returns: + str: Path to converted MJCF file. + """ tree = ET.parse(urdf_path) root = tree.getroot() @@ -423,7 +481,10 @@ def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str: class MeshtoUSDConverter(AssetConverterBase): - """Convert Mesh file from URDF into USD format.""" + """Converts mesh-based URDF files to USD format. + + Adds physics APIs and post-processes collision meshes. + """ DEFAULT_BIND_APIS = [ "MaterialBindingAPI", @@ -443,6 +504,14 @@ def __init__( simulation_app=None, **kwargs, ): + """Initializes the converter. + + Args: + force_usd_conversion (bool, optional): Force USD conversion. + make_instanceable (bool, optional): Make prims instanceable. + simulation_app (optional): Simulation app instance. + **kwargs: Additional arguments. + """ if simulation_app is not None: self.simulation_app = simulation_app @@ -458,6 +527,7 @@ def __init__( ) def __enter__(self): + """Context manager entry, launches simulation app if needed.""" from isaaclab.app import AppLauncher if not hasattr(self, "simulation_app"): @@ -476,6 +546,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit, closes simulation app if created.""" # Close the simulation app if it was created here if hasattr(self, "app_launcher") and self.exit_close: self.simulation_app.close() @@ -486,7 +557,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False def convert(self, urdf_path: str, output_file: str): - """Convert a URDF file to USD and post-process collision meshes.""" + """Converts a URDF file to USD and post-processes collision meshes. + + Args: + urdf_path (str): Path to URDF file. + output_file (str): Path to output USD file. + """ from isaaclab.sim.converters import MeshConverter, MeshConverterCfg from pxr import PhysxSchema, Sdf, Usd, UsdShade @@ -556,6 +632,11 @@ def convert(self, urdf_path: str, output_file: str): class PhysicsUSDAdder(MeshtoUSDConverter): + """Adds physics APIs and collision properties to USD assets. + + Useful for post-processing USD files for simulation. + """ + DEFAULT_BIND_APIS = [ "MaterialBindingAPI", # "PhysicsMeshCollisionAPI", @@ -566,6 +647,12 @@ class PhysicsUSDAdder(MeshtoUSDConverter): ] def convert(self, usd_path: str, output_file: str = None): + """Adds physics APIs and collision properties to a USD file. + + Args: + usd_path (str): Path to input USD file. + output_file (str, optional): Path to output USD file. + """ from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics if output_file is None: @@ -626,14 +713,18 @@ def convert(self, usd_path: str, output_file: str = None): class URDFtoUSDConverter(MeshtoUSDConverter): - """Convert URDF files into USD format. + """Converts URDF files to USD format. Args: - fix_base (bool): Whether to fix the base link. - merge_fixed_joints (bool): Whether to merge fixed joints. - make_instanceable (bool): Whether to make prims instanceable. - force_usd_conversion (bool): Force conversion to USD. - collision_from_visuals (bool): Generate collisions from visuals if not provided. + fix_base (bool, optional): Fix the base link. + merge_fixed_joints (bool, optional): Merge fixed joints. + make_instanceable (bool, optional): Make prims instanceable. + force_usd_conversion (bool, optional): Force conversion to USD. + collision_from_visuals (bool, optional): Generate collisions from visuals. + joint_drive (optional): Joint drive configuration. + rotate_wxyz (tuple[float], optional): Quaternion for rotation. + simulation_app (optional): Simulation app instance. + **kwargs: Additional arguments. """ def __init__( @@ -648,6 +739,19 @@ def __init__( simulation_app=None, **kwargs, ): + """Initializes the converter. + + Args: + fix_base (bool, optional): Fix the base link. + merge_fixed_joints (bool, optional): Merge fixed joints. + make_instanceable (bool, optional): Make prims instanceable. + force_usd_conversion (bool, optional): Force conversion to USD. + collision_from_visuals (bool, optional): Generate collisions from visuals. + joint_drive (optional): Joint drive configuration. + rotate_wxyz (tuple[float], optional): Quaternion for rotation. + simulation_app (optional): Simulation app instance. + **kwargs: Additional arguments. + """ self.usd_parms = dict( fix_base=fix_base, merge_fixed_joints=merge_fixed_joints, @@ -662,7 +766,12 @@ def __init__( self.simulation_app = simulation_app def convert(self, urdf_path: str, output_file: str): - """Convert a URDF file to USD and post-process collision meshes.""" + """Converts a URDF file to USD and post-processes collision meshes. + + Args: + urdf_path (str): Path to URDF file. + output_file (str): Path to output USD file. + """ from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom @@ -723,13 +832,36 @@ def convert(self, urdf_path: str, output_file: str): class AssetConverterFactory: - """Factory class for creating asset converters based on target and source types.""" + """Factory for creating asset converters based on target and source types. + + Example: + ```py + from embodied_gen.data.asset_converter import AssetConverterFactory + from embodied_gen.utils.enum import AssetType + + converter = AssetConverterFactory.create( + target_type=AssetType.USD, source_type=AssetType.MESH + ) + with converter: + for urdf_path, output_file in zip(urdf_paths, output_files): + converter.convert(urdf_path, output_file) + ``` + """ @staticmethod def create( target_type: AssetType, source_type: AssetType = "urdf", **kwargs ) -> AssetConverterBase: - """Create an asset converter instance based on target and source types.""" + """Creates an asset converter instance. + + Args: + target_type (AssetType): Target asset type. + source_type (AssetType, optional): Source asset type. + **kwargs: Additional arguments. + + Returns: + AssetConverterBase: Converter instance. + """ if target_type == AssetType.MJCF and source_type == AssetType.MESH: converter = MeshtoMJCFConverter(**kwargs) elif target_type == AssetType.MJCF and source_type == AssetType.URDF: @@ -751,7 +883,14 @@ def create( # target_asset_type = AssetType.USD urdf_paths = [ - "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf", + 'outputs/EmbodiedGenData/demo_assets/banana/result/banana.urdf', + 'outputs/EmbodiedGenData/demo_assets/book/result/book.urdf', + 'outputs/EmbodiedGenData/demo_assets/lamp/result/lamp.urdf', + 'outputs/EmbodiedGenData/demo_assets/mug/result/mug.urdf', + 'outputs/EmbodiedGenData/demo_assets/remote_control/result/remote_control.urdf', + "outputs/EmbodiedGenData/demo_assets/rubik's_cube/result/rubik's_cube.urdf", + 'outputs/EmbodiedGenData/demo_assets/table/result/table.urdf', + 'outputs/EmbodiedGenData/demo_assets/vase/result/vase.urdf', ] if target_asset_type == AssetType.MJCF: @@ -765,7 +904,14 @@ def create( elif target_asset_type == AssetType.USD: output_files = [ - "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd", + 'outputs/EmbodiedGenData/demo_assets/banana/usd/banana.usd', + 'outputs/EmbodiedGenData/demo_assets/book/usd/book.usd', + 'outputs/EmbodiedGenData/demo_assets/lamp/usd/lamp.usd', + 'outputs/EmbodiedGenData/demo_assets/mug/usd/mug.usd', + 'outputs/EmbodiedGenData/demo_assets/remote_control/usd/remote_control.usd', + "outputs/EmbodiedGenData/demo_assets/rubik's_cube/usd/rubik's_cube.usd", + 'outputs/EmbodiedGenData/demo_assets/table/usd/table.usd', + 'outputs/EmbodiedGenData/demo_assets/vase/usd/vase.usd', ] asset_converter = AssetConverterFactory.create( target_type=AssetType.USD, @@ -776,33 +922,33 @@ def create( for urdf_path, output_file in zip(urdf_paths, output_files): asset_converter.convert(urdf_path, output_file) - urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf" - output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd" - - asset_converter = AssetConverterFactory.create( - target_type=AssetType.USD, - source_type=AssetType.URDF, - rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis - ) - - with asset_converter: - asset_converter.convert(urdf_path, output_file) - - # Convert infinigen urdf to mjcf - urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf" - output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml" - asset_converter = AssetConverterFactory.create( - target_type=AssetType.MJCF, - source_type=AssetType.URDF, - keep_materials=["diffuse"], - ) - with asset_converter: - asset_converter.convert(urdf_path, output_file) - - # Convert infinigen usdc to physics usdc - converter = PhysicsUSDAdder() - with converter: - converter.convert( - usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc", - output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc", - ) + # urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf" + # output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd" + + # asset_converter = AssetConverterFactory.create( + # target_type=AssetType.USD, + # source_type=AssetType.URDF, + # rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis + # ) + + # with asset_converter: + # asset_converter.convert(urdf_path, output_file) + + # # Convert infinigen urdf to mjcf + # urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf" + # output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml" + # asset_converter = AssetConverterFactory.create( + # target_type=AssetType.MJCF, + # source_type=AssetType.URDF, + # keep_materials=["diffuse"], + # ) + # with asset_converter: + # asset_converter.convert(urdf_path, output_file) + + # # Convert infinigen usdc to physics usdc + # converter = PhysicsUSDAdder() + # with converter: + # converter.convert( + # usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc", + # output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc", + # ) diff --git a/embodied_gen/data/backproject_v2.py b/embodied_gen/data/backproject_v2.py index 420ee39..5908013 100644 --- a/embodied_gen/data/backproject_v2.py +++ b/embodied_gen/data/backproject_v2.py @@ -58,7 +58,16 @@ def _transform_vertices( mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False ) -> torch.Tensor: - """Transform 3D vertices using a projection matrix.""" + """Transforms 3D vertices using a projection matrix. + + Args: + mtx (torch.Tensor): Projection matrix. + pos (torch.Tensor): Vertex positions. + keepdim (bool, optional): If True, keeps the batch dimension. + + Returns: + torch.Tensor: Transformed vertices. + """ t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype) if pos.size(-1) == 3: pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1) @@ -71,7 +80,17 @@ def _transform_vertices( def _bilinear_interpolation_scattering( image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor ) -> torch.Tensor: - """Bilinear interpolation scattering for grid-based value accumulation.""" + """Performs bilinear interpolation scattering for grid-based value accumulation. + + Args: + image_h (int): Image height. + image_w (int): Image width. + coords (torch.Tensor): Normalized coordinates. + values (torch.Tensor): Values to scatter. + + Returns: + torch.Tensor: Interpolated grid. + """ device = values.device dtype = values.dtype C = values.shape[-1] @@ -135,7 +154,18 @@ def _texture_inpaint_smooth( faces: np.ndarray, uv_map: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: - """Perform texture inpainting using vertex-based color propagation.""" + """Performs texture inpainting using vertex-based color propagation. + + Args: + texture (np.ndarray): Texture image. + mask (np.ndarray): Mask image. + vertices (np.ndarray): Mesh vertices. + faces (np.ndarray): Mesh faces. + uv_map (np.ndarray): UV coordinates. + + Returns: + tuple[np.ndarray, np.ndarray]: Inpainted texture and updated mask. + """ image_h, image_w, C = texture.shape N = vertices.shape[0] @@ -231,29 +261,41 @@ def _texture_inpaint_smooth( class TextureBacker: """Texture baking pipeline for multi-view projection and fusion. - This class performs UV-based texture generation for a 3D mesh using - multi-view color images, depth, and normal information. The pipeline - includes mesh normalization and UV unwrapping, visibility-aware - back-projection, confidence-weighted texture fusion, and inpainting - of missing texture regions. + This class generates UV-based textures for a 3D mesh using multi-view images, + depth, and normal information. It includes mesh normalization, UV unwrapping, + visibility-aware back-projection, confidence-weighted fusion, and inpainting. Args: - camera_params (CameraSetting): Camera intrinsics and extrinsics used - for rendering each view. - view_weights (list[float]): A list of weights for each view, used - to blend confidence maps during texture fusion. - render_wh (tuple[int, int], optional): Resolution (width, height) for - intermediate rendering passes. Defaults to (2048, 2048). - texture_wh (tuple[int, int], optional): Output texture resolution - (width, height). Defaults to (2048, 2048). - bake_angle_thresh (int, optional): Maximum angle (in degrees) between - view direction and surface normal for projection to be considered valid. - Defaults to 75. - mask_thresh (float, optional): Threshold applied to visibility masks - during rendering. Defaults to 0.5. - smooth_texture (bool, optional): If True, apply post-processing (e.g., - blurring) to the final texture. Defaults to True. - inpaint_smooth (bool, optional): If True, apply inpainting to smooth. + camera_params (CameraSetting): Camera intrinsics and extrinsics. + view_weights (list[float]): Weights for each view in texture fusion. + render_wh (tuple[int, int], optional): Intermediate rendering resolution. + texture_wh (tuple[int, int], optional): Output texture resolution. + bake_angle_thresh (int, optional): Max angle for valid projection. + mask_thresh (float, optional): Threshold for visibility masks. + smooth_texture (bool, optional): Apply post-processing to texture. + inpaint_smooth (bool, optional): Apply inpainting smoothing. + + Example: + ```py + from embodied_gen.data.backproject_v2 import TextureBacker + from embodied_gen.data.utils import CameraSetting + import trimesh + from PIL import Image + + camera_params = CameraSetting( + num_images=6, + elevation=[20, -10], + distance=5, + resolution_hw=(2048,2048), + fov=math.radians(30), + device='cuda', + ) + view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02] + mesh = trimesh.load('mesh.obj') + images = [Image.open(f'view_{i}.png') for i in range(6)] + texture_backer = TextureBacker(camera_params, view_weights) + textured_mesh = texture_backer(images, mesh, 'output.obj') + ``` """ def __init__( @@ -283,6 +325,12 @@ def __init__( ) def _lazy_init_render(self, camera_params, mask_thresh): + """Lazily initializes the renderer. + + Args: + camera_params (CameraSetting): Camera settings. + mask_thresh (float): Mask threshold. + """ if self.renderer is None: camera = init_kal_camera(camera_params) mv = camera.view_matrix() # (n 4 4) world2cam @@ -301,6 +349,14 @@ def _lazy_init_render(self, camera_params, mask_thresh): ) def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: + """Normalizes mesh and unwraps UVs. + + Args: + mesh (trimesh.Trimesh): Input mesh. + + Returns: + trimesh.Trimesh: Mesh with normalized vertices and UVs. + """ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) self.scale, self.center = scale, center @@ -318,6 +374,16 @@ def get_mesh_np_attrs( scale: float = None, center: np.ndarray = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Gets mesh attributes as numpy arrays. + + Args: + mesh (trimesh.Trimesh): Input mesh. + scale (float, optional): Scale factor. + center (np.ndarray, optional): Center offset. + + Returns: + tuple: (vertices, faces, uv_map) + """ vertices = mesh.vertices.copy() faces = mesh.faces.copy() uv_map = mesh.visual.uv.copy() @@ -331,6 +397,14 @@ def get_mesh_np_attrs( return vertices, faces, uv_map def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: + """Computes edge image from depth map. + + Args: + depth_image (torch.Tensor): Depth map. + + Returns: + torch.Tensor: Edge image. + """ depth_image_np = depth_image.cpu().numpy() depth_image_np = (depth_image_np * 255).astype(np.uint8) depth_edges = cv2.Canny(depth_image_np, 30, 80) @@ -344,6 +418,16 @@ def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: def compute_enhanced_viewnormal( self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor ) -> torch.Tensor: + """Computes enhanced view normals for mesh faces. + + Args: + mv_mtx (torch.Tensor): View matrices. + vertices (torch.Tensor): Mesh vertices. + faces (torch.Tensor): Mesh faces. + + Returns: + torch.Tensor: View normals. + """ rast, _ = self.renderer.compute_dr_raster(vertices, faces) rendered_view_normals = [] for idx in range(len(mv_mtx)): @@ -376,6 +460,18 @@ def compute_enhanced_viewnormal( def back_project( self, image, vis_mask, depth, normal, uv ) -> tuple[torch.Tensor, torch.Tensor]: + """Back-projects image and confidence to UV texture space. + + Args: + image (PIL.Image or np.ndarray): Input image. + vis_mask (torch.Tensor): Visibility mask. + depth (torch.Tensor): Depth map. + normal (torch.Tensor): Normal map. + uv (torch.Tensor): UV coordinates. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Texture and confidence map. + """ image = np.array(image) image = torch.as_tensor(image, device=self.device, dtype=torch.float32) if image.ndim == 2: @@ -418,6 +514,17 @@ def back_project( ) def _scatter_texture(self, uv, data, mask): + """Scatters data to texture using UV coordinates and mask. + + Args: + uv (torch.Tensor): UV coordinates. + data (torch.Tensor): Data to scatter. + mask (torch.Tensor): Mask for valid pixels. + + Returns: + torch.Tensor: Scattered texture. + """ + def __filter_data(data, mask): return data.view(-1, data.shape[-1])[mask] @@ -432,6 +539,15 @@ def __filter_data(data, mask): def fast_bake_texture( self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor]: + """Fuses multiple textures and confidence maps. + + Args: + textures (list[torch.Tensor]): List of textures. + confidence_maps (list[torch.Tensor]): List of confidence maps. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Fused texture and mask. + """ channel = textures[0].shape[-1] texture_merge = torch.zeros(self.texture_wh + [channel]).to( self.device @@ -451,6 +567,16 @@ def fast_bake_texture( def uv_inpaint( self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray ) -> np.ndarray: + """Inpaints missing regions in the UV texture. + + Args: + mesh (trimesh.Trimesh): Mesh. + texture (np.ndarray): Texture image. + mask (np.ndarray): Mask image. + + Returns: + np.ndarray: Inpainted texture. + """ if self.inpaint_smooth: vertices, faces, uv_map = self.get_mesh_np_attrs(mesh) texture, mask = _texture_inpaint_smooth( @@ -473,6 +599,15 @@ def compute_texture( colors: list[Image.Image], mesh: trimesh.Trimesh, ) -> trimesh.Trimesh: + """Computes the fused texture for the mesh from multi-view images. + + Args: + colors (list[Image.Image]): List of view images. + mesh (trimesh.Trimesh): Mesh to texture. + + Returns: + tuple[np.ndarray, np.ndarray]: Texture and mask. + """ self._lazy_init_render(self.camera_params, self.mask_thresh) vertices = torch.from_numpy(mesh.vertices).to(self.device).float() @@ -517,7 +652,7 @@ def __call__( Args: colors (list[Image.Image]): List of input view images. mesh (trimesh.Trimesh): Input mesh to be textured. - output_path (str): Path to save the output textured mesh (.obj or .glb). + output_path (str): Path to save the output textured mesh. Returns: trimesh.Trimesh: The textured mesh with UV and texture image. @@ -540,6 +675,11 @@ def __call__( def parse_args(): + """Parses command-line arguments for texture backprojection. + + Returns: + argparse.Namespace: Parsed arguments. + """ parser = argparse.ArgumentParser(description="Backproject texture") parser.add_argument( "--color_path", @@ -636,6 +776,16 @@ def entrypoint( imagesr_model: ImageRealESRGAN = None, **kwargs, ) -> trimesh.Trimesh: + """Entrypoint for texture backprojection from multi-view images. + + Args: + delight_model (DelightingModel, optional): Delighting model. + imagesr_model (ImageRealESRGAN, optional): Super-resolution model. + **kwargs: Additional arguments to override CLI. + + Returns: + trimesh.Trimesh: Textured mesh. + """ args = parse_args() for k, v in kwargs.items(): if hasattr(args, k) and v is not None: diff --git a/embodied_gen/data/convex_decomposer.py b/embodied_gen/data/convex_decomposer.py index 88e6084..73c4a5a 100644 --- a/embodied_gen/data/convex_decomposer.py +++ b/embodied_gen/data/convex_decomposer.py @@ -39,6 +39,22 @@ def decompose_convex_coacd( auto_scale: bool = True, scale_factor: float = 1.0, ) -> None: + """Decomposes a mesh using CoACD and saves the result. + + This function loads a mesh from a file, runs the CoACD algorithm with the + given parameters, optionally scales the resulting convex hulls to match the + original mesh's bounding box, and exports the combined result to a file. + + Args: + filename: Path to the input mesh file. + outfile: Path to save the decomposed output mesh. + params: A dictionary of parameters for the CoACD algorithm. + verbose: If True, sets the CoACD log level to 'info'. + auto_scale: If True, automatically computes a scale factor to match the + decomposed mesh's bounding box to the visual mesh's bounding box. + scale_factor: An additional scaling factor applied to the vertices of + the decomposed mesh parts. + """ coacd.set_log_level("info" if verbose else "warn") mesh = trimesh.load(filename, force="mesh") @@ -83,7 +99,38 @@ def decompose_convex_mesh( scale_factor: float = 1.005, verbose: bool = False, ) -> str: - """Decompose a mesh into convex parts using the CoACD algorithm.""" + """Decomposes a mesh into convex parts with retry logic. + + This function serves as a wrapper for `decompose_convex_coacd`, providing + explicit parameters for the CoACD algorithm and implementing a retry + mechanism. If the initial decomposition fails, it attempts again with + `preprocess_mode` set to 'on'. + + Args: + filename: Path to the input mesh file. + outfile: Path to save the decomposed output mesh. + threshold: CoACD parameter. See CoACD documentation for details. + max_convex_hull: CoACD parameter. See CoACD documentation for details. + preprocess_mode: CoACD parameter. See CoACD documentation for details. + preprocess_resolution: CoACD parameter. See CoACD documentation for details. + resolution: CoACD parameter. See CoACD documentation for details. + mcts_nodes: CoACD parameter. See CoACD documentation for details. + mcts_iterations: CoACD parameter. See CoACD documentation for details. + mcts_max_depth: CoACD parameter. See CoACD documentation for details. + pca: CoACD parameter. See CoACD documentation for details. + merge: CoACD parameter. See CoACD documentation for details. + seed: CoACD parameter. See CoACD documentation for details. + auto_scale: If True, automatically scale the output to match the input + bounding box. + scale_factor: Additional scaling factor to apply. + verbose: If True, enables detailed logging. + + Returns: + The path to the output file if decomposition is successful. + + Raises: + RuntimeError: If convex decomposition fails after all attempts. + """ coacd.set_log_level("info" if verbose else "warn") if os.path.exists(outfile): @@ -148,9 +195,37 @@ def decompose_convex_mp( verbose: bool = False, auto_scale: bool = True, ) -> str: - """Decompose a mesh into convex parts using the CoACD algorithm in a separate process. + """Decomposes a mesh into convex parts in a separate process. + + This function uses the `multiprocessing` module to run the CoACD algorithm + in a spawned subprocess. This is useful for isolating the decomposition + process to prevent potential memory leaks or crashes in the main process. + It includes a retry mechanism similar to `decompose_convex_mesh`. See https://simulately.wiki/docs/toolkits/ConvexDecomp for details. + + Args: + filename: Path to the input mesh file. + outfile: Path to save the decomposed output mesh. + threshold: CoACD parameter. + max_convex_hull: CoACD parameter. + preprocess_mode: CoACD parameter. + preprocess_resolution: CoACD parameter. + resolution: CoACD parameter. + mcts_nodes: CoACD parameter. + mcts_iterations: CoACD parameter. + mcts_max_depth: CoACD parameter. + pca: CoACD parameter. + merge: CoACD parameter. + seed: CoACD parameter. + verbose: If True, enables detailed logging in the subprocess. + auto_scale: If True, automatically scale the output. + + Returns: + The path to the output file if decomposition is successful. + + Raises: + RuntimeError: If convex decomposition fails after all attempts. """ params = dict( threshold=threshold, diff --git a/embodied_gen/data/differentiable_render.py b/embodied_gen/data/differentiable_render.py index fdd5a26..52a8406 100644 --- a/embodied_gen/data/differentiable_render.py +++ b/embodied_gen/data/differentiable_render.py @@ -66,6 +66,14 @@ def create_mp4_from_images( fps: int = 10, prompt: str = None, ): + """Creates an MP4 video from a list of images. + + Args: + images (list[np.ndarray]): List of images as numpy arrays. + output_path (str): Path to save the MP4 file. + fps (int, optional): Frames per second. Defaults to 10. + prompt (str, optional): Optional text prompt overlay. + """ font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 font_thickness = 1 @@ -96,6 +104,13 @@ def create_mp4_from_images( def create_gif_from_images( images: list[np.ndarray], output_path: str, fps: int = 10 ) -> None: + """Creates a GIF animation from a list of images. + + Args: + images (list[np.ndarray]): List of images as numpy arrays. + output_path (str): Path to save the GIF file. + fps (int, optional): Frames per second. Defaults to 10. + """ pil_images = [] for image in images: image = image.clip(min=0, max=1) @@ -116,32 +131,47 @@ def create_gif_from_images( class ImageRender(object): - """A differentiable mesh renderer supporting multi-view rendering. + """Differentiable mesh renderer supporting multi-view rendering. - This class wraps a differentiable rasterization using `nvdiffrast` to - render mesh geometry to various maps (normal, depth, alpha, albedo, etc.). + This class wraps differentiable rasterization using `nvdiffrast` to render mesh + geometry to various maps (normal, depth, alpha, albedo, etc.) and supports + saving images and videos. Args: - render_items (list[RenderItems]): A list of rendering targets to - generate (e.g., IMAGE, DEPTH, NORMAL, etc.). - camera_params (CameraSetting): The camera parameters for rendering, - including intrinsic and extrinsic matrices. - recompute_vtx_normal (bool, optional): If True, recomputes - vertex normals from the mesh geometry. Defaults to True. - with_mtl (bool, optional): Whether to load `.mtl` material files - for meshes. Defaults to False. - gen_color_gif (bool, optional): Generate a GIF of rendered - color images. Defaults to False. - gen_color_mp4 (bool, optional): Generate an MP4 video of rendered - color images. Defaults to False. - gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of - view-space normals. Defaults to False. - gen_glonormal_mp4 (bool, optional): Generate an MP4 video of - global-space normals. Defaults to False. - no_index_file (bool, optional): If True, skip saving the `index.json` - summary file. Defaults to False. - light_factor (float, optional): A scalar multiplier for - PBR light intensity. Defaults to 1.0. + render_items (list[RenderItems]): List of rendering targets. + camera_params (CameraSetting): Camera parameters for rendering. + recompute_vtx_normal (bool, optional): Recompute vertex normals. Defaults to True. + with_mtl (bool, optional): Load mesh material files. Defaults to False. + gen_color_gif (bool, optional): Generate GIF of color images. Defaults to False. + gen_color_mp4 (bool, optional): Generate MP4 of color images. Defaults to False. + gen_viewnormal_mp4 (bool, optional): Generate MP4 of view-space normals. Defaults to False. + gen_glonormal_mp4 (bool, optional): Generate MP4 of global-space normals. Defaults to False. + no_index_file (bool, optional): Skip saving index file. Defaults to False. + light_factor (float, optional): PBR light intensity multiplier. Defaults to 1.0. + + Example: + ```py + from embodied_gen.data.differentiable_render import ImageRender + from embodied_gen.data.utils import CameraSetting + from embodied_gen.utils.enum import RenderItems + + camera_params = CameraSetting( + num_images=6, + elevation=[20, -10], + distance=5, + resolution_hw=(512,512), + fov=math.radians(30), + device='cuda', + ) + render_items = [RenderItems.IMAGE.value, RenderItems.DEPTH.value] + renderer = ImageRender( + render_items, + camera_params, + with_mtl=args.with_mtl, + gen_color_mp4=True, + ) + renderer.render_mesh(mesh_path='mesh.obj', output_root='./renders') + ``` """ def __init__( @@ -198,6 +228,14 @@ def render_mesh( uuid: Union[str, List[str]] = None, prompts: List[str] = None, ) -> None: + """Renders one or more meshes and saves outputs. + + Args: + mesh_path (Union[str, List[str]]): Path(s) to mesh files. + output_root (str): Directory to save outputs. + uuid (Union[str, List[str]], optional): Unique IDs for outputs. + prompts (List[str], optional): Text prompts for videos. + """ mesh_path = as_list(mesh_path) if uuid is None: uuid = [os.path.basename(p).split(".")[0] for p in mesh_path] @@ -227,18 +265,15 @@ def render_mesh( def __call__( self, mesh_path: str, output_dir: str, prompt: str = None ) -> dict[str, str]: - """Render a single mesh and return paths to the rendered outputs. - - Processes the input mesh, renders multiple modalities (e.g., normals, - depth, albedo), and optionally saves video or image sequences. + """Renders a single mesh and returns output paths. Args: - mesh_path (str): Path to the mesh file (.obj/.glb). - output_dir (str): Directory to save rendered outputs. - prompt (str, optional): Optional caption prompt for MP4 metadata. + mesh_path (str): Path to mesh file. + output_dir (str): Directory to save outputs. + prompt (str, optional): Caption prompt for MP4 metadata. Returns: - dict[str, str]: A mapping render types to the saved image paths. + dict[str, str]: Mapping of render types to saved image paths. """ try: mesh = import_kaolin_mesh(mesh_path, self.with_mtl) diff --git a/embodied_gen/data/mesh_operator.py b/embodied_gen/data/mesh_operator.py index 4954900..893e459 100644 --- a/embodied_gen/data/mesh_operator.py +++ b/embodied_gen/data/mesh_operator.py @@ -16,17 +16,13 @@ import logging -import multiprocessing as mp -import os from typing import Tuple, Union -import coacd import igraph import numpy as np import pyvista as pv import spaces import torch -import trimesh import utils3d from pymeshfix import _meshfix from tqdm import tqdm diff --git a/embodied_gen/envs/pick_embodiedgen.py b/embodied_gen/envs/pick_embodiedgen.py index b654bcc..a44e5f1 100644 --- a/embodied_gen/envs/pick_embodiedgen.py +++ b/embodied_gen/envs/pick_embodiedgen.py @@ -51,6 +51,33 @@ @register_env("PickEmbodiedGen-v1", max_episode_steps=100) class PickEmbodiedGen(BaseEnv): + """PickEmbodiedGen as gym env example for object pick-and-place tasks. + + This environment simulates a robot interacting with 3D assets in the + embodiedgen generated scene in SAPIEN. It supports multi-environment setups, + dynamic reconfiguration, and hybrid rendering with 3D Gaussian Splatting. + + Example: + Use `gym.make` to create the `PickEmbodiedGen-v1` parallel environment. + ```python + import gymnasium as gym + env = gym.make( + "PickEmbodiedGen-v1", + num_envs=cfg.num_envs, + render_mode=cfg.render_mode, + enable_shadow=cfg.enable_shadow, + layout_file=cfg.layout_file, + control_mode=cfg.control_mode, + camera_cfg=dict( + camera_eye=cfg.camera_eye, + camera_target_pt=cfg.camera_target_pt, + image_hw=cfg.image_hw, + fovy_deg=cfg.fovy_deg, + ), + ) + ``` + """ + SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"] goal_thresh = 0.0 @@ -63,6 +90,19 @@ def __init__( reconfiguration_freq: int = None, **kwargs, ): + """Initializes the PickEmbodiedGen environment. + + Args: + *args: Variable length argument list for the base class. + robot_uids: The robot(s) to use in the environment. + robot_init_qpos_noise: Noise added to the robot's initial joint + positions. + num_envs: The number of parallel environments to create. + reconfiguration_freq: How often to reconfigure the scene. If None, + it is set based on num_envs. + **kwargs: Additional keyword arguments for environment setup, + including layout_file, replace_objs, enable_grasp, etc. + """ self.robot_init_qpos_noise = robot_init_qpos_noise if reconfiguration_freq is None: if num_envs == 1: @@ -116,6 +156,22 @@ def __init__( def init_env_layouts( layout_file: str, num_envs: int, replace_objs: bool ) -> list[LayoutInfo]: + """Initializes and saves layout files for each environment instance. + + For each environment, this method creates a layout configuration. If + `replace_objs` is True, it generates new object placements for each + subsequent environment. The generated layouts are saved as new JSON + files. + + Args: + layout_file: Path to the base layout JSON file. + num_envs: The number of environments to create layouts for. + replace_objs: If True, generates new object placements for each + environment after the first one using BFS placement. + + Returns: + A list of file paths to the generated layout for each environment. + """ layouts = [] for env_idx in range(num_envs): if replace_objs and env_idx > 0: @@ -136,6 +192,18 @@ def init_env_layouts( def compute_robot_init_pose( layouts: list[str], num_envs: int, z_offset: float = 0.0 ) -> list[list[float]]: + """Computes the initial pose for the robot in each environment. + + Args: + layouts: A list of file paths to the environment layouts. + num_envs: The number of environments. + z_offset: An optional vertical offset to apply to the robot's + position to prevent collisions. + + Returns: + A list of initial poses ([x, y, z, qw, qx, qy, qz]) for the robot + in each environment. + """ robot_pose = [] for env_idx in range(num_envs): layout = json.load(open(layouts[env_idx], "r")) @@ -148,6 +216,11 @@ def compute_robot_init_pose( @property def _default_sim_config(self): + """Returns the default simulation configuration. + + Returns: + The default simulation configuration object. + """ return SimConfig( scene_config=SceneConfig( solver_position_iterations=30, @@ -163,6 +236,11 @@ def _default_sim_config(self): @property def _default_sensor_configs(self): + """Returns the default sensor configurations for the agent. + + Returns: + A list containing the default camera configuration. + """ pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) return [ @@ -171,6 +249,11 @@ def _default_sensor_configs(self): @property def _default_human_render_camera_configs(self): + """Returns the default camera configuration for human-friendly rendering. + + Returns: + The default camera configuration for the renderer. + """ pose = sapien_utils.look_at( eye=self.camera_cfg["camera_eye"], target=self.camera_cfg["camera_target_pt"], @@ -187,10 +270,24 @@ def _default_human_render_camera_configs(self): ) def _load_agent(self, options: dict): + """Loads the agent (robot) and a ground plane into the scene. + + Args: + options: A dictionary of options for loading the agent. + """ self.ground = build_ground(self.scene) super()._load_agent(options, sapien.Pose(p=[-10, 0, 10])) def _load_scene(self, options: dict): + """Loads all assets, objects, and the goal site into the scene. + + This method iterates through the layouts for each environment, loads the + specified assets, and adds them to the simulation. It also creates a + kinematic sphere to represent the goal site. + + Args: + options: A dictionary of options for loading the scene. + """ all_objects = [] logger.info(f"Loading EmbodiedGen assets...") for env_idx in range(self.num_envs): @@ -222,6 +319,15 @@ def _load_scene(self, options: dict): self._hidden_objects.append(self.goal_site) def _initialize_episode(self, env_idx: torch.Tensor, options: dict): + """Initializes an episode for a given set of environments. + + This method sets the goal position, resets the robot's joint positions + with optional noise, and sets its root pose. + + Args: + env_idx: A tensor of environment indices to initialize. + options: A dictionary of options for initialization. + """ with torch.device(self.device): b = len(env_idx) goal_xyz = torch.zeros((b, 3)) @@ -256,6 +362,21 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict): def render_gs3d_images( self, layouts: list[str], num_envs: int, init_quat: list[float] ) -> dict[str, np.ndarray]: + """Renders background images using a pre-trained Gaussian Splatting model. + + This method pre-renders the static background for each environment from + the perspective of all cameras to be used for hybrid rendering. + + Args: + layouts: A list of file paths to the environment layouts. + num_envs: The number of environments. + init_quat: An initial quaternion to orient the Gaussian Splatting + model. + + Returns: + A dictionary mapping a unique key (e.g., 'camera-env_idx') to the + rendered background image as a numpy array. + """ sim_coord_align = ( torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device) ) @@ -293,6 +414,15 @@ def render_gs3d_images( return bg_images def render(self): + """Renders the environment based on the configured render_mode. + + Raises: + RuntimeError: If `render_mode` is not set. + NotImplementedError: If the `render_mode` is not supported. + + Returns: + The rendered output, which varies depending on the render mode. + """ if self.render_mode is None: raise RuntimeError("render_mode is not set.") if self.render_mode == "human": @@ -315,6 +445,17 @@ def render(self): def render_rgb_array( self, camera_name: str = None, return_alpha: bool = False ): + """Renders an RGB image from the human-facing render camera. + + Args: + camera_name: The name of the camera to render from. If None, uses + all human render cameras. + return_alpha: Whether to include the alpha channel in the output. + + Returns: + A numpy array representing the rendered image(s). If multiple + cameras are used, the images are tiled. + """ for obj in self._hidden_objects: obj.show_visual() self.scene.update_render( @@ -335,6 +476,11 @@ def render_rgb_array( return tile_images(images) def render_sensors(self): + """Renders images from all on-board sensor cameras. + + Returns: + A tiled image of all sensor outputs as a numpy array. + """ images = [] sensor_images = self.get_sensor_images() for image in sensor_images.values(): @@ -343,6 +489,14 @@ def render_sensors(self): return tile_images(images) def hybrid_render(self): + """Renders a hybrid image by blending simulated foreground with a background. + + The foreground is rendered with an alpha channel and then blended with + the pre-rendered Gaussian Splatting background image. + + Returns: + A torch tensor of the final blended RGB images. + """ fg_images = self.render_rgb_array( return_alpha=True ) # (n_env, h, w, 3) @@ -362,6 +516,16 @@ def hybrid_render(self): return images[..., :3] def evaluate(self): + """Evaluates the current state of the environment. + + Checks for task success criteria such as whether the object is grasped, + placed at the goal, and if the robot is static. + + Returns: + A dictionary containing boolean tensors for various success + metrics, including 'is_grasped', 'is_obj_placed', and overall + 'success'. + """ obj_to_goal_pos = ( self.obj.pose.p ) # self.goal_site.pose.p - self.obj.pose.p @@ -381,10 +545,31 @@ def evaluate(self): ) def _get_obs_extra(self, info: dict): + """Gets extra information for the observation dictionary. + + Args: + info: A dictionary containing evaluation information. + + Returns: + An empty dictionary, as no extra observations are added. + """ return dict() def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict): + """Computes a dense reward for the current step. + + The reward is a composite of reaching, grasping, placing, and + maintaining a static final pose. + + Args: + obs: The current observation. + action: The action taken in the current step. + info: A dictionary containing evaluation information from `evaluate()`. + + Returns: + A tensor containing the dense reward for each environment. + """ tcp_to_obj_dist = torch.linalg.norm( self.obj.pose.p - self.agent.tcp.pose.p, axis=1 ) @@ -417,4 +602,14 @@ def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict): def compute_normalized_dense_reward( self, obs: any, action: torch.Tensor, info: dict ): + """Computes a dense reward normalized to be between 0 and 1. + + Args: + obs: The current observation. + action: The action taken in the current step. + info: A dictionary containing evaluation information from `evaluate()`. + + Returns: + A tensor containing the normalized dense reward for each environment. + """ return self.compute_dense_reward(obs=obs, action=action, info=info) / 6 diff --git a/embodied_gen/models/delight_model.py b/embodied_gen/models/delight_model.py index 14abb4c..9be7bbb 100644 --- a/embodied_gen/models/delight_model.py +++ b/embodied_gen/models/delight_model.py @@ -40,7 +40,7 @@ class DelightingModel(object): """A model to remove the lighting in image space. This model is encapsulated based on the Hunyuan3D-Delight model - from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa + from `https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0` # noqa Attributes: image_guide_scale (float): Weight of image guidance in diffusion process. diff --git a/embodied_gen/models/image_comm_model.py b/embodied_gen/models/image_comm_model.py index 7a8c30c..a04364d 100644 --- a/embodied_gen/models/image_comm_model.py +++ b/embodied_gen/models/image_comm_model.py @@ -38,26 +38,61 @@ class BasePipelineLoader(ABC): + """Abstract base class for loading Hugging Face image generation pipelines. + + Attributes: + device (str): Device to load the pipeline on. + + Methods: + load(): Loads and returns the pipeline. + """ + def __init__(self, device="cuda"): self.device = device @abstractmethod def load(self): + """Load and return the pipeline instance.""" pass class BasePipelineRunner(ABC): + """Abstract base class for running image generation pipelines. + + Attributes: + pipe: The loaded pipeline. + + Methods: + run(prompt, **kwargs): Runs the pipeline with a prompt. + """ + def __init__(self, pipe): self.pipe = pipe @abstractmethod def run(self, prompt: str, **kwargs) -> Image.Image: + """Run the pipeline with the given prompt. + + Args: + prompt (str): Text prompt for image generation. + **kwargs: Additional pipeline arguments. + + Returns: + Image.Image: Generated image(s). + """ pass # ===== SD3.5-medium ===== class SD35Loader(BasePipelineLoader): + """Loader for Stable Diffusion 3.5 medium pipeline.""" + def load(self): + """Load the Stable Diffusion 3.5 medium pipeline. + + Returns: + StableDiffusion3Pipeline: Loaded pipeline. + """ pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, @@ -70,12 +105,25 @@ def load(self): class SD35Runner(BasePipelineRunner): + """Runner for Stable Diffusion 3.5 medium pipeline.""" + def run(self, prompt: str, **kwargs) -> Image.Image: + """Generate images using Stable Diffusion 3.5 medium. + + Args: + prompt (str): Text prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe(prompt=prompt, **kwargs).images # ===== Cosmos2 ===== class CosmosLoader(BasePipelineLoader): + """Loader for Cosmos2 text-to-image pipeline.""" + def __init__( self, model_id="nvidia/Cosmos-Predict2-2B-Text2Image", @@ -87,6 +135,8 @@ def __init__( self.local_dir = local_dir def _patch(self): + """Patch model and processor for optimized loading.""" + def patch_model(cls): orig = cls.from_pretrained @@ -110,6 +160,11 @@ def new(*args, **kwargs): patch_processor(SiglipProcessor) def load(self): + """Load the Cosmos2 text-to-image pipeline. + + Returns: + Cosmos2TextToImagePipeline: Loaded pipeline. + """ self._patch() snapshot_download( repo_id=self.model_id, @@ -141,7 +196,19 @@ def load(self): class CosmosRunner(BasePipelineRunner): + """Runner for Cosmos2 text-to-image pipeline.""" + def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: + """Generate images using Cosmos2 pipeline. + + Args: + prompt (str): Text prompt. + negative_prompt (str, optional): Negative prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe( prompt=prompt, negative_prompt=negative_prompt, **kwargs ).images @@ -149,7 +216,14 @@ def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: # ===== Kolors ===== class KolorsLoader(BasePipelineLoader): + """Loader for Kolors pipeline.""" + def load(self): + """Load the Kolors pipeline. + + Returns: + KolorsPipeline: Loaded pipeline. + """ pipe = KolorsPipeline.from_pretrained( "Kwai-Kolors/Kolors-diffusers", torch_dtype=torch.float16, @@ -164,13 +238,31 @@ def load(self): class KolorsRunner(BasePipelineRunner): + """Runner for Kolors pipeline.""" + def run(self, prompt: str, **kwargs) -> Image.Image: + """Generate images using Kolors pipeline. + + Args: + prompt (str): Text prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe(prompt=prompt, **kwargs).images # ===== Flux ===== class FluxLoader(BasePipelineLoader): + """Loader for Flux pipeline.""" + def load(self): + """Load the Flux pipeline. + + Returns: + FluxPipeline: Loaded pipeline. + """ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 @@ -182,20 +274,50 @@ def load(self): class FluxRunner(BasePipelineRunner): + """Runner for Flux pipeline.""" + def run(self, prompt: str, **kwargs) -> Image.Image: + """Generate images using Flux pipeline. + + Args: + prompt (str): Text prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe(prompt=prompt, **kwargs).images # ===== Chroma ===== class ChromaLoader(BasePipelineLoader): + """Loader for Chroma pipeline.""" + def load(self): + """Load the Chroma pipeline. + + Returns: + ChromaPipeline: Loaded pipeline. + """ return ChromaPipeline.from_pretrained( "lodestones/Chroma", torch_dtype=torch.bfloat16 ).to(self.device) class ChromaRunner(BasePipelineRunner): + """Runner for Chroma pipeline.""" + def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: + """Generate images using Chroma pipeline. + + Args: + prompt (str): Text prompt. + negative_prompt (str, optional): Negative prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe( prompt=prompt, negative_prompt=negative_prompt, **kwargs ).images @@ -211,6 +333,22 @@ def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner: + """Build a Hugging Face image generation pipeline runner by name. + + Args: + name (str): Name of the pipeline (e.g., "sd35", "cosmos"). + device (str): Device to load the pipeline on. + + Returns: + BasePipelineRunner: Pipeline runner instance. + + Example: + ```py + from embodied_gen.models.image_comm_model import build_hf_image_pipeline + runner = build_hf_image_pipeline("sd35") + images = runner.run(prompt="A robot holding a sign that says 'Hello'") + ``` + """ if name not in PIPELINE_REGISTRY: raise ValueError(f"Unsupported model: {name}") loader_cls, runner_cls = PIPELINE_REGISTRY[name] diff --git a/embodied_gen/models/layout.py b/embodied_gen/models/layout.py index 9613269..8edc279 100644 --- a/embodied_gen/models/layout.py +++ b/embodied_gen/models/layout.py @@ -376,6 +376,21 @@ class LayoutDesigner(object): + """A class for querying GPT-based scene layout reasoning and formatting responses. + + Attributes: + prompt (str): The system prompt for GPT. + verbose (bool): Whether to log responses. + gpt_client (GPTclient): The GPT client instance. + + Methods: + query(prompt, params): Query GPT with a prompt and parameters. + format_response(response): Parse and clean JSON response. + format_response_repair(response): Repair and parse JSON response. + save_output(output, save_path): Save output to file. + __call__(prompt, save_path, params): Query and process output. + """ + def __init__( self, gpt_client: GPTclient, @@ -387,6 +402,15 @@ def __init__( self.gpt_client = gpt_client def query(self, prompt: str, params: dict = None) -> str: + """Query GPT with the system prompt and user prompt. + + Args: + prompt (str): User prompt. + params (dict, optional): GPT parameters. + + Returns: + str: GPT response. + """ full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\"" response = self.gpt_client.query( @@ -400,6 +424,17 @@ def query(self, prompt: str, params: dict = None) -> str: return response def format_response(self, response: str) -> dict: + """Format and parse GPT response as JSON. + + Args: + response (str): Raw GPT response. + + Returns: + dict: Parsed JSON output. + + Raises: + json.JSONDecodeError: If parsing fails. + """ cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip()) try: output = json.loads(cleaned) @@ -411,9 +446,23 @@ def format_response(self, response: str) -> dict: return output def format_response_repair(self, response: str) -> dict: + """Repair and parse possibly broken JSON response. + + Args: + response (str): Raw GPT response. + + Returns: + dict: Parsed JSON output. + """ return json_repair.loads(response) def save_output(self, output: dict, save_path: str) -> None: + """Save output dictionary to a file. + + Args: + output (dict): Output data. + save_path (str): Path to save the file. + """ os.makedirs(os.path.dirname(save_path), exist_ok=True) with open(save_path, 'w') as f: json.dump(output, f, indent=4) @@ -421,6 +470,16 @@ def save_output(self, output: dict, save_path: str) -> None: def __call__( self, prompt: str, save_path: str = None, params: dict = None ) -> dict | str: + """Query GPT and process the output. + + Args: + prompt (str): User prompt. + save_path (str, optional): Path to save output. + params (dict, optional): GPT parameters. + + Returns: + dict | str: Output data. + """ response = self.query(prompt, params=params) output = self.format_response_repair(response) self.save_output(output, save_path) if save_path else None @@ -442,6 +501,29 @@ def __call__( def build_scene_layout( task_desc: str, output_path: str = None, gpt_params: dict = None ) -> LayoutInfo: + """Build a 3D scene layout from a natural language task description. + + This function uses GPT-based reasoning to generate a structured scene layout, + including object hierarchy, spatial relations, and style descriptions. + + Args: + task_desc (str): Natural language description of the robotic task. + output_path (str, optional): Path to save the visualized scene tree. + gpt_params (dict, optional): Parameters for GPT queries. + + Returns: + LayoutInfo: Structured layout information for the scene. + + Example: + ```py + from embodied_gen.models.layout import build_scene_layout + layout_info = build_scene_layout( + task_desc="Put the apples on the table on the plate", + output_path="outputs/scene_tree.jpg", + ) + print(layout_info) + ``` + """ layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params) layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params) object_mapping = Scene3DItemEnum.object_mapping(layout_relation) diff --git a/embodied_gen/models/segment_model.py b/embodied_gen/models/segment_model.py index ab92c0c..6f54cfa 100644 --- a/embodied_gen/models/segment_model.py +++ b/embodied_gen/models/segment_model.py @@ -48,12 +48,19 @@ class SAMRemover(object): - """Loading SAM models and performing background removal on images. + """Loads SAM models and performs background removal on images. Attributes: checkpoint (str): Path to the model checkpoint. - model_type (str): Type of the SAM model to load (default: "vit_h"). - area_ratio (float): Area ratio filtering small connected components. + model_type (str): Type of the SAM model to load. + area_ratio (float): Area ratio for filtering small connected components. + + Example: + ```py + from embodied_gen.models.segment_model import SAMRemover + remover = SAMRemover(model_type="vit_h") + result = remover("input.jpg", "output.png") + ``` """ def __init__( @@ -78,6 +85,14 @@ def __init__( self.mask_generator = self._load_sam_model(checkpoint) def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator: + """Loads the SAM model and returns a mask generator. + + Args: + checkpoint (str): Path to model checkpoint. + + Returns: + SamAutomaticMaskGenerator: Mask generator instance. + """ sam = sam_model_registry[self.model_type](checkpoint=checkpoint) sam.to(device=self.device) @@ -89,13 +104,11 @@ def __call__( """Removes the background from an image using the SAM model. Args: - image (Union[str, Image.Image, np.ndarray]): Input image, - can be a file path, PIL Image, or numpy array. - save_path (str): Path to save the output image (default: None). + image (Union[str, Image.Image, np.ndarray]): Input image. + save_path (str, optional): Path to save the output image. Returns: - Image.Image: The image with background removed, - including an alpha channel. + Image.Image: Image with background removed (RGBA). """ # Convert input to numpy array if isinstance(image, str): @@ -134,6 +147,15 @@ def __call__( class SAMPredictor(object): + """Loads SAM models and predicts segmentation masks from user points. + + Args: + checkpoint (str, optional): Path to model checkpoint. + model_type (str, optional): SAM model type. + binary_thresh (float, optional): Threshold for binary mask. + device (str, optional): Device for inference. + """ + def __init__( self, checkpoint: str = None, @@ -157,12 +179,28 @@ def __init__( self.binary_thresh = binary_thresh def _load_sam_model(self, checkpoint: str) -> SamPredictor: + """Loads the SAM model and returns a predictor. + + Args: + checkpoint (str): Path to model checkpoint. + + Returns: + SamPredictor: Predictor instance. + """ sam = sam_model_registry[self.model_type](checkpoint=checkpoint) sam.to(device=self.device) return SamPredictor(sam) def preprocess_image(self, image: Image.Image) -> np.ndarray: + """Preprocesses input image for SAM prediction. + + Args: + image (Image.Image): Input image. + + Returns: + np.ndarray: Preprocessed image array. + """ if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): @@ -178,6 +216,15 @@ def generate_masks( image: np.ndarray, selected_points: list[list[int]], ) -> np.ndarray: + """Generates segmentation masks from selected points. + + Args: + image (np.ndarray): Input image array. + selected_points (list[list[int]]): List of points and labels. + + Returns: + list[tuple[np.ndarray, str]]: List of masks and names. + """ if len(selected_points) == 0: return [] @@ -220,6 +267,15 @@ def generate_masks( def get_segmented_image( self, image: np.ndarray, masks: list[tuple[np.ndarray, str]] ) -> Image.Image: + """Combines masks and returns segmented image with alpha channel. + + Args: + image (np.ndarray): Input image array. + masks (list[tuple[np.ndarray, str]]): List of masks. + + Returns: + Image.Image: Segmented RGBA image. + """ seg_image = Image.fromarray(image, mode="RGB") alpha_channel = np.zeros( (seg_image.height, seg_image.width), dtype=np.uint8 @@ -241,6 +297,15 @@ def __call__( image: Union[str, Image.Image, np.ndarray], selected_points: list[list[int]], ) -> Image.Image: + """Segments image using selected points. + + Args: + image (Union[str, Image.Image, np.ndarray]): Input image. + selected_points (list[list[int]]): List of points and labels. + + Returns: + Image.Image: Segmented RGBA image. + """ image = self.preprocess_image(image) self.predictor.set_image(image) masks = self.generate_masks(image, selected_points) @@ -249,12 +314,32 @@ def __call__( class RembgRemover(object): + """Removes background from images using the rembg library. + + Example: + ```py + from embodied_gen.models.segment_model import RembgRemover + remover = RembgRemover() + result = remover("input.jpg", "output.png") + ``` + """ + def __init__(self): + """Initializes the RembgRemover.""" self.rembg_session = rembg.new_session("u2net") def __call__( self, image: Union[str, Image.Image, np.ndarray], save_path: str = None ) -> Image.Image: + """Removes background from an image. + + Args: + image (Union[str, Image.Image, np.ndarray]): Input image. + save_path (str, optional): Path to save the output image. + + Returns: + Image.Image: Image with background removed (RGBA). + """ if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): @@ -271,7 +356,18 @@ def __call__( class BMGG14Remover(object): + """Removes background using the RMBG-1.4 segmentation model. + + Example: + ```py + from embodied_gen.models.segment_model import BMGG14Remover + remover = BMGG14Remover() + result = remover("input.jpg", "output.png") + ``` + """ + def __init__(self) -> None: + """Initializes the BMGG14Remover.""" self.model = pipeline( "image-segmentation", model="briaai/RMBG-1.4", @@ -281,6 +377,15 @@ def __init__(self) -> None: def __call__( self, image: Union[str, Image.Image, np.ndarray], save_path: str = None ): + """Removes background from an image. + + Args: + image (Union[str, Image.Image, np.ndarray]): Input image. + save_path (str, optional): Path to save the output image. + + Returns: + Image.Image: Image with background removed. + """ if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): @@ -299,6 +404,16 @@ def __call__( def invert_rgba_pil( image: Image.Image, mask: Image.Image, save_path: str = None ) -> Image.Image: + """Inverts the alpha channel of an RGBA image using a mask. + + Args: + image (Image.Image): Input RGB image. + mask (Image.Image): Mask image for alpha inversion. + save_path (str, optional): Path to save the output image. + + Returns: + Image.Image: RGBA image with inverted alpha. + """ mask = (255 - np.array(mask))[..., None] image_array = np.concatenate([np.array(image), mask], axis=-1) inverted_image = Image.fromarray(image_array, "RGBA") @@ -318,6 +433,20 @@ def get_segmented_image_by_agent( save_path: str = None, mode: Literal["loose", "strict"] = "loose", ) -> Image.Image: + """Segments an image using SAM and rembg, with quality checking. + + Args: + image (Image.Image): Input image. + sam_remover (SAMRemover): SAM-based remover. + rbg_remover (RembgRemover): rembg-based remover. + seg_checker (ImageSegChecker, optional): Quality checker. + save_path (str, optional): Path to save the output image. + mode (Literal["loose", "strict"], optional): Segmentation mode. + + Returns: + Image.Image: Segmented RGBA image. + """ + def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool: if seg_checker is None: return True diff --git a/embodied_gen/models/sr_model.py b/embodied_gen/models/sr_model.py index 40310bb..c8489d9 100644 --- a/embodied_gen/models/sr_model.py +++ b/embodied_gen/models/sr_model.py @@ -39,13 +39,38 @@ class ImageStableSR: - """Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI.""" + """Super-resolution image upscaler using Stable Diffusion x4 upscaling model. + + This class wraps the StabilityAI Stable Diffusion x4 upscaler for high-quality + image super-resolution. + + Args: + model_path (str, optional): Path or HuggingFace repo for the model. + device (str, optional): Device for inference. + + Example: + ```py + from embodied_gen.models.sr_model import ImageStableSR + from PIL import Image + + sr_model = ImageStableSR() + img = Image.open("input.png") + upscaled = sr_model(img) + upscaled.save("output.png") + ``` + """ def __init__( self, model_path: str = "stabilityai/stable-diffusion-x4-upscaler", device="cuda", ) -> None: + """Initializes the Stable Diffusion x4 upscaler. + + Args: + model_path (str, optional): Model path or repo. + device (str, optional): Device for inference. + """ from diffusers import StableDiffusionUpscalePipeline self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained( @@ -62,6 +87,16 @@ def __call__( prompt: str = "", infer_step: int = 20, ) -> Image.Image: + """Performs super-resolution on the input image. + + Args: + image (Union[Image.Image, np.ndarray]): Input image. + prompt (str, optional): Text prompt for upscaling. + infer_step (int, optional): Number of inference steps. + + Returns: + Image.Image: Upscaled image. + """ if isinstance(image, np.ndarray): image = Image.fromarray(image) @@ -86,9 +121,26 @@ class ImageRealESRGAN: Attributes: outscale (int): The output image scale factor (e.g., 2, 4). model_path (str): Path to the pre-trained model weights. + + Example: + ```py + from embodied_gen.models.sr_model import ImageRealESRGAN + from PIL import Image + + sr_model = ImageRealESRGAN(outscale=4) + img = Image.open("input.png") + upscaled = sr_model(img) + upscaled.save("output.png") + ``` """ def __init__(self, outscale: int, model_path: str = None) -> None: + """Initializes the RealESRGAN upscaler. + + Args: + outscale (int): Output scale factor. + model_path (str, optional): Path to model weights. + """ # monkey patch to support torchvision>=0.16 import torchvision from packaging import version @@ -122,6 +174,7 @@ def __init__(self, outscale: int, model_path: str = None) -> None: self.model_path = model_path def _lazy_init(self): + """Lazily initializes the RealESRGAN model.""" if self.upsampler is None: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer @@ -145,6 +198,14 @@ def _lazy_init(self): @spaces.GPU def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: + """Performs super-resolution on the input image. + + Args: + image (Union[Image.Image, np.ndarray]): Input image. + + Returns: + Image.Image: Upscaled image. + """ self._lazy_init() if isinstance(image, Image.Image): diff --git a/embodied_gen/models/text_model.py b/embodied_gen/models/text_model.py index 0807814..59167bb 100644 --- a/embodied_gen/models/text_model.py +++ b/embodied_gen/models/text_model.py @@ -60,6 +60,11 @@ def download_kolors_weights(local_dir: str = "weights/Kolors") -> None: + """Downloads Kolors model weights from HuggingFace. + + Args: + local_dir (str, optional): Local directory to store weights. + """ logger.info(f"Download kolors weights from huggingface...") os.makedirs(local_dir, exist_ok=True) subprocess.run( @@ -93,6 +98,22 @@ def build_text2img_ip_pipeline( ref_scale: float, device: str = "cuda", ) -> StableDiffusionXLPipelineIP: + """Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation. + + Args: + ckpt_dir (str): Directory containing model checkpoints. + ref_scale (float): Reference scale for IP-Adapter. + device (str, optional): Device for inference. + + Returns: + StableDiffusionXLPipelineIP: Configured pipeline. + + Example: + ```py + from embodied_gen.models.text_model import build_text2img_ip_pipeline + pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3) + ``` + """ download_kolors_weights(ckpt_dir) text_encoder = ChatGLMModel.from_pretrained( @@ -146,6 +167,21 @@ def build_text2img_pipeline( ckpt_dir: str, device: str = "cuda", ) -> StableDiffusionXLPipeline: + """Builds a Stable Diffusion XL pipeline for text-to-image generation. + + Args: + ckpt_dir (str): Directory containing model checkpoints. + device (str, optional): Device for inference. + + Returns: + StableDiffusionXLPipeline: Configured pipeline. + + Example: + ```py + from embodied_gen.models.text_model import build_text2img_pipeline + pipe = build_text2img_pipeline("weights/Kolors") + ``` + """ download_kolors_weights(ckpt_dir) text_encoder = ChatGLMModel.from_pretrained( @@ -185,6 +221,29 @@ def text2img_gen( ip_image_size: int = 512, seed: int = None, ) -> list[Image.Image]: + """Generates images from text prompts using a Stable Diffusion XL pipeline. + + Args: + prompt (str): Text prompt for image generation. + n_sample (int): Number of images to generate. + guidance_scale (float): Guidance scale for diffusion. + pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance. + ip_image (Image.Image | str, optional): Reference image for IP-Adapter. + image_wh (tuple[int, int], optional): Output image size (width, height). + infer_step (int, optional): Number of inference steps. + ip_image_size (int, optional): Size for IP-Adapter image. + seed (int, optional): Random seed. + + Returns: + list[Image.Image]: List of generated images. + + Example: + ```py + from embodied_gen.models.text_model import text2img_gen + images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5) + images[0].save("banana.png") + ``` + """ prompt = PROMPT_KAPPEND.format(object=prompt.strip()) logger.info(f"Processing prompt: {prompt}") diff --git a/embodied_gen/trainer/pono2mesh_trainer.py b/embodied_gen/trainer/pono2mesh_trainer.py index a2fc752..6f04435 100644 --- a/embodied_gen/trainer/pono2mesh_trainer.py +++ b/embodied_gen/trainer/pono2mesh_trainer.py @@ -53,26 +53,31 @@ class Pano2MeshSRPipeline: - """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement. + """Pipeline for converting panoramic RGB images into 3D mesh representations. - This class integrates several key components including: - - Depth estimation from RGB panorama - - Inpainting of missing regions under offsets - - RGB-D to mesh conversion - - Multi-view mesh repair - - 3D Gaussian Splatting (3DGS) dataset generation + This class integrates depth estimation, inpainting, mesh conversion, multi-view mesh repair, + and 3D Gaussian Splatting (3DGS) dataset generation. Args: config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters. Example: - ```python + ```py + from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline + from embodied_gen.utils.config import Pano2MeshSRConfig + + config = Pano2MeshSRConfig() pipeline = Pano2MeshSRPipeline(config) pipeline(pano_image='example.png', output_dir='./output') ``` """ def __init__(self, config: Pano2MeshSRConfig) -> None: + """Initializes the pipeline with models and camera poses. + + Args: + config (Pano2MeshSRConfig): Configuration object. + """ self.cfg = config self.device = config.device @@ -93,6 +98,7 @@ def __init__(self, config: Pano2MeshSRConfig) -> None: self.kernel = torch.from_numpy(kernel).float().to(self.device) def init_mesh_params(self) -> None: + """Initializes mesh parameters and inpaint mask.""" torch.set_default_device(self.device) self.inpaint_mask = torch.ones( (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool @@ -103,6 +109,14 @@ def init_mesh_params(self) -> None: @staticmethod def read_camera_pose_file(filepath: str) -> np.ndarray: + """Reads a camera pose file and returns the pose matrix. + + Args: + filepath (str): Path to the camera pose file. + + Returns: + np.ndarray: 4x4 camera pose matrix. + """ with open(filepath, "r") as f: values = [float(num) for line in f for num in line.split()] @@ -111,6 +125,14 @@ def read_camera_pose_file(filepath: str) -> np.ndarray: def load_camera_poses( self, trajectory_dir: str ) -> tuple[np.ndarray, list[torch.Tensor]]: + """Loads camera poses from a directory. + + Args: + trajectory_dir (str): Directory containing camera pose files. + + Returns: + tuple[np.ndarray, list[torch.Tensor]]: List of relative camera poses. + """ pose_filenames = sorted( [ fname @@ -148,6 +170,14 @@ def load_camera_poses( def load_inpaint_poses( self, poses: torch.Tensor ) -> dict[int, torch.Tensor]: + """Samples and loads poses for inpainting. + + Args: + poses (torch.Tensor): Tensor of camera poses. + + Returns: + dict[int, torch.Tensor]: Dictionary mapping indices to pose tensors. + """ inpaint_poses = dict() sampled_views = poses[:: self.cfg.inpaint_frame_stride] init_pose = torch.eye(4) @@ -162,6 +192,14 @@ def load_inpaint_poses( return inpaint_poses def project(self, world_to_cam: torch.Tensor): + """Projects the mesh to an image using the given camera pose. + + Args: + world_to_cam (torch.Tensor): World-to-camera transformation matrix. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected RGB image, inpaint mask, and depth map. + """ ( project_image, project_depth, @@ -185,6 +223,14 @@ def project(self, world_to_cam: torch.Tensor): return project_image[:3, ...], inpaint_mask, project_depth def render_pano(self, pose: torch.Tensor): + """Renders a panorama from the mesh using the given pose. + + Args: + pose (torch.Tensor): Camera pose. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: RGB panorama, depth map, and mask. + """ cubemap_list = [] for cubemap_pose in self.cubemap_w2cs: project_pose = cubemap_pose @ pose @@ -213,6 +259,15 @@ def rgbd_to_mesh( world_to_cam: torch.Tensor = None, using_distance_map: bool = True, ) -> None: + """Converts RGB-D images to mesh and updates mesh parameters. + + Args: + rgb (torch.Tensor): RGB image tensor. + depth (torch.Tensor): Depth map tensor. + inpaint_mask (torch.Tensor): Inpaint mask tensor. + world_to_cam (torch.Tensor, optional): Camera pose. + using_distance_map (bool, optional): Whether to use distance map. + """ if world_to_cam is None: world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device) @@ -239,6 +294,15 @@ def rgbd_to_mesh( def get_edge_image_by_depth( self, depth: torch.Tensor, dilate_iter: int = 1 ) -> np.ndarray: + """Computes edge image from depth map. + + Args: + depth (torch.Tensor): Depth map tensor. + dilate_iter (int, optional): Number of dilation iterations. + + Returns: + np.ndarray: Edge image. + """ if isinstance(depth, torch.Tensor): depth = depth.cpu().detach().numpy() @@ -253,6 +317,15 @@ def get_edge_image_by_depth( def mesh_repair_by_greedy_view_selection( self, pose_dict: dict[str, torch.Tensor], output_dir: str ) -> list: + """Repairs mesh by selecting views greedily and inpainting missing regions. + + Args: + pose_dict (dict[str, torch.Tensor]): Dictionary of poses for inpainting. + output_dir (str): Directory to save visualizations. + + Returns: + list: List of inpainted panoramas with poses. + """ inpainted_panos_w_pose = [] while len(pose_dict) > 0: logger.info(f"Repairing mesh left rounds {len(pose_dict)}") @@ -343,6 +416,17 @@ def inpaint_panorama( distances: torch.Tensor, pano_mask: torch.Tensor, ) -> tuple[torch.Tensor]: + """Inpaints missing regions in a panorama. + + Args: + idx (int): Index of the panorama. + colors (torch.Tensor): RGB image tensor. + distances (torch.Tensor): Distance map tensor. + pano_mask (torch.Tensor): Mask tensor. + + Returns: + tuple[torch.Tensor]: Inpainted RGB image, distances, and normals. + """ mask = (pano_mask[None, ..., None] > 0.5).float() mask = mask.permute(0, 3, 1, 2) mask = dilation(mask, kernel=self.kernel) @@ -364,6 +448,14 @@ def inpaint_panorama( def preprocess_pano( self, image: Image.Image | str ) -> tuple[torch.Tensor, torch.Tensor]: + """Preprocesses a panoramic image for mesh generation. + + Args: + image (Image.Image | str): Input image or path. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Preprocessed RGB and depth tensors. + """ if isinstance(image, str): image = Image.open(image) @@ -387,6 +479,17 @@ def preprocess_pano( def pano_to_perpective( self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float ) -> torch.Tensor: + """Converts a panoramic image to a perspective view. + + Args: + pano_image (torch.Tensor): Panoramic image tensor. + pitch (float): Pitch angle. + yaw (float): Yaw angle. + fov (float): Field of view. + + Returns: + torch.Tensor: Perspective image tensor. + """ rots = dict( roll=0, pitch=pitch, @@ -404,6 +507,14 @@ def pano_to_perpective( return perspective def pano_to_cubemap(self, pano_rgb: torch.Tensor): + """Converts a panoramic RGB image to six cubemap views. + + Args: + pano_rgb (torch.Tensor): Panoramic RGB image tensor. + + Returns: + list: List of cubemap RGB tensors. + """ # Define six canonical cube directions in (pitch, yaw) directions = [ (0, 0), @@ -424,6 +535,11 @@ def pano_to_cubemap(self, pano_rgb: torch.Tensor): return cubemaps_rgb def save_mesh(self, output_path: str) -> None: + """Saves the mesh to a file. + + Args: + output_path (str): Path to save the mesh file. + """ vertices_np = self.vertices.T.cpu().numpy() colors_np = self.colors.T.cpu().numpy() faces_np = self.faces.T.cpu().numpy() @@ -434,6 +550,14 @@ def save_mesh(self, output_path: str) -> None: mesh.export(output_path) def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray: + """Converts mesh pose to 3D Gaussian Splatting pose. + + Args: + mesh_pose (torch.Tensor): Mesh pose tensor. + + Returns: + np.ndarray: Converted pose matrix. + """ pose = mesh_pose.clone() pose[0, :] *= -1 pose[1, :] *= -1 @@ -450,6 +574,15 @@ def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray: return c2w def __call__(self, pano_image: Image.Image | str, output_dir: str): + """Runs the pipeline to generate mesh and 3DGS data from a panoramic image. + + Args: + pano_image (Image.Image | str): Input panoramic image or path. + output_dir (str): Directory to save outputs. + + Returns: + None + """ self.init_mesh_params() pano_rgb, pano_depth = self.preprocess_pano(pano_image) self.sup_pool = SupInfoPool() diff --git a/embodied_gen/utils/enum.py b/embodied_gen/utils/enum.py index f807f81..807d4da 100644 --- a/embodied_gen/utils/enum.py +++ b/embodied_gen/utils/enum.py @@ -24,11 +24,27 @@ "Scene3DItemEnum", "SpatialRelationEnum", "RobotItemEnum", + "LayoutInfo", + "AssetType", + "SimAssetMapper", ] @dataclass class RenderItems(str, Enum): + """Enumeration of render item types for 3D scenes. + + Attributes: + IMAGE: Color image. + ALPHA: Mask image. + VIEW_NORMAL: View-space normal image. + GLOBAL_NORMAL: World-space normal image. + POSITION_MAP: Position map image. + DEPTH: Depth image. + ALBEDO: Albedo image. + DIFFUSE: Diffuse image. + """ + IMAGE = "image_color" ALPHA = "image_mask" VIEW_NORMAL = "image_view_normal" @@ -41,6 +57,21 @@ class RenderItems(str, Enum): @dataclass class Scene3DItemEnum(str, Enum): + """Enumeration of 3D scene item categories. + + Attributes: + BACKGROUND: Background objects. + CONTEXT: Contextual objects. + ROBOT: Robot entity. + MANIPULATED_OBJS: Objects manipulated by the robot. + DISTRACTOR_OBJS: Distractor objects. + OTHERS: Other objects. + + Methods: + object_list(layout_relation): Returns a list of objects in the scene. + object_mapping(layout_relation): Returns a mapping from object to category. + """ + BACKGROUND = "background" CONTEXT = "context" ROBOT = "robot" @@ -50,6 +81,14 @@ class Scene3DItemEnum(str, Enum): @classmethod def object_list(cls, layout_relation: dict) -> list: + """Returns a list of objects in the scene. + + Args: + layout_relation: Dictionary mapping categories to objects. + + Returns: + List of objects in the scene. + """ return ( [ layout_relation[cls.BACKGROUND.value], @@ -61,6 +100,14 @@ def object_list(cls, layout_relation: dict) -> list: @classmethod def object_mapping(cls, layout_relation): + """Returns a mapping from object to category. + + Args: + layout_relation: Dictionary mapping categories to objects. + + Returns: + Dictionary mapping object names to their category. + """ relation_mapping = { # layout_relation[cls.ROBOT.value]: cls.ROBOT.value, layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value, @@ -84,6 +131,15 @@ def object_mapping(cls, layout_relation): @dataclass class SpatialRelationEnum(str, Enum): + """Enumeration of spatial relations for objects in a scene. + + Attributes: + ON: Objects on a surface (e.g., table). + IN: Objects in a container or room. + INSIDE: Objects inside a shelf or rack. + FLOOR: Objects on the floor. + """ + ON = "ON" # objects on the table IN = "IN" # objects in the room INSIDE = "INSIDE" # objects inside the shelf/rack @@ -92,6 +148,14 @@ class SpatialRelationEnum(str, Enum): @dataclass class RobotItemEnum(str, Enum): + """Enumeration of supported robot types. + + Attributes: + FRANKA: Franka robot. + UR5: UR5 robot. + PIPER: Piper robot. + """ + FRANKA = "franka" UR5 = "ur5" PIPER = "piper" @@ -99,6 +163,18 @@ class RobotItemEnum(str, Enum): @dataclass class LayoutInfo(DataClassJsonMixin): + """Data structure for layout information in a 3D scene. + + Attributes: + tree: Hierarchical structure of scene objects. + relation: Spatial relations between objects. + objs_desc: Descriptions of objects. + objs_mapping: Mapping from object names to categories. + assets: Asset file paths for objects. + quality: Quality information for assets. + position: Position coordinates for objects. + """ + tree: dict[str, list] relation: dict[str, str | list[str]] objs_desc: dict[str, str] = field(default_factory=dict) @@ -106,3 +182,64 @@ class LayoutInfo(DataClassJsonMixin): assets: dict[str, str] = field(default_factory=dict) quality: dict[str, str] = field(default_factory=dict) position: dict[str, list[float]] = field(default_factory=dict) + + +@dataclass +class AssetType(str): + """Enumeration for asset types. + + Supported types: + MJCF: MuJoCo XML format. + USD: Universal Scene Description format. + URDF: Unified Robot Description Format. + MESH: Mesh file format. + """ + + MJCF = "mjcf" + USD = "usd" + URDF = "urdf" + MESH = "mesh" + + +class SimAssetMapper: + """Maps simulator names to asset types. + + Provides a mapping from simulator names to their corresponding asset type. + + Example: + ```py + from embodied_gen.utils.enum import SimAssetMapper + asset_type = SimAssetMapper["isaacsim"] + print(asset_type) # Output: 'usd' + ``` + + Methods: + __class_getitem__(key): Returns the asset type for a given simulator name. + """ + + _mapping = dict( + ISAACSIM=AssetType.USD, + ISAACGYM=AssetType.URDF, + MUJOCO=AssetType.MJCF, + GENESIS=AssetType.MJCF, + SAPIEN=AssetType.URDF, + PYBULLET=AssetType.URDF, + ) + + @classmethod + def __class_getitem__(cls, key: str): + """Returns the asset type for a given simulator name. + + Args: + key: Name of the simulator. + + Returns: + AssetType corresponding to the simulator. + + Raises: + KeyError: If the simulator name is not recognized. + """ + key = key.upper() + if key.startswith("SAPIEN"): + key = "SAPIEN" + return cls._mapping[key] diff --git a/embodied_gen/utils/geometry.py b/embodied_gen/utils/geometry.py index 8352ccc..c5dbe85 100644 --- a/embodied_gen/utils/geometry.py +++ b/embodied_gen/utils/geometry.py @@ -45,13 +45,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]: - """Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw). + """Converts a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw). Args: matrix (np.ndarray): 4x4 transformation matrix. Returns: - List[float]: Pose as [x, y, z, qx, qy, qz, qw]. + list[float]: Pose as [x, y, z, qx, qy, qz, qw]. """ x, y, z = matrix[:3, 3] rot_mat = matrix[:3, :3] @@ -62,13 +62,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]: def pose_to_matrix(pose: list[float]) -> np.ndarray: - """Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix. + """Converts pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix. Args: - List[float]: Pose as [x, y, z, qx, qy, qz, qw]. + pose (list[float]): Pose as [x, y, z, qx, qy, qz, qw]. Returns: - matrix (np.ndarray): 4x4 transformation matrix. + np.ndarray: 4x4 transformation matrix. """ x, y, z, qx, qy, qz, qw = pose r = R.from_quat([qx, qy, qz, qw]) @@ -82,6 +82,16 @@ def pose_to_matrix(pose: list[float]) -> np.ndarray: def compute_xy_bbox( vertices: np.ndarray, col_x: int = 0, col_y: int = 1 ) -> list[float]: + """Computes the bounding box in XY plane for given vertices. + + Args: + vertices (np.ndarray): Vertex coordinates. + col_x (int, optional): Column index for X. + col_y (int, optional): Column index for Y. + + Returns: + list[float]: [min_x, max_x, min_y, max_y] + """ x_vals = vertices[:, col_x] y_vals = vertices[:, col_y] return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max() @@ -92,6 +102,16 @@ def has_iou_conflict( placed_boxes: list[list[float]], iou_threshold: float = 0.0, ) -> bool: + """Checks for intersection-over-union conflict between boxes. + + Args: + new_box (list[float]): New box coordinates. + placed_boxes (list[list[float]]): List of placed box coordinates. + iou_threshold (float, optional): IOU threshold. + + Returns: + bool: True if conflict exists, False otherwise. + """ new_min_x, new_max_x, new_min_y, new_max_y = new_box for min_x, max_x, min_y, max_y in placed_boxes: ix1 = max(new_min_x, min_x) @@ -105,7 +125,14 @@ def has_iou_conflict( def with_seed(seed_attr_name: str = "seed"): - """A parameterized decorator that temporarily sets the random seed.""" + """Decorator to temporarily set the random seed for reproducibility. + + Args: + seed_attr_name (str, optional): Name of the seed argument. + + Returns: + function: Decorator function. + """ def decorator(func): @wraps(func) @@ -143,6 +170,20 @@ def compute_convex_hull_path( y_axis: int = 1, z_axis: int = 2, ) -> Path: + """Computes a dense convex hull path for the top surface of a mesh. + + Args: + vertices (np.ndarray): Mesh vertices. + z_threshold (float, optional): Z threshold for top surface. + interp_per_edge (int, optional): Interpolation points per edge. + margin (float, optional): Margin for polygon buffer. + x_axis (int, optional): X axis index. + y_axis (int, optional): Y axis index. + z_axis (int, optional): Z axis index. + + Returns: + Path: Matplotlib path object for the convex hull. + """ top_vertices = vertices[ vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold ] @@ -170,6 +211,15 @@ def compute_convex_hull_path( def find_parent_node(node: str, tree: dict) -> str | None: + """Finds the parent node of a given node in a tree. + + Args: + node (str): Node name. + tree (dict): Tree structure. + + Returns: + str | None: Parent node name or None. + """ for parent, children in tree.items(): if any(child[0] == node for child in children): return parent @@ -177,6 +227,16 @@ def find_parent_node(node: str, tree: dict) -> str | None: def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool: + """Checks if at least `threshold` corners of a box are inside a hull. + + Args: + hull (Path): Convex hull path. + box (list): Box coordinates [x1, x2, y1, y2]. + threshold (int, optional): Minimum corners inside. + + Returns: + bool: True if enough corners are inside. + """ x1, x2, y1, y2 = box corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]] @@ -187,6 +247,15 @@ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool: def compute_axis_rotation_quat( axis: Literal["x", "y", "z"], angle_rad: float ) -> list[float]: + """Computes quaternion for rotation around a given axis. + + Args: + axis (Literal["x", "y", "z"]): Axis of rotation. + angle_rad (float): Rotation angle in radians. + + Returns: + list[float]: Quaternion [x, y, z, w]. + """ if axis.lower() == "x": q = Quaternion(axis=[1, 0, 0], angle=angle_rad) elif axis.lower() == "y": @@ -202,6 +271,15 @@ def compute_axis_rotation_quat( def quaternion_multiply( init_quat: list[float], rotate_quat: list[float] ) -> list[float]: + """Multiplies two quaternions. + + Args: + init_quat (list[float]): Initial quaternion [x, y, z, w]. + rotate_quat (list[float]): Rotation quaternion [x, y, z, w]. + + Returns: + list[float]: Resulting quaternion [x, y, z, w]. + """ qx, qy, qz, qw = init_quat q1 = Quaternion(w=qw, x=qx, y=qy, z=qz) qx, qy, qz, qw = rotate_quat @@ -217,7 +295,17 @@ def check_reachable( min_reach: float = 0.25, max_reach: float = 0.85, ) -> bool: - """Check if the target point is within the reachable range.""" + """Checks if the target point is within the reachable range. + + Args: + base_xyz (np.ndarray): Base position. + reach_xyz (np.ndarray): Target position. + min_reach (float, optional): Minimum reach distance. + max_reach (float, optional): Maximum reach distance. + + Returns: + bool: True if reachable, False otherwise. + """ distance = np.linalg.norm(reach_xyz - base_xyz) return min_reach < distance < max_reach @@ -238,26 +326,31 @@ def bfs_placement( robot_dim: float = 0.12, seed: int = None, ) -> LayoutInfo: - """Place objects in the layout using BFS traversal. + """Places objects in a scene layout using BFS traversal. Args: - layout_file: Path to the JSON file defining the layout structure and assets. - floor_margin: Z-offset for the background object, typically for objects placed on the floor. - beside_margin: Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails. - max_attempts: Maximum number of attempts to find a non-overlapping position for an object. - init_rpy: Initial Roll-Pitch-Yaw rotation rad applied to all object meshes to align the mesh's - coordinate system with the world's (e.g., Z-up). - rotate_objs: If True, apply a random rotation around the Z-axis for manipulated and distractor objects. - rotate_bg: If True, apply a random rotation around the Y-axis for the background object. - rotate_context: If True, apply a random rotation around the Z-axis for the context object. - limit_reach_range: If set, enforce a check that manipulated objects are within the robot's reach range, in meter. - max_orient_diff: If set, enforce a check that manipulated objects are within the robot's orientation range, in degree. - robot_dim: The approximate dimension (e.g., diameter) of the robot for box representation. - seed: Random seed for reproducible placement. + layout_file (str): Path to layout JSON file generated from `layout-cli`. + floor_margin (float, optional): Z-offset for objects placed on the floor. + beside_margin (float, optional): Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails. + max_attempts (int, optional): Max attempts for a non-overlapping placement. + init_rpy (tuple, optional): Initial rotation (rpy). + rotate_objs (bool, optional): Whether to random rotate objects. + rotate_bg (bool, optional): Whether to random rotate background. + rotate_context (bool, optional): Whether to random rotate context asset. + limit_reach_range (tuple[float, float] | None, optional): If set, enforce a check that manipulated objects are within the robot's reach range, in meter. + max_orient_diff (float | None, optional): If set, enforce a check that manipulated objects are within the robot's orientation range, in degree. + robot_dim (float, optional): The approximate robot size. + seed (int, optional): Random seed for reproducible placement. Returns: - A :class:`LayoutInfo` object containing the objects and their final computed 7D poses - ([x, y, z, qx, qy, qz, qw]). + LayoutInfo: Layout information with object poses. + + Example: + ```py + from embodied_gen.utils.geometry import bfs_placement + layout = bfs_placement("scene_layout.json", seed=42) + print(layout.position) + ``` """ layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r"))) asset_dir = os.path.dirname(layout_file) @@ -478,6 +571,13 @@ def bfs_placement( def compose_mesh_scene( layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False ) -> None: + """Composes a mesh scene from layout information and saves to file. + + Args: + layout_info (LayoutInfo): Layout information. + out_scene_path (str): Output scene file path. + with_bg (bool, optional): Include background mesh. + """ object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation) scene = trimesh.Scene() for node in layout_info.assets: @@ -505,6 +605,16 @@ def compose_mesh_scene( def compute_pinhole_intrinsics( image_w: int, image_h: int, fov_deg: float ) -> np.ndarray: + """Computes pinhole camera intrinsic matrix from image size and FOV. + + Args: + image_w (int): Image width. + image_h (int): Image height. + fov_deg (float): Field of view in degrees. + + Returns: + np.ndarray: Intrinsic matrix K. + """ fov_rad = np.deg2rad(fov_deg) fx = image_w / (2 * np.tan(fov_rad / 2)) fy = fx # assuming square pixels diff --git a/embodied_gen/utils/gpt_clients.py b/embodied_gen/utils/gpt_clients.py index de435e2..47f5ce2 100644 --- a/embodied_gen/utils/gpt_clients.py +++ b/embodied_gen/utils/gpt_clients.py @@ -45,7 +45,35 @@ class GPTclient: - """A client to interact with the GPT model via OpenAI or Azure API.""" + """A client to interact with GPT models via OpenAI or Azure API. + + Supports text and image prompts, connection checking, and configurable parameters. + + Args: + endpoint (str): API endpoint URL. + api_key (str): API key for authentication. + model_name (str, optional): Model name to use. + api_version (str, optional): API version (for Azure). + check_connection (bool, optional): Whether to check API connection. + verbose (bool, optional): Enable verbose logging. + + Example: + ```sh + export ENDPOINT="https://yfb-openai-sweden.openai.azure.com" + export API_KEY="xxxxxx" + export API_VERSION="2025-03-01-preview" + export MODEL_NAME="yfb-gpt-4o-sweden" + ``` + ```py + from embodied_gen.utils.gpt_clients import GPT_CLIENT + + response = GPT_CLIENT.query("Describe the physics of a falling apple.") + response = GPT_CLIENT.query( + text_prompt="Describe the content in each image." + image_base64=["path/to/image1.png", "path/to/image2.jpg"], + ) + ``` + """ def __init__( self, @@ -82,6 +110,7 @@ def __init__( stop=(stop_after_attempt(10) | stop_after_delay(30)), ) def completion_with_backoff(self, **kwargs): + """Performs a chat completion request with retry/backoff.""" return self.client.chat.completions.create(**kwargs) def query( @@ -91,19 +120,16 @@ def query( system_role: Optional[str] = None, params: Optional[dict] = None, ) -> Optional[str]: - """Queries the GPT model with a text and optional image prompts. + """Queries the GPT model with text and optional image prompts. Args: - text_prompt (str): The main text input that the model responds to. - image_base64 (Optional[List[str]]): A list of image base64 strings - or local image paths or PIL.Image to accompany the text prompt. - system_role (Optional[str]): Optional system-level instructions - that specify the behavior of the assistant. - params (Optional[dict]): Additional parameters for GPT setting. + text_prompt (str): Main text input. + image_base64 (Optional[list[str | Image.Image]], optional): List of image base64 strings, file paths, or PIL Images. + system_role (Optional[str], optional): System-level instructions. + params (Optional[dict], optional): Additional GPT parameters. Returns: - Optional[str]: The response content generated by the model based on - the prompt. Returns `None` if an error occurs. + Optional[str]: Model response content, or None if error. """ if system_role is None: system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa @@ -177,7 +203,11 @@ def query( return response def check_connection(self) -> None: - """Check whether the GPT API connection is working.""" + """Checks whether the GPT API connection is working. + + Raises: + ConnectionError: If connection fails. + """ try: response = self.completion_with_backoff( messages=[ diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py index 88eb8e5..8feb7ec 100644 --- a/embodied_gen/utils/process_media.py +++ b/embodied_gen/utils/process_media.py @@ -69,6 +69,40 @@ def render_asset3d( no_index_file: bool = False, with_mtl: bool = True, ) -> list[str]: + """Renders a 3D mesh asset and returns output image paths. + + Args: + mesh_path (str): Path to the mesh file. + output_root (str): Directory to save outputs. + distance (float, optional): Camera distance. + num_images (int, optional): Number of views to render. + elevation (list[float], optional): Camera elevation angles. + pbr_light_factor (float, optional): PBR lighting factor. + return_key (str, optional): Glob pattern for output images. + output_subdir (str, optional): Subdirectory for outputs. + gen_color_mp4 (bool, optional): Generate color MP4 video. + gen_viewnormal_mp4 (bool, optional): Generate view normal MP4. + gen_glonormal_mp4 (bool, optional): Generate global normal MP4. + no_index_file (bool, optional): Skip index file saving. + with_mtl (bool, optional): Use mesh material. + + Returns: + list[str]: List of output image file paths. + + Example: + ```py + from embodied_gen.utils.process_media import render_asset3d + + image_paths = render_asset3d( + mesh_path="path_to_mesh.obj", + output_root="path_to_save_dir", + num_images=6, + elevation=(30, -30), + output_subdir="renders", + no_index_file=True, + ) + ``` + """ input_args = dict( mesh_path=mesh_path, output_root=output_root, @@ -95,6 +129,13 @@ def render_asset3d( def merge_images_video(color_images, normal_images, output_path) -> None: + """Merges color and normal images into a video. + + Args: + color_images (list[np.ndarray]): List of color images. + normal_images (list[np.ndarray]): List of normal images. + output_path (str): Path to save the output video. + """ width = color_images[0].shape[1] combined_video = [ np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]]) @@ -108,7 +149,13 @@ def merge_images_video(color_images, normal_images, output_path) -> None: def merge_video_video( video_path1: str, video_path2: str, output_path: str ) -> None: - """Merge two videos by the left half and the right half of the videos.""" + """Merges two videos by combining their left and right halves. + + Args: + video_path1 (str): Path to first video. + video_path2 (str): Path to second video. + output_path (str): Path to save the merged video. + """ clip1 = VideoFileClip(video_path1) clip2 = VideoFileClip(video_path2) @@ -127,6 +174,16 @@ def filter_small_connected_components( area_ratio: float, connectivity: int = 8, ) -> np.ndarray: + """Removes small connected components from a binary mask. + + Args: + mask (Union[Image.Image, np.ndarray]): Input mask. + area_ratio (float): Minimum area ratio for components. + connectivity (int, optional): Connectivity for labeling. + + Returns: + np.ndarray: Mask with small components removed. + """ if isinstance(mask, Image.Image): mask = np.array(mask) num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( @@ -152,6 +209,16 @@ def filter_image_small_connected_components( area_ratio: float = 10, connectivity: int = 8, ) -> np.ndarray: + """Removes small connected components from the alpha channel of an image. + + Args: + image (Union[Image.Image, np.ndarray]): Input image. + area_ratio (float, optional): Minimum area ratio. + connectivity (int, optional): Connectivity for labeling. + + Returns: + np.ndarray: Image with filtered alpha channel. + """ if isinstance(image, Image.Image): image = image.convert("RGBA") image = np.array(image) @@ -169,6 +236,24 @@ def combine_images_to_grid( target_wh: tuple[int, int] = (512, 512), image_mode: str = "RGB", ) -> list[Image.Image]: + """Combines multiple images into a grid. + + Args: + images (list[str | Image.Image]): List of image paths or PIL Images. + cat_row_col (tuple[int, int], optional): Grid rows and columns. + target_wh (tuple[int, int], optional): Target image size. + image_mode (str, optional): Image mode. + + Returns: + list[Image.Image]: List containing the grid image. + + Example: + ```py + from embodied_gen.utils.process_media import combine_images_to_grid + grid = combine_images_to_grid(["img1.png", "img2.png"]) + grid[0].save("grid.png") + ``` + """ n_images = len(images) if n_images == 1: return images @@ -196,6 +281,19 @@ def combine_images_to_grid( class SceneTreeVisualizer: + """Visualizes a scene tree layout using networkx and matplotlib. + + Args: + layout_info (LayoutInfo): Layout information for the scene. + + Example: + ```py + from embodied_gen.utils.process_media import SceneTreeVisualizer + visualizer = SceneTreeVisualizer(layout_info) + visualizer.render(save_path="tree.png") + ``` + """ + def __init__(self, layout_info: LayoutInfo) -> None: self.tree = layout_info.tree self.relation = layout_info.relation @@ -274,6 +372,14 @@ def render( dpi=300, title: str = "Scene 3D Hierarchy Tree", ): + """Renders the scene tree and saves to file. + + Args: + save_path (str): Path to save the rendered image. + figsize (tuple, optional): Figure size. + dpi (int, optional): Image DPI. + title (str, optional): Plot image title. + """ node_colors = [ self.role_colors[self._get_node_role(n)] for n in self.G.nodes ] @@ -350,6 +456,14 @@ def render( def load_scene_dict(file_path: str) -> dict: + """Loads a scene description dictionary from a file. + + Args: + file_path (str): Path to the scene description file. + + Returns: + dict: Mapping from scene ID to description. + """ scene_dict = {} with open(file_path, "r", encoding='utf-8') as f: for line in f: @@ -363,12 +477,28 @@ def load_scene_dict(file_path: str) -> dict: def is_image_file(filename: str) -> bool: + """Checks if a filename is an image file. + + Args: + filename (str): Filename to check. + + Returns: + bool: True if image file, False otherwise. + """ mime_type, _ = mimetypes.guess_type(filename) return mime_type is not None and mime_type.startswith('image') def parse_text_prompts(prompts: list[str]) -> list[str]: + """Parses text prompts from a list or file. + + Args: + prompts (list[str]): List of prompts or a file path. + + Returns: + list[str]: List of parsed prompts. + """ if len(prompts) == 1 and prompts[0].endswith(".txt"): with open(prompts[0], "r") as f: prompts = [ @@ -386,13 +516,18 @@ def alpha_blend_rgba( """Alpha blends a foreground RGBA image over a background RGBA image. Args: - fg_image: Foreground image. Can be a file path (str), a PIL Image, - or a NumPy ndarray. - bg_image: Background image. Can be a file path (str), a PIL Image, - or a NumPy ndarray. + fg_image: Foreground image (str, PIL Image, or ndarray). + bg_image: Background image (str, PIL Image, or ndarray). Returns: - A PIL Image representing the alpha-blended result in RGBA mode. + Image.Image: Alpha-blended RGBA image. + + Example: + ```py + from embodied_gen.utils.process_media import alpha_blend_rgba + result = alpha_blend_rgba("fg.png", "bg.png") + result.save("blended.png") + ``` """ if isinstance(fg_image, str): fg_image = Image.open(fg_image) @@ -421,13 +556,11 @@ def check_object_edge_truncated( """Checks if a binary object mask is truncated at the image edges. Args: - mask: A 2D binary NumPy array where nonzero values indicate the object region. - edge_threshold: Number of pixels from each image edge to consider for truncation. - Defaults to 5. + mask (np.ndarray): 2D binary mask. + edge_threshold (int, optional): Edge pixel threshold. Returns: - True if the object is fully enclosed (not truncated). - False if the object touches or crosses any image boundary. + bool: True if object is fully enclosed, False if truncated. """ top = mask[:edge_threshold, :].any() bottom = mask[-edge_threshold:, :].any() @@ -440,6 +573,22 @@ def check_object_edge_truncated( def vcat_pil_images( images: list[Image.Image], image_mode: str = "RGB" ) -> Image.Image: + """Vertically concatenates a list of PIL images. + + Args: + images (list[Image.Image]): List of images. + image_mode (str, optional): Image mode. + + Returns: + Image.Image: Vertically concatenated image. + + Example: + ```py + from embodied_gen.utils.process_media import vcat_pil_images + img = vcat_pil_images([Image.open("a.png"), Image.open("b.png")]) + img.save("vcat.png") + ``` + """ widths, heights = zip(*(img.size for img in images)) total_height = sum(heights) max_width = max(widths) diff --git a/embodied_gen/utils/simulation.py b/embodied_gen/utils/simulation.py index 6925cfb..5ff13b6 100644 --- a/embodied_gen/utils/simulation.py +++ b/embodied_gen/utils/simulation.py @@ -69,6 +69,21 @@ def load_actor_from_urdf( update_mass: bool = False, scale: float | np.ndarray = 1.0, ) -> sapien.pysapien.Entity: + """Load an sapien actor from a URDF file and add it to the scene. + + Args: + scene (sapien.Scene | ManiSkillScene): The simulation scene. + file_path (str): Path to the URDF file. + pose (sapien.Pose | None): Initial pose of the actor. + env_idx (int): Environment index for multi-env setup. + use_static (bool): Whether the actor is static. + update_mass (bool): Whether to update the actor's mass from URDF. + scale (float | np.ndarray): Scale factor for the actor. + + Returns: + sapien.pysapien.Entity: The created actor entity. + """ + def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose: local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0]) if origin_tag is not None: @@ -154,14 +169,17 @@ def load_assets_from_layout_file( init_quat: list[float] = [0, 0, 0, 1], env_idx: int = None, ) -> dict[str, sapien.pysapien.Entity]: - """Load assets from `EmbodiedGen` layout-gen output and create actors in the scene. + """Load assets from an EmbodiedGen layout file and create sapien actors in the scene. Args: - scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into. - layout (str): The layout file path. - z_offset (float): Offset to apply to the Z-coordinate of non-context objects. - init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment. - env_idx (int): Environment index for multi-environment setup. + scene (ManiSkillScene | sapien.Scene): The sapien simulation scene. + layout (str): Path to the embodiedgen layout file. + z_offset (float): Z offset for non-context objects. + init_quat (list[float]): Initial quaternion for orientation. + env_idx (int): Environment index. + + Returns: + dict[str, sapien.pysapien.Entity]: Mapping from object names to actor entities. """ asset_root = os.path.dirname(layout) layout = LayoutInfo.from_dict(json.load(open(layout, "r"))) @@ -206,6 +224,19 @@ def load_mani_skill_robot( control_mode: str = "pd_joint_pos", backend_str: tuple[str, str] = ("cpu", "gpu"), ) -> BaseAgent: + """Load a ManiSkill robot agent into the scene. + + Args: + scene (sapien.Scene | ManiSkillScene): The simulation scene. + layout (LayoutInfo | str): Layout info or path to layout file. + control_freq (int): Control frequency. + robot_init_qpos_noise (float): Noise for initial joint positions. + control_mode (str): Robot control mode. + backend_str (tuple[str, str]): Simulation/render backend. + + Returns: + BaseAgent: The loaded robot agent. + """ from mani_skill.agents import REGISTERED_AGENTS from mani_skill.envs.scene import ManiSkillScene from mani_skill.envs.utils.system.backend import ( @@ -278,14 +309,14 @@ def render_images( ] ] = None, ) -> dict[str, Image.Image]: - """Render images from a given sapien camera. + """Render images from a given SAPIEN camera. Args: - camera (sapien.render.RenderCameraComponent): The camera to render from. - render_keys (List[str]): Types of images to render (e.g., Color, Segmentation). + camera (sapien.render.RenderCameraComponent): Camera to render from. + render_keys (list[str], optional): Types of images to render. Returns: - Dict[str, Image.Image]: Dictionary of rendered images. + dict[str, Image.Image]: Dictionary of rendered images. """ if render_keys is None: render_keys = [ @@ -341,11 +372,33 @@ def render_images( class SapienSceneManager: - """A class to manage SAPIEN simulator.""" + """Manages SAPIEN simulation scenes, cameras, and rendering. + + This class provides utilities for setting up scenes, adding cameras, + stepping simulation, and rendering images. + + Attributes: + sim_freq (int): Simulation frequency. + ray_tracing (bool): Whether to use ray tracing. + device (str): Device for simulation. + renderer (sapien.SapienRenderer): SAPIEN renderer. + scene (sapien.Scene): Simulation scene. + cameras (list): List of camera components. + actors (dict): Mapping of actor names to entities. + + Example see `embodied_gen/scripts/simulate_sapien.py`. + """ def __init__( self, sim_freq: int, ray_tracing: bool, device: str = "cuda" ) -> None: + """Initialize the scene manager. + + Args: + sim_freq (int): Simulation frequency. + ray_tracing (bool): Enable ray tracing. + device (str): Device for simulation. + """ self.sim_freq = sim_freq self.ray_tracing = ray_tracing self.device = device @@ -355,7 +408,11 @@ def __init__( self.actors: dict[str, sapien.pysapien.Entity] = {} def _setup_scene(self) -> sapien.Scene: - """Set up the SAPIEN scene with lighting and ground.""" + """Set up the SAPIEN scene with lighting and ground. + + Returns: + sapien.Scene: The initialized scene. + """ # Ray tracing settings if self.ray_tracing: sapien.render.set_camera_shader_dir("rt") @@ -397,6 +454,18 @@ def step_action( render_keys: list[str], sim_steps_per_control: int = 1, ) -> dict: + """Step the simulation and render images from cameras. + + Args: + agent (BaseAgent): The robot agent. + action (torch.Tensor): Action to apply. + cameras (list): List of camera components. + render_keys (list[str]): Types of images to render. + sim_steps_per_control (int): Simulation steps per control. + + Returns: + dict: Dictionary of rendered frames per camera. + """ agent.set_action(action) frames = defaultdict(list) for _ in range(sim_steps_per_control): @@ -417,13 +486,13 @@ def create_camera( image_hw: tuple[int, int], fovy_deg: float, ) -> sapien.render.RenderCameraComponent: - """Create a single camera in the scene. + """Create a camera in the scene. Args: - cam_name (str): Name of the camera. - pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z) - image_hw (Tuple[int, int]): Image resolution (height, width) for cameras. - fovy_deg (float): Field of view in degrees for cameras. + cam_name (str): Camera name. + pose (sapien.Pose): Camera pose. + image_hw (tuple[int, int]): Image resolution (height, width). + fovy_deg (float): Field of view in degrees. Returns: sapien.render.RenderCameraComponent: The created camera. @@ -456,15 +525,15 @@ def initialize_circular_cameras( """Initialize multiple cameras arranged in a circle. Args: - num_cameras (int): Number of cameras to create. - radius (float): Radius of the camera circle. - height (float): Fixed Z-coordinate of the cameras. - target_pt (list[float]): 3D point (x, y, z) that cameras look at. - image_hw (Tuple[int, int]): Image resolution (height, width) for cameras. - fovy_deg (float): Field of view in degrees for cameras. + num_cameras (int): Number of cameras. + radius (float): Circle radius. + height (float): Camera height. + target_pt (list[float]): Target point to look at. + image_hw (tuple[int, int]): Image resolution. + fovy_deg (float): Field of view in degrees. Returns: - List[sapien.render.RenderCameraComponent]: List of created cameras. + list[sapien.render.RenderCameraComponent]: List of cameras. """ angle_step = 2 * np.pi / num_cameras world_up_vec = np.array([0.0, 0.0, 1.0]) @@ -510,6 +579,19 @@ def initialize_circular_cameras( class FrankaPandaGrasper(object): + """Provides grasp planning and control for Franka Panda robot. + + Attributes: + agent (BaseAgent): The robot agent. + robot: The robot instance. + control_freq (float): Control frequency. + control_timestep (float): Control timestep. + joint_vel_limits (float): Joint velocity limits. + joint_acc_limits (float): Joint acceleration limits. + finger_length (float): Length of gripper fingers. + planners: Motion planners for each environment. + """ + def __init__( self, agent: BaseAgent, @@ -518,6 +600,7 @@ def __init__( joint_acc_limits: float = 1.0, finger_length: float = 0.025, ) -> None: + """Initialize the grasper.""" self.agent = agent self.robot = agent.robot self.control_freq = control_freq @@ -553,6 +636,15 @@ def control_gripper( gripper_state: Literal[-1, 1], n_step: int = 10, ) -> np.ndarray: + """Generate gripper control actions. + + Args: + gripper_state (Literal[-1, 1]): Desired gripper state. + n_step (int): Number of steps. + + Returns: + np.ndarray: Array of gripper actions. + """ qpos = self.robot.get_qpos()[0, :-2].cpu().numpy() actions = [] for _ in range(n_step): @@ -571,6 +663,20 @@ def move_to_pose( action_key: str = "position", env_idx: int = 0, ) -> np.ndarray: + """Plan and execute motion to a target pose. + + Args: + pose (sapien.Pose): Target pose. + control_timestep (float): Control timestep. + gripper_state (Literal[-1, 1]): Desired gripper state. + use_point_cloud (bool): Use point cloud for planning. + n_max_step (int): Max number of steps. + action_key (str): Key for action in result. + env_idx (int): Environment index. + + Returns: + np.ndarray: Array of actions to reach the pose. + """ result = self.planners[env_idx].plan_qpos_to_pose( np.concatenate([pose.p, pose.q]), self.robot.get_qpos().cpu().numpy()[0], @@ -608,6 +714,17 @@ def compute_grasp_action( offset: tuple[float, float, float] = [0, 0, -0.05], env_idx: int = 0, ) -> np.ndarray: + """Compute grasp actions for a target actor. + + Args: + actor (sapien.pysapien.Entity): Target actor to grasp. + reach_target_only (bool): Only reach the target pose if True. + offset (tuple[float, float, float]): Offset for reach pose. + env_idx (int): Environment index. + + Returns: + np.ndarray: Array of grasp actions. + """ physx_rigid = actor.components[1] mesh = get_component_mesh(physx_rigid, to_world_frame=True) obb = mesh.bounding_box_oriented diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py index 9c50269..9302331 100644 --- a/embodied_gen/utils/tags.py +++ b/embodied_gen/utils/tags.py @@ -1 +1 @@ -VERSION = "v0.1.5" +VERSION = "v0.1.6" diff --git a/embodied_gen/validators/aesthetic_predictor.py b/embodied_gen/validators/aesthetic_predictor.py index 921f363..6e77449 100644 --- a/embodied_gen/validators/aesthetic_predictor.py +++ b/embodied_gen/validators/aesthetic_predictor.py @@ -27,14 +27,22 @@ class AestheticPredictor: - """Aesthetic Score Predictor. + """Aesthetic Score Predictor using CLIP and a pre-trained MLP. - Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main + Checkpoints from `https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main`. Args: - clip_model_dir (str): Path to the directory of the CLIP model. - sac_model_path (str): Path to the pre-trained SAC model. - device (str): Device to use for computation ("cuda" or "cpu"). + clip_model_dir (str, optional): Path to CLIP model directory. + sac_model_path (str, optional): Path to SAC model weights. + device (str, optional): Device for computation ("cuda" or "cpu"). + + Example: + ```py + from embodied_gen.validators.aesthetic_predictor import AestheticPredictor + predictor = AestheticPredictor(device="cuda") + score = predictor.predict("image.png") + print("Aesthetic score:", score) + ``` """ def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"): @@ -109,7 +117,7 @@ def _load_sac_model(self, model_path, input_size): return model def predict(self, image_path): - """Predict the aesthetic score for a given image. + """Predicts the aesthetic score for a given image. Args: image_path (str): Path to the image file. diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py index 65e236c..0e5ff7e 100644 --- a/embodied_gen/validators/quality_checkers.py +++ b/embodied_gen/validators/quality_checkers.py @@ -40,6 +40,16 @@ class BaseChecker: + """Base class for quality checkers using GPT clients. + + Provides a common interface for querying and validating responses. + Subclasses must implement the `query` method. + + Attributes: + prompt (str): The prompt used for queries. + verbose (bool): Whether to enable verbose logging. + """ + def __init__(self, prompt: str = None, verbose: bool = False) -> None: self.prompt = prompt self.verbose = verbose @@ -70,6 +80,15 @@ def __call__(self, *args, **kwargs) -> tuple[bool, str]: def validate( checkers: list["BaseChecker"], images_list: list[list[str]] ) -> list: + """Validates a list of checkers against corresponding image lists. + + Args: + checkers (list[BaseChecker]): List of checker instances. + images_list (list[list[str]]): List of image path lists. + + Returns: + list: Validation results with overall outcome. + """ assert len(checkers) == len(images_list) results = [] overall_result = True @@ -192,7 +211,7 @@ def query(self, image_paths: list[str]) -> str: class ImageAestheticChecker(BaseChecker): - """A class for evaluating the aesthetic quality of images. + """Evaluates the aesthetic quality of images using a CLIP-based predictor. Attributes: clip_model_dir (str): Path to the CLIP model directory. @@ -200,6 +219,14 @@ class ImageAestheticChecker(BaseChecker): thresh (float): Threshold above which images are considered aesthetically acceptable. verbose (bool): Whether to print detailed log messages. predictor (AestheticPredictor): The model used to predict aesthetic scores. + + Example: + ```py + from embodied_gen.validators.quality_checkers import ImageAestheticChecker + checker = ImageAestheticChecker(thresh=4.5) + flag, score = checker(["image1.png", "image2.png"]) + print("Aesthetic OK:", flag, "Score:", score) + ``` """ def __init__( @@ -227,6 +254,16 @@ def __call__(self, image_paths: list[str], **kwargs) -> bool: class SemanticConsistChecker(BaseChecker): + """Checks semantic consistency between text descriptions and segmented images. + + Uses GPT to evaluate if the image matches the text in object type, geometry, and color. + + Attributes: + gpt_client (GPTclient): GPT client for queries. + prompt (str): Prompt for consistency evaluation. + verbose (bool): Whether to enable verbose logging. + """ + def __init__( self, gpt_client: GPTclient, @@ -276,6 +313,16 @@ def query(self, text: str, image: list[Image.Image | str]) -> str: class TextGenAlignChecker(BaseChecker): + """Evaluates alignment between text prompts and generated 3D asset images. + + Assesses if the rendered images match the text description in category and geometry. + + Attributes: + gpt_client (GPTclient): GPT client for queries. + prompt (str): Prompt for alignment evaluation. + verbose (bool): Whether to enable verbose logging. + """ + def __init__( self, gpt_client: GPTclient, @@ -489,6 +536,17 @@ def __call__(self, image_paths: str | Image.Image) -> float: class SemanticMatcher(BaseChecker): + """Matches query text to semantically similar scene descriptions. + + Uses GPT to find the most similar scene IDs from a dictionary. + + Attributes: + gpt_client (GPTclient): GPT client for queries. + prompt (str): Prompt for semantic matching. + verbose (bool): Whether to enable verbose logging. + seed (int): Random seed for selection. + """ + def __init__( self, gpt_client: GPTclient, @@ -543,6 +601,17 @@ def __init__( def query( self, text: str, context: dict, rand: bool = True, params: dict = None ) -> str: + """Queries for semantically similar scene IDs. + + Args: + text (str): Query text. + context (dict): Dictionary of scene descriptions. + rand (bool, optional): Whether to randomly select from top matches. + params (dict, optional): Additional GPT parameters. + + Returns: + str: Matched scene ID. + """ match_list = self.gpt_client.query( self.prompt.format(context=context, text=text), params=params, diff --git a/embodied_gen/validators/urdf_convertor.py b/embodied_gen/validators/urdf_convertor.py index 7341ed7..3f070be 100644 --- a/embodied_gen/validators/urdf_convertor.py +++ b/embodied_gen/validators/urdf_convertor.py @@ -80,6 +80,31 @@ class URDFGenerator(object): + """Generates URDF files for 3D assets with physical and semantic attributes. + + Uses GPT to estimate object properties and generates a URDF file with mesh, friction, mass, and metadata. + + Args: + gpt_client (GPTclient): GPT client for attribute estimation. + mesh_file_list (list[str], optional): Additional mesh files to copy. + prompt_template (str, optional): Prompt template for GPT queries. + attrs_name (list[str], optional): List of attribute names to include. + render_dir (str, optional): Directory for rendered images. + render_view_num (int, optional): Number of views to render. + decompose_convex (bool, optional): Whether to decompose mesh for collision. + rotate_xyzw (list[float], optional): Quaternion for mesh rotation. + + Example: + ```py + from embodied_gen.validators.urdf_convertor import URDFGenerator + from embodied_gen.utils.gpt_clients import GPT_CLIENT + + urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4) + urdf_path = urdf_gen(mesh_path="mesh.obj", output_root="output_dir") + print("Generated URDF:", urdf_path) + ``` + """ + def __init__( self, gpt_client: GPTclient, @@ -168,6 +193,14 @@ def __init__( self.rotate_xyzw = rotate_xyzw def parse_response(self, response: str) -> dict[str, any]: + """Parses GPT response to extract asset attributes. + + Args: + response (str): GPT response string. + + Returns: + dict[str, any]: Parsed attributes. + """ lines = response.split("\n") lines = [line.strip() for line in lines if line] category = lines[0].split(": ")[1] @@ -207,11 +240,9 @@ def generate_urdf( Args: input_mesh (str): Path to the input mesh file. - output_dir (str): Directory to store the generated URDF - and processed mesh. - attr_dict (dict): Dictionary containing attributes like height, - mass, and friction coefficients. - output_name (str, optional): Name for the generated URDF and robot. + output_dir (str): Directory to store the generated URDF and mesh. + attr_dict (dict): Dictionary of asset attributes. + output_name (str, optional): Name for the URDF and robot. Returns: str: Path to the generated URDF file. @@ -336,6 +367,16 @@ def get_attr_from_urdf( attr_root: str = ".//link/extra_info", attr_name: str = "scale", ) -> float: + """Extracts an attribute value from a URDF file. + + Args: + urdf_path (str): Path to the URDF file. + attr_root (str, optional): XML path to attribute root. + attr_name (str, optional): Attribute name. + + Returns: + float: Attribute value, or None if not found. + """ if not os.path.exists(urdf_path): raise FileNotFoundError(f"URDF file not found: {urdf_path}") @@ -358,6 +399,13 @@ def get_attr_from_urdf( def add_quality_tag( urdf_path: str, results: list, output_path: str = None ) -> None: + """Adds a quality tag to a URDF file. + + Args: + urdf_path (str): Path to the URDF file. + results (list): List of [checker_name, result] pairs. + output_path (str, optional): Output file path. + """ if output_path is None: output_path = urdf_path @@ -382,6 +430,14 @@ def add_quality_tag( logger.info(f"URDF files saved to {output_path}") def get_estimated_attributes(self, asset_attrs: dict): + """Calculates estimated attributes from asset properties. + + Args: + asset_attrs (dict): Asset attributes. + + Returns: + dict: Estimated attributes (height, mass, mu, category). + """ estimated_attrs = { "height": round( (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4 @@ -403,6 +459,18 @@ def __call__( category: str = "unknown", **kwargs, ): + """Generates a URDF file for a mesh asset. + + Args: + mesh_path (str): Path to mesh file. + output_root (str): Directory for outputs. + text_prompt (str, optional): Prompt for GPT. + category (str, optional): Asset category. + **kwargs: Additional attributes. + + Returns: + str: Path to generated URDF file. + """ if text_prompt is None or len(text_prompt) == 0: text_prompt = self.prompt_template text_prompt = text_prompt.format(category=category.lower()) diff --git a/pyproject.toml b/pyproject.toml index 4f8d645..be3cc60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ packages = ["embodied_gen"] [project] name = "embodied_gen" -version = "v0.1.5" +version = "v0.1.6" readme = "README.md" license = "Apache-2.0" license-files = ["LICENSE", "NOTICE"] diff --git a/tests/test_examples/test_asset_converter.py b/tests/test_examples/test_asset_converter.py index eeefdbb..6094945 100644 --- a/tests/test_examples/test_asset_converter.py +++ b/tests/test_examples/test_asset_converter.py @@ -4,10 +4,9 @@ from huggingface_hub import snapshot_download from embodied_gen.data.asset_converter import ( AssetConverterFactory, - AssetType, - SimAssetMapper, cvt_embodiedgen_asset_to_anysim, ) +from embodied_gen.utils.enum import AssetType, SimAssetMapper @pytest.fixture(scope="session") @@ -77,7 +76,10 @@ def test_cvt_embodiedgen_asset_to_anysim( ): dst_asset_path = cvt_embodiedgen_asset_to_anysim( urdf_files=[ - "outputs/embodiedgen_assets/demo_assets/remote_control2/result/remote_control.urdf", + "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf", + ], + target_dirs=[ + "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd", ], target_type=SimAssetMapper[simulator_name], source_type=AssetType.MESH,