* feat(sim): Optimize support for downstream simulators and gym. * feat(sim): Optimize support for downstream simulators and gym. * docs: update docs * update version
340 lines
13 KiB
Python
340 lines
13 KiB
Python
# Project EmbodiedGen
|
|
#
|
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
|
|
import argparse
|
|
import os
|
|
import random
|
|
import sys
|
|
from glob import glob
|
|
from shutil import copy, copytree, rmtree
|
|
|
|
import numpy as np
|
|
import torch
|
|
import trimesh
|
|
from PIL import Image
|
|
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
|
from embodied_gen.data.utils import delete_dir, trellis_preprocess
|
|
from embodied_gen.models.delight_model import DelightingModel
|
|
from embodied_gen.models.gs_model import GaussianOperator
|
|
from embodied_gen.models.segment_model import RembgRemover
|
|
from embodied_gen.models.sr_model import ImageRealESRGAN
|
|
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|
from embodied_gen.utils.log import logger
|
|
from embodied_gen.utils.process_media import merge_images_video
|
|
from embodied_gen.utils.tags import VERSION
|
|
from embodied_gen.utils.trender import render_video
|
|
from embodied_gen.validators.quality_checkers import (
|
|
BaseChecker,
|
|
ImageAestheticChecker,
|
|
ImageSegChecker,
|
|
MeshGeoChecker,
|
|
)
|
|
from embodied_gen.validators.urdf_convertor import URDFGenerator
|
|
|
|
current_file_path = os.path.abspath(__file__)
|
|
current_dir = os.path.dirname(current_file_path)
|
|
sys.path.append(os.path.join(current_dir, "../.."))
|
|
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
|
|
|
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
|
"~/.cache/torch_extensions"
|
|
)
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
|
os.environ["SPCONV_ALGO"] = "native"
|
|
random.seed(0)
|
|
|
|
logger.info("Loading Image3D Models...")
|
|
DELIGHT = DelightingModel()
|
|
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
|
RBG_REMOVER = RembgRemover()
|
|
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
|
"microsoft/TRELLIS-image-large"
|
|
)
|
|
# PIPELINE.cuda()
|
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
|
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
|
AESTHETIC_CHECKER = ImageAestheticChecker()
|
|
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
|
|
parser.add_argument(
|
|
"--image_path", type=str, nargs="+", help="Path to the input images."
|
|
)
|
|
parser.add_argument(
|
|
"--image_root", type=str, help="Path to the input images folder."
|
|
)
|
|
parser.add_argument(
|
|
"--output_root",
|
|
type=str,
|
|
help="Root directory for saving outputs.",
|
|
)
|
|
parser.add_argument(
|
|
"--height_range",
|
|
type=str,
|
|
default=None,
|
|
help="The hight in meter to restore the mesh real size.",
|
|
)
|
|
parser.add_argument(
|
|
"--mass_range",
|
|
type=str,
|
|
default=None,
|
|
help="The mass in kg to restore the mesh real weight.",
|
|
)
|
|
parser.add_argument("--asset_type", type=str, nargs="+", default=None)
|
|
parser.add_argument("--skip_exists", action="store_true")
|
|
parser.add_argument("--version", type=str, default=VERSION)
|
|
parser.add_argument("--keep_intermediate", action="store_true")
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument(
|
|
"--n_retry",
|
|
type=int,
|
|
default=2,
|
|
)
|
|
parser.add_argument("--disable_decompose_convex", action="store_true")
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
return args
|
|
|
|
|
|
def entrypoint(**kwargs):
|
|
args = parse_args()
|
|
for k, v in kwargs.items():
|
|
if hasattr(args, k) and v is not None:
|
|
setattr(args, k, v)
|
|
|
|
assert (
|
|
args.image_path or args.image_root
|
|
), "Please provide either --image_path or --image_root."
|
|
if not args.image_path:
|
|
args.image_path = glob(os.path.join(args.image_root, "*.png"))
|
|
args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
|
|
args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
|
|
|
|
for idx, image_path in enumerate(args.image_path):
|
|
try:
|
|
filename = os.path.basename(image_path).split(".")[0]
|
|
output_root = args.output_root
|
|
if args.image_root is not None or len(args.image_path) > 1:
|
|
output_root = os.path.join(output_root, filename)
|
|
os.makedirs(output_root, exist_ok=True)
|
|
|
|
mesh_out = f"{output_root}/{filename}.obj"
|
|
if args.skip_exists and os.path.exists(mesh_out):
|
|
logger.warning(
|
|
f"Skip {image_path}, already processed in {mesh_out}"
|
|
)
|
|
continue
|
|
|
|
image = Image.open(image_path)
|
|
image.save(f"{output_root}/{filename}_raw.png")
|
|
|
|
# Segmentation: Get segmented image using Rembg.
|
|
seg_path = f"{output_root}/{filename}_cond.png"
|
|
seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
|
|
seg_image = trellis_preprocess(seg_image)
|
|
seg_image.save(seg_path)
|
|
|
|
seed = args.seed
|
|
asset_node = "unknown"
|
|
if isinstance(args.asset_type, list) and args.asset_type[idx]:
|
|
asset_node = args.asset_type[idx]
|
|
for try_idx in range(args.n_retry):
|
|
logger.info(
|
|
f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
|
|
)
|
|
# Run the pipeline
|
|
try:
|
|
PIPELINE.cuda()
|
|
outputs = PIPELINE.run(
|
|
seg_image,
|
|
preprocess_image=False,
|
|
seed=(
|
|
random.randint(0, 100000) if seed is None else seed
|
|
),
|
|
# Optional parameters
|
|
# sparse_structure_sampler_params={
|
|
# "steps": 12,
|
|
# "cfg_strength": 7.5,
|
|
# },
|
|
# slat_sampler_params={
|
|
# "steps": 12,
|
|
# "cfg_strength": 3,
|
|
# },
|
|
)
|
|
PIPELINE.cpu()
|
|
torch.cuda.empty_cache()
|
|
except Exception as e:
|
|
logger.error(
|
|
f"[Pipeline Failed] process {image_path}: {e}, skip."
|
|
)
|
|
continue
|
|
|
|
gs_model = outputs["gaussian"][0]
|
|
mesh_model = outputs["mesh"][0]
|
|
|
|
# Save the raw Gaussian model
|
|
gs_path = mesh_out.replace(".obj", "_gs.ply")
|
|
gs_model.save_ply(gs_path)
|
|
|
|
# Rotate mesh and GS by 90 degrees around Z-axis.
|
|
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
|
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
|
|
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
|
|
|
# Addtional rotation for GS to align mesh.
|
|
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
|
|
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
|
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
|
GaussianOperator.resave_ply(
|
|
in_ply=gs_path,
|
|
out_ply=aligned_gs_path,
|
|
instance_pose=pose,
|
|
device="cpu",
|
|
)
|
|
color_path = os.path.join(output_root, "color.png")
|
|
render_gs_api(aligned_gs_path, color_path)
|
|
|
|
geo_flag, geo_result = GEO_CHECKER(
|
|
[color_path], text=asset_node
|
|
)
|
|
logger.warning(
|
|
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
|
|
)
|
|
if geo_flag is True or geo_flag is None:
|
|
break
|
|
|
|
seed = random.randint(0, 100000) if seed is not None else None
|
|
|
|
# Render the video for generated 3D asset.
|
|
color_images = render_video(gs_model)["color"]
|
|
normal_images = render_video(mesh_model)["normal"]
|
|
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
|
merge_images_video(color_images, normal_images, video_path)
|
|
|
|
mesh = trimesh.Trimesh(
|
|
vertices=mesh_model.vertices.cpu().numpy(),
|
|
faces=mesh_model.faces.cpu().numpy(),
|
|
)
|
|
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
|
|
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
|
|
|
|
mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
|
|
mesh.export(mesh_obj_path)
|
|
|
|
mesh = backproject_api(
|
|
delight_model=DELIGHT,
|
|
imagesr_model=IMAGESR_MODEL,
|
|
color_path=color_path,
|
|
mesh_path=mesh_obj_path,
|
|
output_path=mesh_obj_path,
|
|
skip_fix_mesh=False,
|
|
delight=True,
|
|
texture_wh=[2048, 2048],
|
|
)
|
|
|
|
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
|
|
mesh.export(mesh_glb_path)
|
|
|
|
urdf_convertor = URDFGenerator(
|
|
GPT_CLIENT,
|
|
render_view_num=4,
|
|
decompose_convex=not args.disable_decompose_convex,
|
|
)
|
|
asset_attrs = {
|
|
"version": VERSION,
|
|
"gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
|
|
}
|
|
if args.height_range:
|
|
min_height, max_height = map(
|
|
float, args.height_range.split("-")
|
|
)
|
|
asset_attrs["min_height"] = min_height
|
|
asset_attrs["max_height"] = max_height
|
|
if args.mass_range:
|
|
min_mass, max_mass = map(float, args.mass_range.split("-"))
|
|
asset_attrs["min_mass"] = min_mass
|
|
asset_attrs["max_mass"] = max_mass
|
|
if isinstance(args.asset_type, list) and args.asset_type[idx]:
|
|
asset_attrs["category"] = args.asset_type[idx]
|
|
if args.version:
|
|
asset_attrs["version"] = args.version
|
|
|
|
urdf_root = f"{output_root}/URDF_{filename}"
|
|
urdf_path = urdf_convertor(
|
|
mesh_path=mesh_obj_path,
|
|
output_root=urdf_root,
|
|
**asset_attrs,
|
|
)
|
|
|
|
# Rescale GS and save to URDF/mesh folder.
|
|
real_height = urdf_convertor.get_attr_from_urdf(
|
|
urdf_path, attr_name="real_height"
|
|
)
|
|
out_gs = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
|
|
GaussianOperator.resave_ply(
|
|
in_ply=aligned_gs_path,
|
|
out_ply=out_gs,
|
|
real_height=real_height,
|
|
device="cpu",
|
|
)
|
|
|
|
# Quality check and update .urdf file.
|
|
mesh_out = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
|
|
trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
|
|
|
|
image_dir = f"{urdf_root}/{urdf_convertor.output_render_dir}/image_color" # noqa
|
|
image_paths = glob(f"{image_dir}/*.png")
|
|
images_list = []
|
|
for checker in CHECKERS:
|
|
images = image_paths
|
|
if isinstance(checker, ImageSegChecker):
|
|
images = [
|
|
f"{output_root}/{filename}_raw.png",
|
|
f"{output_root}/{filename}_cond.png",
|
|
]
|
|
images_list.append(images)
|
|
|
|
qa_results = BaseChecker.validate(CHECKERS, images_list)
|
|
urdf_convertor.add_quality_tag(urdf_path, qa_results)
|
|
|
|
# Organize the final result files
|
|
result_dir = f"{output_root}/result"
|
|
if os.path.exists(result_dir):
|
|
rmtree(result_dir, ignore_errors=True)
|
|
os.makedirs(result_dir, exist_ok=True)
|
|
copy(urdf_path, f"{result_dir}/{os.path.basename(urdf_path)}")
|
|
copytree(
|
|
f"{urdf_root}/{urdf_convertor.output_mesh_dir}",
|
|
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
|
|
)
|
|
copy(video_path, f"{result_dir}/video.mp4")
|
|
if not args.keep_intermediate:
|
|
delete_dir(output_root, keep_subs=["result"])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process {image_path}: {e}, skip.")
|
|
continue
|
|
|
|
logger.info(f"Processing complete. Outputs saved to {args.output_root}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
entrypoint()
|