feat(texture): Refine texture back-projection. (#58)

This commit is contained in:
Xinjie 2025-12-04 21:25:33 +08:00 committed by GitHub
parent fb637f9afc
commit 74c3c52a23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 727 additions and 108 deletions

View File

@ -331,4 +331,4 @@ EmbodiedGen builds upon the following amazing projects and models:
## ⚖️ License ## ⚖️ License
This project is licensed under the [Apache License 2.0](LICENSE). See the `LICENSE` file for details. This project is licensed under the [Apache License 2.0](docs/LICENSE). See the `LICENSE` file for details.

View File

@ -4,7 +4,7 @@ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
lighting_css = """ lighting_css = """
<style> <style>
#lighter_mesh canvas { #lighter_mesh canvas {
filter: brightness(1.9) !important; filter: brightness(2.0) !important;
} }
</style> </style>
""" """

View File

@ -32,8 +32,9 @@ import trimesh
from easydict import EasyDict as edict from easydict import EasyDict as edict
from PIL import Image from PIL import Image
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
from embodied_gen.data.differentiable_render import entrypoint as render_api from embodied_gen.data.differentiable_render import entrypoint as render_api
from embodied_gen.data.utils import trellis_preprocess, zip_files from embodied_gen.data.utils import resize_pil, trellis_preprocess, zip_files
from embodied_gen.models.delight_model import DelightingModel from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.gs_model import GaussianOperator from embodied_gen.models.gs_model import GaussianOperator
from embodied_gen.models.segment_model import ( from embodied_gen.models.segment_model import (
@ -131,8 +132,8 @@ def patched_setup_functions(self):
Gaussian.setup_functions = patched_setup_functions Gaussian.setup_functions = patched_setup_functions
DELIGHT = DelightingModel() # DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4) # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
# IMAGESR_MODEL = ImageStableSR() # IMAGESR_MODEL = ImageStableSR()
if os.getenv("GRADIO_APP") == "imageto3d": if os.getenv("GRADIO_APP") == "imageto3d":
RBG_REMOVER = RembgRemover() RBG_REMOVER = RembgRemover()
@ -169,6 +170,8 @@ elif os.getenv("GRADIO_APP") == "textto3d":
) )
os.makedirs(TMP_DIR, exist_ok=True) os.makedirs(TMP_DIR, exist_ok=True)
elif os.getenv("GRADIO_APP") == "texture_edit": elif os.getenv("GRADIO_APP") == "texture_edit":
DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
PIPELINE_IP = build_texture_gen_pipe( PIPELINE_IP = build_texture_gen_pipe(
base_ckpt_dir="./weights", base_ckpt_dir="./weights",
ip_adapt_scale=0.7, ip_adapt_scale=0.7,
@ -205,7 +208,7 @@ def preprocess_image_fn(
elif isinstance(image, np.ndarray): elif isinstance(image, np.ndarray):
image = Image.fromarray(image) image = Image.fromarray(image)
image_cache = image.copy().resize((512, 512)) image_cache = resize_pil(image.copy(), 1024)
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
image = bg_remover(image) image = bg_remover(image)
@ -221,7 +224,7 @@ def preprocess_sam_image_fn(
image = Image.fromarray(image) image = Image.fromarray(image)
sam_image = SAM_PREDICTOR.preprocess_image(image) sam_image = SAM_PREDICTOR.preprocess_image(image)
image_cache = Image.fromarray(sam_image).resize((512, 512)) image_cache = sam_image.copy()
SAM_PREDICTOR.predictor.set_image(sam_image) SAM_PREDICTOR.predictor.set_image(sam_image)
return sam_image, image_cache return sam_image, image_cache
@ -512,6 +515,60 @@ def extract_3d_representations_v2(
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
def extract_3d_representations_v3(
state: dict,
enable_delight: bool,
texture_size: int,
req: gr.Request,
):
output_root = TMP_DIR
user_dir = os.path.join(output_root, str(req.session_hash))
gs_model, mesh_model = unpack_state(state, device="cpu")
filename = "sample"
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
gs_model.save_ply(gs_path)
# Rotate mesh and GS by 90 degrees around Z-axis.
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
# Addtional rotation for GS to align mesh.
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
pose = GaussianOperator.trans_to_quatpose(gs_rot)
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
GaussianOperator.resave_ply(
in_ply=gs_path,
out_ply=aligned_gs_path,
instance_pose=pose,
device="cpu",
)
mesh = trimesh.Trimesh(
vertices=mesh_model.vertices.cpu().numpy(),
faces=mesh_model.faces.cpu().numpy(),
)
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
mesh.export(mesh_obj_path)
mesh = backproject_api_v3(
gs_path=aligned_gs_path,
mesh_path=mesh_obj_path,
output_path=mesh_obj_path,
skip_fix_mesh=False,
texture_size=texture_size,
)
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
mesh.export(mesh_glb_path)
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
def extract_urdf( def extract_urdf(
gs_path: str, gs_path: str,
mesh_obj_path: str, mesh_obj_path: str,

View File

@ -27,7 +27,7 @@ from common import (
VERSION, VERSION,
active_btn_by_content, active_btn_by_content,
end_session, end_session,
extract_3d_representations_v2, extract_3d_representations_v3,
extract_urdf, extract_urdf,
get_seed, get_seed,
image_to_3d, image_to_3d,
@ -179,17 +179,17 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
) )
generate_btn = gr.Button( generate_btn = gr.Button(
"🚀 1. Generate(~0.5 mins)", "🚀 1. Generate(~2 mins)",
variant="primary", variant="primary",
interactive=False, interactive=False,
) )
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
with gr.Row(): # with gr.Row():
extract_rep3d_btn = gr.Button( # extract_rep3d_btn = gr.Button(
"🔍 2. Extract 3D Representation(~2 mins)", # "🔍 2. Extract 3D Representation(~2 mins)",
variant="primary", # variant="primary",
interactive=False, # interactive=False,
) # )
with gr.Accordion( with gr.Accordion(
label="Enter Asset Attributes(optional)", open=False label="Enter Asset Attributes(optional)", open=False
): ):
@ -207,7 +207,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
) )
with gr.Row(): with gr.Row():
extract_urdf_btn = gr.Button( extract_urdf_btn = gr.Button(
"🧩 3. Extract URDF with physics(~1 mins)", "🧩 2. Extract URDF with physics(~1 mins)",
variant="primary", variant="primary",
interactive=False, interactive=False,
) )
@ -230,7 +230,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
) )
with gr.Row(): with gr.Row():
download_urdf = gr.DownloadButton( download_urdf = gr.DownloadButton(
label="⬇️ 4. Download URDF", label="⬇️ 3. Download URDF",
variant="primary", variant="primary",
interactive=False, interactive=False,
) )
@ -326,7 +326,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
image_prompt.change( image_prompt.change(
lambda: tuple( lambda: tuple(
[ [
gr.Button(interactive=False), # gr.Button(interactive=False),
gr.Button(interactive=False), gr.Button(interactive=False),
gr.Button(interactive=False), gr.Button(interactive=False),
None, None,
@ -344,7 +344,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
] ]
), ),
outputs=[ outputs=[
extract_rep3d_btn, # extract_rep3d_btn,
extract_urdf_btn, extract_urdf_btn,
download_urdf, download_urdf,
model_output_gs, model_output_gs,
@ -375,7 +375,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
image_prompt_sam.change( image_prompt_sam.change(
lambda: tuple( lambda: tuple(
[ [
gr.Button(interactive=False), # gr.Button(interactive=False),
gr.Button(interactive=False), gr.Button(interactive=False),
gr.Button(interactive=False), gr.Button(interactive=False),
None, None,
@ -394,7 +394,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
] ]
), ),
outputs=[ outputs=[
extract_rep3d_btn, # extract_rep3d_btn,
extract_urdf_btn, extract_urdf_btn,
download_urdf, download_urdf,
model_output_gs, model_output_gs,
@ -447,12 +447,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
], ],
outputs=[output_buf, video_output], outputs=[output_buf, video_output],
).success( ).success(
lambda: gr.Button(interactive=True), extract_3d_representations_v3,
outputs=[extract_rep3d_btn],
)
extract_rep3d_btn.click(
extract_3d_representations_v2,
inputs=[ inputs=[
output_buf, output_buf,
project_delight, project_delight,
@ -495,4 +490,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
if __name__ == "__main__": if __name__ == "__main__":
demo.launch() demo.launch(server_port=8081)

View File

@ -27,7 +27,7 @@ from common import (
VERSION, VERSION,
active_btn_by_text_content, active_btn_by_text_content,
end_session, end_session,
extract_3d_representations_v2, extract_3d_representations_v3,
extract_urdf, extract_urdf,
get_cached_image, get_cached_image,
get_seed, get_seed,
@ -178,17 +178,17 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
) )
generate_btn = gr.Button( generate_btn = gr.Button(
"🚀 2. Generate 3D(~0.5 mins)", "🚀 2. Generate 3D(~2 mins)",
variant="primary", variant="primary",
interactive=False, interactive=False,
) )
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
with gr.Row(): # with gr.Row():
extract_rep3d_btn = gr.Button( # extract_rep3d_btn = gr.Button(
"🔍 3. Extract 3D Representation(~1 mins)", # "🔍 3. Extract 3D Representation(~1 mins)",
variant="primary", # variant="primary",
interactive=False, # interactive=False,
) # )
with gr.Accordion( with gr.Accordion(
label="Enter Asset Attributes(optional)", open=False label="Enter Asset Attributes(optional)", open=False
): ):
@ -206,13 +206,13 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
) )
with gr.Row(): with gr.Row():
extract_urdf_btn = gr.Button( extract_urdf_btn = gr.Button(
"🧩 4. Extract URDF with physics(~1 mins)", "🧩 3. Extract URDF with physics(~1 mins)",
variant="primary", variant="primary",
interactive=False, interactive=False,
) )
with gr.Row(): with gr.Row():
download_urdf = gr.DownloadButton( download_urdf = gr.DownloadButton(
label="⬇️ 5. Download URDF", label="⬇️ 4. Download URDF",
variant="primary", variant="primary",
interactive=False, interactive=False,
) )
@ -336,7 +336,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
generate_img_btn.click( generate_img_btn.click(
lambda: tuple( lambda: tuple(
[ [
gr.Button(interactive=False), # gr.Button(interactive=False),
gr.Button(interactive=False), gr.Button(interactive=False),
gr.Button(interactive=False), gr.Button(interactive=False),
gr.Button(interactive=False), gr.Button(interactive=False),
@ -358,7 +358,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
] ]
), ),
outputs=[ outputs=[
extract_rep3d_btn, # extract_rep3d_btn,
extract_urdf_btn, extract_urdf_btn,
download_urdf, download_urdf,
generate_btn, generate_btn,
@ -428,12 +428,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
], ],
outputs=[output_buf, video_output], outputs=[output_buf, video_output],
).success( ).success(
lambda: gr.Button(interactive=True), extract_3d_representations_v3,
outputs=[extract_rep3d_btn],
)
extract_rep3d_btn.click(
extract_3d_representations_v2,
inputs=[ inputs=[
output_buf, output_buf,
project_delight, project_delight,
@ -476,4 +471,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
if __name__ == "__main__": if __name__ == "__main__":
demo.launch() demo.launch(server_port=8082)

View File

@ -381,4 +381,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
if __name__ == "__main__": if __name__ == "__main__":
demo.launch() demo.launch(server_port=8083)

View File

@ -727,7 +727,6 @@ with gr.Blocks(
if __name__ == "__main__": if __name__ == "__main__":
demo.launch( demo.launch(
server_name="10.34.8.77",
server_port=8088, server_port=8088,
allowed_paths=[ allowed_paths=[
"/horizon-bucket/robot_lab/datasets/embodiedgen/assets" "/horizon-bucket/robot_lab/datasets/embodiedgen/assets"

View File

@ -14,6 +14,8 @@ conda activate embodiedgen
bash install.sh basic bash install.sh basic
``` ```
Please `huggingface-cli login` to ensure that the ckpts can be downloaded automatically afterwards.
## ✅ Starting from Docker ## ✅ Starting from Docker
We provide a pre-built Docker image on [Docker Hub](https://hub.docker.com/repository/docker/wangxinjie/embodiedgen) with a configured environment for your convenience. For more details, please refer to [Docker documentation](https://github.com/HorizonRobotics/EmbodiedGen/tree/master/docker). We provide a pre-built Docker image on [Docker Hub](https://hub.docker.com/repository/docker/wangxinjie/embodiedgen) with a configured environment for your convenience. For more details, please refer to [Docker documentation](https://github.com/HorizonRobotics/EmbodiedGen/tree/master/docker).

View File

@ -589,6 +589,8 @@ class MeshtoUSDConverter(AssetConverterBase):
stage = Usd.Stage.Open(usd_path) stage = Usd.Stage.Open(usd_path)
layer = stage.GetRootLayer() layer = stage.GetRootLayer()
with Usd.EditContext(stage, layer): with Usd.EditContext(stage, layer):
base_prim = stage.GetPseudoRoot().GetChildren()[0]
base_prim.SetMetadata("kind", "component")
for prim in stage.Traverse(): for prim in stage.Traverse():
# Change texture path to relative path. # Change texture path to relative path.
if prim.GetName() == "material_0": if prim.GetName() == "material_0":

View File

@ -34,6 +34,7 @@ from embodied_gen.data.utils import (
CameraSetting, CameraSetting,
get_images_from_grid, get_images_from_grid,
init_kal_camera, init_kal_camera,
kaolin_to_opencv_view,
normalize_vertices_array, normalize_vertices_array,
post_process_texture, post_process_texture,
save_mesh_with_mtl, save_mesh_with_mtl,
@ -306,28 +307,6 @@ class TextureBaker(object):
raise ValueError(f"Unknown mode: {mode}") raise ValueError(f"Unknown mode: {mode}")
def kaolin_to_opencv_view(raw_matrix):
R_orig = raw_matrix[:, :3, :3]
t_orig = raw_matrix[:, :3, 3]
R_target = torch.zeros_like(R_orig)
R_target[:, :, 0] = R_orig[:, :, 2]
R_target[:, :, 1] = R_orig[:, :, 0]
R_target[:, :, 2] = R_orig[:, :, 1]
t_target = t_orig
target_matrix = (
torch.eye(4, device=raw_matrix.device)
.unsqueeze(0)
.repeat(raw_matrix.size(0), 1, 1)
)
target_matrix[:, :3, :3] = R_target
target_matrix[:, :3, 3] = t_target
return target_matrix
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Render settings") parser = argparse.ArgumentParser(description="Render settings")

View File

@ -0,0 +1,558 @@
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import logging
import math
import os
from typing import Literal, Union
import cv2
import numpy as np
import nvdiffrast.torch as dr
import spaces
import torch
import trimesh
import utils3d
import xatlas
from PIL import Image
from tqdm import tqdm
from embodied_gen.data.mesh_operator import MeshFixer
from embodied_gen.data.utils import (
CameraSetting,
init_kal_camera,
kaolin_to_opencv_view,
normalize_vertices_array,
post_process_texture,
save_mesh_with_mtl,
)
from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.gs_model import load_gs_model
from embodied_gen.models.sr_model import ImageRealESRGAN
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"TextureBaker",
]
class TextureBaker(object):
"""Baking textures onto a mesh from multiple observations.
This class take 3D mesh data, camera settings and texture baking parameters
to generate texture map by projecting images to the mesh from diff views.
It supports both a fast texture baking approach and a more optimized method
with total variation regularization.
Attributes:
vertices (torch.Tensor): The vertices of the mesh.
faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
uvs (torch.Tensor): The UV coordinates of the mesh.
camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
device (str): The device to run computations on ("cpu" or "cuda").
w2cs (torch.Tensor): World-to-camera transformation matrices.
projections (torch.Tensor): Camera projection matrices.
Example:
>>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
>>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
>>> images = get_images_from_grid(args.color_path, image_size)
>>> texture = texture_backer.bake_texture(
... images, texture_size=args.texture_size, mode=args.baker_mode
... )
>>> texture = post_process_texture(texture)
"""
def __init__(
self,
vertices: np.ndarray,
faces: np.ndarray,
uvs: np.ndarray,
camera_params: CameraSetting,
device: str = "cuda",
) -> None:
self.vertices = (
torch.tensor(vertices, device=device)
if isinstance(vertices, np.ndarray)
else vertices.to(device)
)
self.faces = (
torch.tensor(faces.astype(np.int32), device=device)
if isinstance(faces, np.ndarray)
else faces.to(device)
)
self.uvs = (
torch.tensor(uvs, device=device)
if isinstance(uvs, np.ndarray)
else uvs.to(device)
)
self.camera_params = camera_params
self.device = device
camera = init_kal_camera(camera_params)
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
matrix_mv = kaolin_to_opencv_view(matrix_mv)
matrix_p = (
camera.intrinsics.projection_matrix()
) # (n_cam 4 4) cam2pixel
self.w2cs = matrix_mv.to(self.device)
self.projections = matrix_p.to(self.device)
@staticmethod
def parametrize_mesh(
vertices: np.array, faces: np.array
) -> Union[np.array, np.array, np.array]:
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
vertices = vertices[vmapping]
faces = indices
return vertices, faces, uvs
def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
texture = torch.zeros(
(texture_size * texture_size, 3), dtype=torch.float32
).cuda()
texture_weights = torch.zeros(
(texture_size * texture_size), dtype=torch.float32
).cuda()
rastctx = utils3d.torch.RastContext(backend="cuda")
for observation, w2c, projection in tqdm(
zip(observations, w2cs, projections),
total=len(observations),
desc="Texture baking (fast)",
):
with torch.no_grad():
rast = utils3d.torch.rasterize_triangle_faces(
rastctx,
self.vertices[None],
self.faces,
observation.shape[1],
observation.shape[0],
uv=self.uvs[None],
view=w2c,
projection=projection,
)
uv_map = rast["uv"][0].detach().flip(0)
mask = rast["mask"][0].detach().bool() & masks[0]
# nearest neighbor interpolation
uv_map = (uv_map * texture_size).floor().long()
obs = observation[mask]
uv_map = uv_map[mask]
idx = (
uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
)
texture = texture.scatter_add(
0, idx.view(-1, 1).expand(-1, 3), obs
)
texture_weights = texture_weights.scatter_add(
0,
idx,
torch.ones(
(obs.shape[0]), dtype=torch.float32, device=texture.device
),
)
mask = texture_weights > 0
texture[mask] /= texture_weights[mask][:, None]
texture = np.clip(
texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
0,
255,
).astype(np.uint8)
# inpaint
mask = (
(texture_weights == 0)
.cpu()
.numpy()
.astype(np.uint8)
.reshape(texture_size, texture_size)
)
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
return texture
def _bake_opt(
self,
observations,
w2cs,
projections,
texture_size,
lambda_tv,
masks,
total_steps,
):
rastctx = utils3d.torch.RastContext(backend="cuda")
observations = [observations.flip(0) for observations in observations]
masks = [m.flip(0) for m in masks]
_uv = []
_uv_dr = []
for observation, w2c, projection in tqdm(
zip(observations, w2cs, projections),
total=len(w2cs),
):
with torch.no_grad():
rast = utils3d.torch.rasterize_triangle_faces(
rastctx,
self.vertices[None],
self.faces,
observation.shape[1],
observation.shape[0],
uv=self.uvs[None],
view=w2c,
projection=projection,
)
_uv.append(rast["uv"].detach())
_uv_dr.append(rast["uv_dr"].detach())
texture = torch.nn.Parameter(
torch.zeros(
(1, texture_size, texture_size, 3), dtype=torch.float32
).cuda()
)
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
def cosine_anealing(step, total_steps, start_lr, end_lr):
return end_lr + 0.5 * (start_lr - end_lr) * (
1 + np.cos(np.pi * step / total_steps)
)
def tv_loss(texture):
return torch.nn.functional.l1_loss(
texture[:, :-1, :, :], texture[:, 1:, :, :]
) + torch.nn.functional.l1_loss(
texture[:, :, :-1, :], texture[:, :, 1:, :]
)
with tqdm(total=total_steps, desc="Texture baking") as pbar:
for step in range(total_steps):
optimizer.zero_grad()
selected = np.random.randint(0, len(w2cs))
uv, uv_dr, observation, mask = (
_uv[selected],
_uv_dr[selected],
observations[selected],
masks[selected],
)
render = dr.texture(texture, uv, uv_dr)[0]
loss = torch.nn.functional.l1_loss(
render[mask], observation[mask]
)
if lambda_tv > 0:
loss += lambda_tv * tv_loss(texture)
loss.backward()
optimizer.step()
optimizer.param_groups[0]["lr"] = cosine_anealing(
step, total_steps, 1e-2, 1e-5
)
pbar.set_postfix({"loss": loss.item()})
pbar.update()
texture = np.clip(
texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
).astype(np.uint8)
mask = 1 - utils3d.torch.rasterize_triangle_faces(
rastctx,
(self.uvs * 2 - 1)[None],
self.faces,
texture_size,
texture_size,
)["mask"][0].detach().cpu().numpy().astype(np.uint8)
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
return texture
def bake_texture(
self,
images: list[np.array],
texture_size: int = 1024,
mode: Literal["fast", "opt"] = "opt",
lambda_tv: float = 1e-2,
opt_step: int = 2000,
):
masks = [np.any(img > 0, axis=-1) for img in images]
masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
images = [
torch.tensor(obs / 255.0).float().to(self.device) for obs in images
]
if mode == "fast":
return self._bake_fast(
images, self.w2cs, self.projections, texture_size, masks
)
elif mode == "opt":
return self._bake_opt(
images,
self.w2cs,
self.projections,
texture_size,
lambda_tv,
masks,
opt_step,
)
else:
raise ValueError(f"Unknown mode: {mode}")
def parse_args():
"""Parses command-line arguments for texture backprojection.
Returns:
argparse.Namespace: Parsed arguments.
"""
parser = argparse.ArgumentParser(description="Backproject texture")
parser.add_argument(
"--gs_path",
type=str,
help="Path to the GS.ply gaussian splatting model",
)
parser.add_argument(
"--mesh_path",
type=str,
help="Mesh path, .obj, .glb or .ply",
)
parser.add_argument(
"--output_path",
type=str,
help="Output mesh path with suffix",
)
parser.add_argument(
"--num_images",
type=int,
default=180,
help="Number of images to render.",
)
parser.add_argument(
"--elevation",
nargs="+",
type=float,
default=list(range(85, -90, -10)),
help="Elevation angles for the camera",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(512, 512),
help="Resolution of the render images (default: (512, 512))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
)
parser.add_argument(
"--texture_size",
type=int,
default=2048,
help="Texture size for texture baking (default: 1024)",
)
parser.add_argument(
"--baker_mode",
type=str,
default="opt",
help="Texture baking mode, `fast` or `opt` (default: opt)",
)
parser.add_argument(
"--opt_step",
type=int,
default=3000,
help="Optimization steps for texture baking (default: 3000)",
)
parser.add_argument(
"--mesh_sipmlify_ratio",
type=float,
default=0.9,
help="Mesh simplification ratio (default: 0.9)",
)
parser.add_argument(
"--delight", action="store_true", help="Use delighting model."
)
parser.add_argument(
"--no_smooth_texture",
action="store_true",
help="Do not smooth the texture.",
)
parser.add_argument(
"--no_coor_trans",
action="store_true",
help="Do not transform the asset coordinate system.",
)
parser.add_argument(
"--save_glb_path", type=str, default=None, help="Save glb path."
)
parser.add_argument("--n_max_faces", type=int, default=30000)
args, unknown = parser.parse_known_args()
return args
@spaces.GPU
def entrypoint(
delight_model: DelightingModel = None,
imagesr_model: ImageRealESRGAN = None,
**kwargs,
) -> trimesh.Trimesh:
"""Entrypoint for texture backprojection from multi-view images.
Args:
delight_model (DelightingModel, optional): Delighting model.
imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
**kwargs: Additional arguments to override CLI.
Returns:
trimesh.Trimesh: Textured mesh.
"""
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
# Setup camera parameters.
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
# GS render.
camera = init_kal_camera(camera_params, flip_az=True)
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
w2cs = matrix_mv.to(camera_params.device)
c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
gs_model = load_gs_model(args.gs_path, pre_quat=[0.0, 0.0, 1.0, 0.0])
multiviews = []
for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
result = gs_model.render(
c2ws[idx],
Ks=Ks,
image_width=camera_params.resolution_hw[1],
image_height=camera_params.resolution_hw[0],
)
color = cv2.cvtColor(result.rgba, cv2.COLOR_BGRA2RGBA)
multiviews.append(Image.fromarray(color))
if args.delight and delight_model is None:
delight_model = DelightingModel()
if args.delight:
for idx in range(len(multiviews)):
multiviews[idx] = delight_model(multiviews[idx])
multiviews = [img.convert("RGB") for img in multiviews]
mesh = trimesh.load(args.mesh_path)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
vertices, scale, center = normalize_vertices_array(mesh.vertices)
# Transform mesh coordinate system by default.
if not args.no_coor_trans:
x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
vertices = vertices @ x_rot
vertices = vertices @ z_rot
faces = mesh.faces.astype(np.int32)
vertices = vertices.astype(np.float32)
if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces:
mesh_fixer = MeshFixer(vertices, faces, args.device)
vertices, faces = mesh_fixer(
filter_ratio=args.mesh_sipmlify_ratio,
max_hole_size=0.04,
resolution=1024,
num_views=1000,
norm_mesh_ratio=0.5,
)
if len(faces) > args.n_max_faces:
mesh_fixer = MeshFixer(vertices, faces, args.device)
vertices, faces = mesh_fixer(
filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2),
max_hole_size=0.04,
resolution=1024,
num_views=1000,
norm_mesh_ratio=0.5,
)
vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
texture_backer = TextureBaker(
vertices,
faces,
uvs,
camera_params,
)
multiviews = [np.array(img) for img in multiviews]
texture = texture_backer.bake_texture(
images=[img[..., :3] for img in multiviews],
texture_size=args.texture_size,
mode=args.baker_mode,
opt_step=args.opt_step,
)
if not args.no_smooth_texture:
texture = post_process_texture(texture)
# Recover mesh original orientation, scale and center.
if not args.no_coor_trans:
vertices = vertices @ np.linalg.inv(z_rot)
vertices = vertices @ np.linalg.inv(x_rot)
vertices = vertices / scale
vertices = vertices + center
textured_mesh = save_mesh_with_mtl(
vertices, faces, uvs, texture, args.output_path
)
if args.save_glb_path is not None:
os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
textured_mesh.export(args.save_glb_path)
return textured_mesh
if __name__ == "__main__":
entrypoint()

View File

@ -66,6 +66,7 @@ __all__ = [
"resize_pil", "resize_pil",
"trellis_preprocess", "trellis_preprocess",
"delete_dir", "delete_dir",
"kaolin_to_opencv_view",
] ]
@ -373,10 +374,18 @@ def _compute_az_el_by_views(
def _compute_cam_pts_by_az_el( def _compute_cam_pts_by_az_el(
azs: np.ndarray, azs: np.ndarray,
els: np.ndarray, els: np.ndarray,
distance: float, distance: float | list[float] | np.ndarray,
extra_pts: np.ndarray = None, extra_pts: np.ndarray = None,
) -> np.ndarray: ) -> np.ndarray:
distances = np.array([distance for _ in range(len(azs))]) if np.isscalar(distance) or isinstance(distance, (float, int)):
distances = np.full(len(azs), distance)
else:
distances = np.array(distance)
if len(distances) != len(azs):
raise ValueError(
f"Length of distances ({len(distances)}) must match length of azs ({len(azs)})"
)
cam_pts = _az_el_to_points(azs, els) * distances[:, None] cam_pts = _az_el_to_points(azs, els) * distances[:, None]
if extra_pts is not None: if extra_pts is not None:
@ -710,7 +719,7 @@ class CameraSetting:
num_images: int num_images: int
elevation: list[float] elevation: list[float]
distance: float distance: float | list[float]
resolution_hw: tuple[int, int] resolution_hw: tuple[int, int]
fov: float fov: float
at: tuple[float, float, float] = field( at: tuple[float, float, float] = field(
@ -824,6 +833,28 @@ def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
return mesh return mesh
def kaolin_to_opencv_view(raw_matrix):
R_orig = raw_matrix[:, :3, :3]
t_orig = raw_matrix[:, :3, 3]
R_target = torch.zeros_like(R_orig)
R_target[:, :, 0] = R_orig[:, :, 2]
R_target[:, :, 1] = R_orig[:, :, 0]
R_target[:, :, 2] = R_orig[:, :, 1]
t_target = t_orig
target_matrix = (
torch.eye(4, device=raw_matrix.device)
.unsqueeze(0)
.repeat(raw_matrix.size(0), 1, 1)
)
target_matrix[:, :3, :3] = R_target
target_matrix[:, :3, 3] = t_target
return target_matrix
def save_mesh_with_mtl( def save_mesh_with_mtl(
vertices: np.ndarray, vertices: np.ndarray,
faces: np.ndarray, faces: np.ndarray,

View File

@ -21,14 +21,18 @@ import struct
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
import cv2
import numpy as np import numpy as np
import torch import torch
from gsplat.cuda._wrapper import spherical_harmonics from gsplat.cuda._wrapper import spherical_harmonics
from gsplat.rendering import rasterization from gsplat.rendering import rasterization
from plyfile import PlyData from plyfile import PlyData
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat from embodied_gen.data.utils import (
gamma_shs,
normalize_vertices_array,
quat_mult,
quat_to_rotmat,
)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -494,6 +498,21 @@ class GaussianOperator(GaussianBase):
) )
def load_gs_model(
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
) -> GaussianOperator:
gs_model = GaussianOperator.load_from_ply(input_gs)
# Normalize vertices to [-1, 1], center to (0, 0, 0).
_, scale, center = normalize_vertices_array(gs_model._means)
scale, center = float(scale), center.tolist()
transpose = [*[v for v in center], *pre_quat]
instance_pose = torch.tensor(transpose).to(gs_model.device)
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
gs_model.rescale(scale)
return gs_model
if __name__ == "__main__": if __name__ == "__main__":
input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply" input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
output_gs = "./gs_model.ply" output_gs = "./gs_model.ply"

View File

@ -26,12 +26,14 @@ import numpy as np
import torch import torch
import trimesh import trimesh
from PIL import Image from PIL import Image
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
from embodied_gen.data.utils import delete_dir, trellis_preprocess from embodied_gen.data.utils import delete_dir, trellis_preprocess
from embodied_gen.models.delight_model import DelightingModel
# from embodied_gen.models.delight_model import DelightingModel
from embodied_gen.models.gs_model import GaussianOperator from embodied_gen.models.gs_model import GaussianOperator
from embodied_gen.models.segment_model import RembgRemover from embodied_gen.models.segment_model import RembgRemover
from embodied_gen.models.sr_model import ImageRealESRGAN
# from embodied_gen.models.sr_model import ImageRealESRGAN
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
from embodied_gen.utils.gpt_clients import GPT_CLIENT from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.log import logger from embodied_gen.utils.log import logger
@ -59,8 +61,8 @@ os.environ["SPCONV_ALGO"] = "native"
random.seed(0) random.seed(0)
logger.info("Loading Image3D Models...") logger.info("Loading Image3D Models...")
DELIGHT = DelightingModel() # DELIGHT = DelightingModel()
IMAGESR_MODEL = ImageRealESRGAN(outscale=4) # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
RBG_REMOVER = RembgRemover() RBG_REMOVER = RembgRemover()
PIPELINE = TrellisImageTo3DPipeline.from_pretrained( PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large" "microsoft/TRELLIS-image-large"
@ -108,9 +110,7 @@ def parse_args():
default=2, default=2,
) )
parser.add_argument("--disable_decompose_convex", action="store_true") parser.add_argument("--disable_decompose_convex", action="store_true")
parser.add_argument( parser.add_argument("--texture_size", type=int, default=2048)
"--texture_wh", type=int, nargs=2, default=[2048, 2048]
)
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
return args return args
@ -248,16 +248,14 @@ def entrypoint(**kwargs):
mesh.export(mesh_obj_path) mesh.export(mesh_obj_path)
mesh = backproject_api( mesh = backproject_api(
delight_model=DELIGHT, # delight_model=DELIGHT,
imagesr_model=IMAGESR_MODEL, # imagesr_model=IMAGESR_MODEL,
color_path=color_path, gs_path=aligned_gs_path,
mesh_path=mesh_obj_path, mesh_path=mesh_obj_path,
output_path=mesh_obj_path, output_path=mesh_obj_path,
skip_fix_mesh=False, skip_fix_mesh=False,
delight=True, texture_size=args.texture_size,
texture_wh=args.texture_wh, delight=False,
elevation=[20, -10, 60, -50],
num_images=12,
) )
mesh_glb_path = os.path.join(output_root, f"{filename}.glb") mesh_glb_path = os.path.join(output_root, f"{filename}.glb")

View File

@ -29,7 +29,7 @@ from embodied_gen.data.utils import (
init_kal_camera, init_kal_camera,
normalize_vertices_array, normalize_vertices_array,
) )
from embodied_gen.models.gs_model import GaussianOperator from embodied_gen.models.gs_model import load_gs_model
from embodied_gen.utils.process_media import combine_images_to_grid from embodied_gen.utils.process_media import combine_images_to_grid
logging.basicConfig( logging.basicConfig(
@ -97,21 +97,6 @@ def parse_args():
return args return args
def load_gs_model(
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
) -> GaussianOperator:
gs_model = GaussianOperator.load_from_ply(input_gs)
# Normalize vertices to [-1, 1], center to (0, 0, 0).
_, scale, center = normalize_vertices_array(gs_model._means)
scale, center = float(scale), center.tolist()
transpose = [*[v for v in center], *pre_quat]
instance_pose = torch.tensor(transpose).to(gs_model.device)
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
gs_model.rescale(scale)
return gs_model
@spaces.GPU @spaces.GPU
def entrypoint(**kwargs) -> None: def entrypoint(**kwargs) -> None:
args = parse_args() args = parse_args()

View File

@ -94,7 +94,6 @@ plugins:
docstring_style: google docstring_style: google
show_source: true show_source: true
merge_init_into_class: true merge_init_into_class: true
show_inherited_members: true
show_root_heading: true show_root_heading: true
show_root_full_path: true show_root_full_path: true