feat(texture): Refine texture back-projection. (#58)
This commit is contained in:
parent
fb637f9afc
commit
74c3c52a23
@ -331,4 +331,4 @@ EmbodiedGen builds upon the following amazing projects and models:
|
||||
|
||||
## ⚖️ License
|
||||
|
||||
This project is licensed under the [Apache License 2.0](LICENSE). See the `LICENSE` file for details.
|
||||
This project is licensed under the [Apache License 2.0](docs/LICENSE). See the `LICENSE` file for details.
|
||||
|
||||
@ -4,7 +4,7 @@ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
|
||||
lighting_css = """
|
||||
<style>
|
||||
#lighter_mesh canvas {
|
||||
filter: brightness(1.9) !important;
|
||||
filter: brightness(2.0) !important;
|
||||
}
|
||||
</style>
|
||||
"""
|
||||
|
||||
@ -32,8 +32,9 @@ import trimesh
|
||||
from easydict import EasyDict as edict
|
||||
from PIL import Image
|
||||
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
||||
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
|
||||
from embodied_gen.data.differentiable_render import entrypoint as render_api
|
||||
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
||||
from embodied_gen.data.utils import resize_pil, trellis_preprocess, zip_files
|
||||
from embodied_gen.models.delight_model import DelightingModel
|
||||
from embodied_gen.models.gs_model import GaussianOperator
|
||||
from embodied_gen.models.segment_model import (
|
||||
@ -131,8 +132,8 @@ def patched_setup_functions(self):
|
||||
Gaussian.setup_functions = patched_setup_functions
|
||||
|
||||
|
||||
DELIGHT = DelightingModel()
|
||||
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||
# DELIGHT = DelightingModel()
|
||||
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||
# IMAGESR_MODEL = ImageStableSR()
|
||||
if os.getenv("GRADIO_APP") == "imageto3d":
|
||||
RBG_REMOVER = RembgRemover()
|
||||
@ -169,6 +170,8 @@ elif os.getenv("GRADIO_APP") == "textto3d":
|
||||
)
|
||||
os.makedirs(TMP_DIR, exist_ok=True)
|
||||
elif os.getenv("GRADIO_APP") == "texture_edit":
|
||||
DELIGHT = DelightingModel()
|
||||
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||
PIPELINE_IP = build_texture_gen_pipe(
|
||||
base_ckpt_dir="./weights",
|
||||
ip_adapt_scale=0.7,
|
||||
@ -205,7 +208,7 @@ def preprocess_image_fn(
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
|
||||
image_cache = image.copy().resize((512, 512))
|
||||
image_cache = resize_pil(image.copy(), 1024)
|
||||
|
||||
bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
|
||||
image = bg_remover(image)
|
||||
@ -221,7 +224,7 @@ def preprocess_sam_image_fn(
|
||||
image = Image.fromarray(image)
|
||||
|
||||
sam_image = SAM_PREDICTOR.preprocess_image(image)
|
||||
image_cache = Image.fromarray(sam_image).resize((512, 512))
|
||||
image_cache = sam_image.copy()
|
||||
SAM_PREDICTOR.predictor.set_image(sam_image)
|
||||
|
||||
return sam_image, image_cache
|
||||
@ -512,6 +515,60 @@ def extract_3d_representations_v2(
|
||||
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
||||
|
||||
|
||||
def extract_3d_representations_v3(
|
||||
state: dict,
|
||||
enable_delight: bool,
|
||||
texture_size: int,
|
||||
req: gr.Request,
|
||||
):
|
||||
output_root = TMP_DIR
|
||||
user_dir = os.path.join(output_root, str(req.session_hash))
|
||||
gs_model, mesh_model = unpack_state(state, device="cpu")
|
||||
|
||||
filename = "sample"
|
||||
gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
|
||||
gs_model.save_ply(gs_path)
|
||||
|
||||
# Rotate mesh and GS by 90 degrees around Z-axis.
|
||||
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
||||
gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
|
||||
mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
||||
|
||||
# Addtional rotation for GS to align mesh.
|
||||
gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
|
||||
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
||||
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
||||
GaussianOperator.resave_ply(
|
||||
in_ply=gs_path,
|
||||
out_ply=aligned_gs_path,
|
||||
instance_pose=pose,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
mesh = trimesh.Trimesh(
|
||||
vertices=mesh_model.vertices.cpu().numpy(),
|
||||
faces=mesh_model.faces.cpu().numpy(),
|
||||
)
|
||||
mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
|
||||
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
|
||||
|
||||
mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
|
||||
mesh.export(mesh_obj_path)
|
||||
|
||||
mesh = backproject_api_v3(
|
||||
gs_path=aligned_gs_path,
|
||||
mesh_path=mesh_obj_path,
|
||||
output_path=mesh_obj_path,
|
||||
skip_fix_mesh=False,
|
||||
texture_size=texture_size,
|
||||
)
|
||||
|
||||
mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
|
||||
mesh.export(mesh_glb_path)
|
||||
|
||||
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
||||
|
||||
|
||||
def extract_urdf(
|
||||
gs_path: str,
|
||||
mesh_obj_path: str,
|
||||
|
||||
@ -27,7 +27,7 @@ from common import (
|
||||
VERSION,
|
||||
active_btn_by_content,
|
||||
end_session,
|
||||
extract_3d_representations_v2,
|
||||
extract_3d_representations_v3,
|
||||
extract_urdf,
|
||||
get_seed,
|
||||
image_to_3d,
|
||||
@ -179,17 +179,17 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
)
|
||||
|
||||
generate_btn = gr.Button(
|
||||
"🚀 1. Generate(~0.5 mins)",
|
||||
"🚀 1. Generate(~2 mins)",
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
)
|
||||
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
|
||||
with gr.Row():
|
||||
extract_rep3d_btn = gr.Button(
|
||||
"🔍 2. Extract 3D Representation(~2 mins)",
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
)
|
||||
# with gr.Row():
|
||||
# extract_rep3d_btn = gr.Button(
|
||||
# "🔍 2. Extract 3D Representation(~2 mins)",
|
||||
# variant="primary",
|
||||
# interactive=False,
|
||||
# )
|
||||
with gr.Accordion(
|
||||
label="Enter Asset Attributes(optional)", open=False
|
||||
):
|
||||
@ -207,7 +207,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
)
|
||||
with gr.Row():
|
||||
extract_urdf_btn = gr.Button(
|
||||
"🧩 3. Extract URDF with physics(~1 mins)",
|
||||
"🧩 2. Extract URDF with physics(~1 mins)",
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
)
|
||||
@ -230,7 +230,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
)
|
||||
with gr.Row():
|
||||
download_urdf = gr.DownloadButton(
|
||||
label="⬇️ 4. Download URDF",
|
||||
label="⬇️ 3. Download URDF",
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
)
|
||||
@ -326,7 +326,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
image_prompt.change(
|
||||
lambda: tuple(
|
||||
[
|
||||
gr.Button(interactive=False),
|
||||
# gr.Button(interactive=False),
|
||||
gr.Button(interactive=False),
|
||||
gr.Button(interactive=False),
|
||||
None,
|
||||
@ -344,7 +344,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
]
|
||||
),
|
||||
outputs=[
|
||||
extract_rep3d_btn,
|
||||
# extract_rep3d_btn,
|
||||
extract_urdf_btn,
|
||||
download_urdf,
|
||||
model_output_gs,
|
||||
@ -375,7 +375,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
image_prompt_sam.change(
|
||||
lambda: tuple(
|
||||
[
|
||||
gr.Button(interactive=False),
|
||||
# gr.Button(interactive=False),
|
||||
gr.Button(interactive=False),
|
||||
gr.Button(interactive=False),
|
||||
None,
|
||||
@ -394,7 +394,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
]
|
||||
),
|
||||
outputs=[
|
||||
extract_rep3d_btn,
|
||||
# extract_rep3d_btn,
|
||||
extract_urdf_btn,
|
||||
download_urdf,
|
||||
model_output_gs,
|
||||
@ -447,12 +447,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
],
|
||||
outputs=[output_buf, video_output],
|
||||
).success(
|
||||
lambda: gr.Button(interactive=True),
|
||||
outputs=[extract_rep3d_btn],
|
||||
)
|
||||
|
||||
extract_rep3d_btn.click(
|
||||
extract_3d_representations_v2,
|
||||
extract_3d_representations_v3,
|
||||
inputs=[
|
||||
output_buf,
|
||||
project_delight,
|
||||
@ -495,4 +490,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
demo.launch(server_port=8081)
|
||||
|
||||
@ -27,7 +27,7 @@ from common import (
|
||||
VERSION,
|
||||
active_btn_by_text_content,
|
||||
end_session,
|
||||
extract_3d_representations_v2,
|
||||
extract_3d_representations_v3,
|
||||
extract_urdf,
|
||||
get_cached_image,
|
||||
get_seed,
|
||||
@ -178,17 +178,17 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
)
|
||||
|
||||
generate_btn = gr.Button(
|
||||
"🚀 2. Generate 3D(~0.5 mins)",
|
||||
"🚀 2. Generate 3D(~2 mins)",
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
)
|
||||
model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
|
||||
with gr.Row():
|
||||
extract_rep3d_btn = gr.Button(
|
||||
"🔍 3. Extract 3D Representation(~1 mins)",
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
)
|
||||
# with gr.Row():
|
||||
# extract_rep3d_btn = gr.Button(
|
||||
# "🔍 3. Extract 3D Representation(~1 mins)",
|
||||
# variant="primary",
|
||||
# interactive=False,
|
||||
# )
|
||||
with gr.Accordion(
|
||||
label="Enter Asset Attributes(optional)", open=False
|
||||
):
|
||||
@ -206,13 +206,13 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
)
|
||||
with gr.Row():
|
||||
extract_urdf_btn = gr.Button(
|
||||
"🧩 4. Extract URDF with physics(~1 mins)",
|
||||
"🧩 3. Extract URDF with physics(~1 mins)",
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
)
|
||||
with gr.Row():
|
||||
download_urdf = gr.DownloadButton(
|
||||
label="⬇️ 5. Download URDF",
|
||||
label="⬇️ 4. Download URDF",
|
||||
variant="primary",
|
||||
interactive=False,
|
||||
)
|
||||
@ -336,7 +336,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
generate_img_btn.click(
|
||||
lambda: tuple(
|
||||
[
|
||||
gr.Button(interactive=False),
|
||||
# gr.Button(interactive=False),
|
||||
gr.Button(interactive=False),
|
||||
gr.Button(interactive=False),
|
||||
gr.Button(interactive=False),
|
||||
@ -358,7 +358,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
]
|
||||
),
|
||||
outputs=[
|
||||
extract_rep3d_btn,
|
||||
# extract_rep3d_btn,
|
||||
extract_urdf_btn,
|
||||
download_urdf,
|
||||
generate_btn,
|
||||
@ -428,12 +428,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
],
|
||||
outputs=[output_buf, video_output],
|
||||
).success(
|
||||
lambda: gr.Button(interactive=True),
|
||||
outputs=[extract_rep3d_btn],
|
||||
)
|
||||
|
||||
extract_rep3d_btn.click(
|
||||
extract_3d_representations_v2,
|
||||
extract_3d_representations_v3,
|
||||
inputs=[
|
||||
output_buf,
|
||||
project_delight,
|
||||
@ -476,4 +471,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
demo.launch(server_port=8082)
|
||||
|
||||
@ -381,4 +381,4 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
demo.launch(server_port=8083)
|
||||
|
||||
@ -727,7 +727,6 @@ with gr.Blocks(
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(
|
||||
server_name="10.34.8.77",
|
||||
server_port=8088,
|
||||
allowed_paths=[
|
||||
"/horizon-bucket/robot_lab/datasets/embodiedgen/assets"
|
||||
|
||||
@ -14,6 +14,8 @@ conda activate embodiedgen
|
||||
bash install.sh basic
|
||||
```
|
||||
|
||||
Please `huggingface-cli login` to ensure that the ckpts can be downloaded automatically afterwards.
|
||||
|
||||
## ✅ Starting from Docker
|
||||
|
||||
We provide a pre-built Docker image on [Docker Hub](https://hub.docker.com/repository/docker/wangxinjie/embodiedgen) with a configured environment for your convenience. For more details, please refer to [Docker documentation](https://github.com/HorizonRobotics/EmbodiedGen/tree/master/docker).
|
||||
|
||||
@ -589,6 +589,8 @@ class MeshtoUSDConverter(AssetConverterBase):
|
||||
stage = Usd.Stage.Open(usd_path)
|
||||
layer = stage.GetRootLayer()
|
||||
with Usd.EditContext(stage, layer):
|
||||
base_prim = stage.GetPseudoRoot().GetChildren()[0]
|
||||
base_prim.SetMetadata("kind", "component")
|
||||
for prim in stage.Traverse():
|
||||
# Change texture path to relative path.
|
||||
if prim.GetName() == "material_0":
|
||||
|
||||
@ -34,6 +34,7 @@ from embodied_gen.data.utils import (
|
||||
CameraSetting,
|
||||
get_images_from_grid,
|
||||
init_kal_camera,
|
||||
kaolin_to_opencv_view,
|
||||
normalize_vertices_array,
|
||||
post_process_texture,
|
||||
save_mesh_with_mtl,
|
||||
@ -306,28 +307,6 @@ class TextureBaker(object):
|
||||
raise ValueError(f"Unknown mode: {mode}")
|
||||
|
||||
|
||||
def kaolin_to_opencv_view(raw_matrix):
|
||||
R_orig = raw_matrix[:, :3, :3]
|
||||
t_orig = raw_matrix[:, :3, 3]
|
||||
|
||||
R_target = torch.zeros_like(R_orig)
|
||||
R_target[:, :, 0] = R_orig[:, :, 2]
|
||||
R_target[:, :, 1] = R_orig[:, :, 0]
|
||||
R_target[:, :, 2] = R_orig[:, :, 1]
|
||||
|
||||
t_target = t_orig
|
||||
|
||||
target_matrix = (
|
||||
torch.eye(4, device=raw_matrix.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(raw_matrix.size(0), 1, 1)
|
||||
)
|
||||
target_matrix[:, :3, :3] = R_target
|
||||
target_matrix[:, :3, 3] = t_target
|
||||
|
||||
return target_matrix
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Render settings")
|
||||
|
||||
|
||||
558
embodied_gen/data/backproject_v3.py
Normal file
558
embodied_gen/data/backproject_v3.py
Normal file
@ -0,0 +1,558 @@
|
||||
# Project EmbodiedGen
|
||||
#
|
||||
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import nvdiffrast.torch as dr
|
||||
import spaces
|
||||
import torch
|
||||
import trimesh
|
||||
import utils3d
|
||||
import xatlas
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from embodied_gen.data.mesh_operator import MeshFixer
|
||||
from embodied_gen.data.utils import (
|
||||
CameraSetting,
|
||||
init_kal_camera,
|
||||
kaolin_to_opencv_view,
|
||||
normalize_vertices_array,
|
||||
post_process_texture,
|
||||
save_mesh_with_mtl,
|
||||
)
|
||||
from embodied_gen.models.delight_model import DelightingModel
|
||||
from embodied_gen.models.gs_model import load_gs_model
|
||||
from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TextureBaker",
|
||||
]
|
||||
|
||||
|
||||
class TextureBaker(object):
|
||||
"""Baking textures onto a mesh from multiple observations.
|
||||
|
||||
This class take 3D mesh data, camera settings and texture baking parameters
|
||||
to generate texture map by projecting images to the mesh from diff views.
|
||||
It supports both a fast texture baking approach and a more optimized method
|
||||
with total variation regularization.
|
||||
|
||||
Attributes:
|
||||
vertices (torch.Tensor): The vertices of the mesh.
|
||||
faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
|
||||
uvs (torch.Tensor): The UV coordinates of the mesh.
|
||||
camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
|
||||
device (str): The device to run computations on ("cpu" or "cuda").
|
||||
w2cs (torch.Tensor): World-to-camera transformation matrices.
|
||||
projections (torch.Tensor): Camera projection matrices.
|
||||
|
||||
Example:
|
||||
>>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
|
||||
>>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
|
||||
>>> images = get_images_from_grid(args.color_path, image_size)
|
||||
>>> texture = texture_backer.bake_texture(
|
||||
... images, texture_size=args.texture_size, mode=args.baker_mode
|
||||
... )
|
||||
>>> texture = post_process_texture(texture)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vertices: np.ndarray,
|
||||
faces: np.ndarray,
|
||||
uvs: np.ndarray,
|
||||
camera_params: CameraSetting,
|
||||
device: str = "cuda",
|
||||
) -> None:
|
||||
self.vertices = (
|
||||
torch.tensor(vertices, device=device)
|
||||
if isinstance(vertices, np.ndarray)
|
||||
else vertices.to(device)
|
||||
)
|
||||
self.faces = (
|
||||
torch.tensor(faces.astype(np.int32), device=device)
|
||||
if isinstance(faces, np.ndarray)
|
||||
else faces.to(device)
|
||||
)
|
||||
self.uvs = (
|
||||
torch.tensor(uvs, device=device)
|
||||
if isinstance(uvs, np.ndarray)
|
||||
else uvs.to(device)
|
||||
)
|
||||
self.camera_params = camera_params
|
||||
self.device = device
|
||||
|
||||
camera = init_kal_camera(camera_params)
|
||||
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
|
||||
matrix_mv = kaolin_to_opencv_view(matrix_mv)
|
||||
matrix_p = (
|
||||
camera.intrinsics.projection_matrix()
|
||||
) # (n_cam 4 4) cam2pixel
|
||||
self.w2cs = matrix_mv.to(self.device)
|
||||
self.projections = matrix_p.to(self.device)
|
||||
|
||||
@staticmethod
|
||||
def parametrize_mesh(
|
||||
vertices: np.array, faces: np.array
|
||||
) -> Union[np.array, np.array, np.array]:
|
||||
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
||||
|
||||
vertices = vertices[vmapping]
|
||||
faces = indices
|
||||
|
||||
return vertices, faces, uvs
|
||||
|
||||
def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
|
||||
texture = torch.zeros(
|
||||
(texture_size * texture_size, 3), dtype=torch.float32
|
||||
).cuda()
|
||||
texture_weights = torch.zeros(
|
||||
(texture_size * texture_size), dtype=torch.float32
|
||||
).cuda()
|
||||
rastctx = utils3d.torch.RastContext(backend="cuda")
|
||||
for observation, w2c, projection in tqdm(
|
||||
zip(observations, w2cs, projections),
|
||||
total=len(observations),
|
||||
desc="Texture baking (fast)",
|
||||
):
|
||||
with torch.no_grad():
|
||||
rast = utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx,
|
||||
self.vertices[None],
|
||||
self.faces,
|
||||
observation.shape[1],
|
||||
observation.shape[0],
|
||||
uv=self.uvs[None],
|
||||
view=w2c,
|
||||
projection=projection,
|
||||
)
|
||||
uv_map = rast["uv"][0].detach().flip(0)
|
||||
mask = rast["mask"][0].detach().bool() & masks[0]
|
||||
|
||||
# nearest neighbor interpolation
|
||||
uv_map = (uv_map * texture_size).floor().long()
|
||||
obs = observation[mask]
|
||||
uv_map = uv_map[mask]
|
||||
idx = (
|
||||
uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
|
||||
)
|
||||
texture = texture.scatter_add(
|
||||
0, idx.view(-1, 1).expand(-1, 3), obs
|
||||
)
|
||||
texture_weights = texture_weights.scatter_add(
|
||||
0,
|
||||
idx,
|
||||
torch.ones(
|
||||
(obs.shape[0]), dtype=torch.float32, device=texture.device
|
||||
),
|
||||
)
|
||||
|
||||
mask = texture_weights > 0
|
||||
texture[mask] /= texture_weights[mask][:, None]
|
||||
texture = np.clip(
|
||||
texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
|
||||
0,
|
||||
255,
|
||||
).astype(np.uint8)
|
||||
|
||||
# inpaint
|
||||
mask = (
|
||||
(texture_weights == 0)
|
||||
.cpu()
|
||||
.numpy()
|
||||
.astype(np.uint8)
|
||||
.reshape(texture_size, texture_size)
|
||||
)
|
||||
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
||||
|
||||
return texture
|
||||
|
||||
def _bake_opt(
|
||||
self,
|
||||
observations,
|
||||
w2cs,
|
||||
projections,
|
||||
texture_size,
|
||||
lambda_tv,
|
||||
masks,
|
||||
total_steps,
|
||||
):
|
||||
rastctx = utils3d.torch.RastContext(backend="cuda")
|
||||
observations = [observations.flip(0) for observations in observations]
|
||||
masks = [m.flip(0) for m in masks]
|
||||
_uv = []
|
||||
_uv_dr = []
|
||||
for observation, w2c, projection in tqdm(
|
||||
zip(observations, w2cs, projections),
|
||||
total=len(w2cs),
|
||||
):
|
||||
with torch.no_grad():
|
||||
rast = utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx,
|
||||
self.vertices[None],
|
||||
self.faces,
|
||||
observation.shape[1],
|
||||
observation.shape[0],
|
||||
uv=self.uvs[None],
|
||||
view=w2c,
|
||||
projection=projection,
|
||||
)
|
||||
_uv.append(rast["uv"].detach())
|
||||
_uv_dr.append(rast["uv_dr"].detach())
|
||||
|
||||
texture = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
(1, texture_size, texture_size, 3), dtype=torch.float32
|
||||
).cuda()
|
||||
)
|
||||
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
|
||||
|
||||
def cosine_anealing(step, total_steps, start_lr, end_lr):
|
||||
return end_lr + 0.5 * (start_lr - end_lr) * (
|
||||
1 + np.cos(np.pi * step / total_steps)
|
||||
)
|
||||
|
||||
def tv_loss(texture):
|
||||
return torch.nn.functional.l1_loss(
|
||||
texture[:, :-1, :, :], texture[:, 1:, :, :]
|
||||
) + torch.nn.functional.l1_loss(
|
||||
texture[:, :, :-1, :], texture[:, :, 1:, :]
|
||||
)
|
||||
|
||||
with tqdm(total=total_steps, desc="Texture baking") as pbar:
|
||||
for step in range(total_steps):
|
||||
optimizer.zero_grad()
|
||||
selected = np.random.randint(0, len(w2cs))
|
||||
uv, uv_dr, observation, mask = (
|
||||
_uv[selected],
|
||||
_uv_dr[selected],
|
||||
observations[selected],
|
||||
masks[selected],
|
||||
)
|
||||
render = dr.texture(texture, uv, uv_dr)[0]
|
||||
loss = torch.nn.functional.l1_loss(
|
||||
render[mask], observation[mask]
|
||||
)
|
||||
if lambda_tv > 0:
|
||||
loss += lambda_tv * tv_loss(texture)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
optimizer.param_groups[0]["lr"] = cosine_anealing(
|
||||
step, total_steps, 1e-2, 1e-5
|
||||
)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
pbar.update()
|
||||
|
||||
texture = np.clip(
|
||||
texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
|
||||
).astype(np.uint8)
|
||||
mask = 1 - utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx,
|
||||
(self.uvs * 2 - 1)[None],
|
||||
self.faces,
|
||||
texture_size,
|
||||
texture_size,
|
||||
)["mask"][0].detach().cpu().numpy().astype(np.uint8)
|
||||
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
||||
|
||||
return texture
|
||||
|
||||
def bake_texture(
|
||||
self,
|
||||
images: list[np.array],
|
||||
texture_size: int = 1024,
|
||||
mode: Literal["fast", "opt"] = "opt",
|
||||
lambda_tv: float = 1e-2,
|
||||
opt_step: int = 2000,
|
||||
):
|
||||
masks = [np.any(img > 0, axis=-1) for img in images]
|
||||
masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
|
||||
images = [
|
||||
torch.tensor(obs / 255.0).float().to(self.device) for obs in images
|
||||
]
|
||||
|
||||
if mode == "fast":
|
||||
return self._bake_fast(
|
||||
images, self.w2cs, self.projections, texture_size, masks
|
||||
)
|
||||
elif mode == "opt":
|
||||
return self._bake_opt(
|
||||
images,
|
||||
self.w2cs,
|
||||
self.projections,
|
||||
texture_size,
|
||||
lambda_tv,
|
||||
masks,
|
||||
opt_step,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {mode}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parses command-line arguments for texture backprojection.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Backproject texture")
|
||||
parser.add_argument(
|
||||
"--gs_path",
|
||||
type=str,
|
||||
help="Path to the GS.ply gaussian splatting model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mesh_path",
|
||||
type=str,
|
||||
help="Mesh path, .obj, .glb or .ply",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
help="Output mesh path with suffix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_images",
|
||||
type=int,
|
||||
default=180,
|
||||
help="Number of images to render.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--elevation",
|
||||
nargs="+",
|
||||
type=float,
|
||||
default=list(range(85, -90, -10)),
|
||||
help="Elevation angles for the camera",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance",
|
||||
type=float,
|
||||
default=5,
|
||||
help="Camera distance (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resolution_hw",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=(512, 512),
|
||||
help="Resolution of the render images (default: (512, 512))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fov",
|
||||
type=float,
|
||||
default=30,
|
||||
help="Field of view in degrees (default: 30)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
choices=["cpu", "cuda"],
|
||||
default="cuda",
|
||||
help="Device to run on (default: `cuda`)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--texture_size",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Texture size for texture baking (default: 1024)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--baker_mode",
|
||||
type=str,
|
||||
default="opt",
|
||||
help="Texture baking mode, `fast` or `opt` (default: opt)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opt_step",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="Optimization steps for texture baking (default: 3000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mesh_sipmlify_ratio",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Mesh simplification ratio (default: 0.9)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delight", action="store_true", help="Use delighting model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_smooth_texture",
|
||||
action="store_true",
|
||||
help="Do not smooth the texture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_coor_trans",
|
||||
action="store_true",
|
||||
help="Do not transform the asset coordinate system.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_glb_path", type=str, default=None, help="Save glb path."
|
||||
)
|
||||
parser.add_argument("--n_max_faces", type=int, default=30000)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@spaces.GPU
|
||||
def entrypoint(
|
||||
delight_model: DelightingModel = None,
|
||||
imagesr_model: ImageRealESRGAN = None,
|
||||
**kwargs,
|
||||
) -> trimesh.Trimesh:
|
||||
"""Entrypoint for texture backprojection from multi-view images.
|
||||
|
||||
Args:
|
||||
delight_model (DelightingModel, optional): Delighting model.
|
||||
imagesr_model (ImageRealESRGAN, optional): Super-resolution model.
|
||||
**kwargs: Additional arguments to override CLI.
|
||||
|
||||
Returns:
|
||||
trimesh.Trimesh: Textured mesh.
|
||||
"""
|
||||
args = parse_args()
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(args, k) and v is not None:
|
||||
setattr(args, k, v)
|
||||
|
||||
# Setup camera parameters.
|
||||
camera_params = CameraSetting(
|
||||
num_images=args.num_images,
|
||||
elevation=args.elevation,
|
||||
distance=args.distance,
|
||||
resolution_hw=args.resolution_hw,
|
||||
fov=math.radians(args.fov),
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
# GS render.
|
||||
camera = init_kal_camera(camera_params, flip_az=True)
|
||||
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
|
||||
matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
|
||||
w2cs = matrix_mv.to(camera_params.device)
|
||||
c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
|
||||
Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
|
||||
gs_model = load_gs_model(args.gs_path, pre_quat=[0.0, 0.0, 1.0, 0.0])
|
||||
multiviews = []
|
||||
for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
|
||||
result = gs_model.render(
|
||||
c2ws[idx],
|
||||
Ks=Ks,
|
||||
image_width=camera_params.resolution_hw[1],
|
||||
image_height=camera_params.resolution_hw[0],
|
||||
)
|
||||
color = cv2.cvtColor(result.rgba, cv2.COLOR_BGRA2RGBA)
|
||||
multiviews.append(Image.fromarray(color))
|
||||
|
||||
if args.delight and delight_model is None:
|
||||
delight_model = DelightingModel()
|
||||
|
||||
if args.delight:
|
||||
for idx in range(len(multiviews)):
|
||||
multiviews[idx] = delight_model(multiviews[idx])
|
||||
|
||||
multiviews = [img.convert("RGB") for img in multiviews]
|
||||
|
||||
mesh = trimesh.load(args.mesh_path)
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
mesh = mesh.dump(concatenate=True)
|
||||
|
||||
vertices, scale, center = normalize_vertices_array(mesh.vertices)
|
||||
|
||||
# Transform mesh coordinate system by default.
|
||||
if not args.no_coor_trans:
|
||||
x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
|
||||
z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
|
||||
vertices = vertices @ x_rot
|
||||
vertices = vertices @ z_rot
|
||||
|
||||
faces = mesh.faces.astype(np.int32)
|
||||
vertices = vertices.astype(np.float32)
|
||||
|
||||
if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces:
|
||||
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
||||
vertices, faces = mesh_fixer(
|
||||
filter_ratio=args.mesh_sipmlify_ratio,
|
||||
max_hole_size=0.04,
|
||||
resolution=1024,
|
||||
num_views=1000,
|
||||
norm_mesh_ratio=0.5,
|
||||
)
|
||||
if len(faces) > args.n_max_faces:
|
||||
mesh_fixer = MeshFixer(vertices, faces, args.device)
|
||||
vertices, faces = mesh_fixer(
|
||||
filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2),
|
||||
max_hole_size=0.04,
|
||||
resolution=1024,
|
||||
num_views=1000,
|
||||
norm_mesh_ratio=0.5,
|
||||
)
|
||||
|
||||
vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
|
||||
texture_backer = TextureBaker(
|
||||
vertices,
|
||||
faces,
|
||||
uvs,
|
||||
camera_params,
|
||||
)
|
||||
|
||||
multiviews = [np.array(img) for img in multiviews]
|
||||
texture = texture_backer.bake_texture(
|
||||
images=[img[..., :3] for img in multiviews],
|
||||
texture_size=args.texture_size,
|
||||
mode=args.baker_mode,
|
||||
opt_step=args.opt_step,
|
||||
)
|
||||
if not args.no_smooth_texture:
|
||||
texture = post_process_texture(texture)
|
||||
|
||||
# Recover mesh original orientation, scale and center.
|
||||
if not args.no_coor_trans:
|
||||
vertices = vertices @ np.linalg.inv(z_rot)
|
||||
vertices = vertices @ np.linalg.inv(x_rot)
|
||||
vertices = vertices / scale
|
||||
vertices = vertices + center
|
||||
|
||||
textured_mesh = save_mesh_with_mtl(
|
||||
vertices, faces, uvs, texture, args.output_path
|
||||
)
|
||||
if args.save_glb_path is not None:
|
||||
os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
|
||||
textured_mesh.export(args.save_glb_path)
|
||||
|
||||
return textured_mesh
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
entrypoint()
|
||||
@ -66,6 +66,7 @@ __all__ = [
|
||||
"resize_pil",
|
||||
"trellis_preprocess",
|
||||
"delete_dir",
|
||||
"kaolin_to_opencv_view",
|
||||
]
|
||||
|
||||
|
||||
@ -373,10 +374,18 @@ def _compute_az_el_by_views(
|
||||
def _compute_cam_pts_by_az_el(
|
||||
azs: np.ndarray,
|
||||
els: np.ndarray,
|
||||
distance: float,
|
||||
distance: float | list[float] | np.ndarray,
|
||||
extra_pts: np.ndarray = None,
|
||||
) -> np.ndarray:
|
||||
distances = np.array([distance for _ in range(len(azs))])
|
||||
if np.isscalar(distance) or isinstance(distance, (float, int)):
|
||||
distances = np.full(len(azs), distance)
|
||||
else:
|
||||
distances = np.array(distance)
|
||||
if len(distances) != len(azs):
|
||||
raise ValueError(
|
||||
f"Length of distances ({len(distances)}) must match length of azs ({len(azs)})"
|
||||
)
|
||||
|
||||
cam_pts = _az_el_to_points(azs, els) * distances[:, None]
|
||||
|
||||
if extra_pts is not None:
|
||||
@ -710,7 +719,7 @@ class CameraSetting:
|
||||
|
||||
num_images: int
|
||||
elevation: list[float]
|
||||
distance: float
|
||||
distance: float | list[float]
|
||||
resolution_hw: tuple[int, int]
|
||||
fov: float
|
||||
at: tuple[float, float, float] = field(
|
||||
@ -824,6 +833,28 @@ def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
|
||||
return mesh
|
||||
|
||||
|
||||
def kaolin_to_opencv_view(raw_matrix):
|
||||
R_orig = raw_matrix[:, :3, :3]
|
||||
t_orig = raw_matrix[:, :3, 3]
|
||||
|
||||
R_target = torch.zeros_like(R_orig)
|
||||
R_target[:, :, 0] = R_orig[:, :, 2]
|
||||
R_target[:, :, 1] = R_orig[:, :, 0]
|
||||
R_target[:, :, 2] = R_orig[:, :, 1]
|
||||
|
||||
t_target = t_orig
|
||||
|
||||
target_matrix = (
|
||||
torch.eye(4, device=raw_matrix.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(raw_matrix.size(0), 1, 1)
|
||||
)
|
||||
target_matrix[:, :3, :3] = R_target
|
||||
target_matrix[:, :3, 3] = t_target
|
||||
|
||||
return target_matrix
|
||||
|
||||
|
||||
def save_mesh_with_mtl(
|
||||
vertices: np.ndarray,
|
||||
faces: np.ndarray,
|
||||
|
||||
@ -21,14 +21,18 @@ import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from gsplat.cuda._wrapper import spherical_harmonics
|
||||
from gsplat.rendering import rasterization
|
||||
from plyfile import PlyData
|
||||
from scipy.spatial.transform import Rotation
|
||||
from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat
|
||||
from embodied_gen.data.utils import (
|
||||
gamma_shs,
|
||||
normalize_vertices_array,
|
||||
quat_mult,
|
||||
quat_to_rotmat,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -494,6 +498,21 @@ class GaussianOperator(GaussianBase):
|
||||
)
|
||||
|
||||
|
||||
def load_gs_model(
|
||||
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
|
||||
) -> GaussianOperator:
|
||||
gs_model = GaussianOperator.load_from_ply(input_gs)
|
||||
# Normalize vertices to [-1, 1], center to (0, 0, 0).
|
||||
_, scale, center = normalize_vertices_array(gs_model._means)
|
||||
scale, center = float(scale), center.tolist()
|
||||
transpose = [*[v for v in center], *pre_quat]
|
||||
instance_pose = torch.tensor(transpose).to(gs_model.device)
|
||||
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
|
||||
gs_model.rescale(scale)
|
||||
|
||||
return gs_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
|
||||
output_gs = "./gs_model.ply"
|
||||
|
||||
@ -26,12 +26,14 @@ import numpy as np
|
||||
import torch
|
||||
import trimesh
|
||||
from PIL import Image
|
||||
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
||||
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
|
||||
from embodied_gen.data.utils import delete_dir, trellis_preprocess
|
||||
from embodied_gen.models.delight_model import DelightingModel
|
||||
|
||||
# from embodied_gen.models.delight_model import DelightingModel
|
||||
from embodied_gen.models.gs_model import GaussianOperator
|
||||
from embodied_gen.models.segment_model import RembgRemover
|
||||
from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||
|
||||
# from embodied_gen.models.sr_model import ImageRealESRGAN
|
||||
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||
from embodied_gen.utils.log import logger
|
||||
@ -59,8 +61,8 @@ os.environ["SPCONV_ALGO"] = "native"
|
||||
random.seed(0)
|
||||
|
||||
logger.info("Loading Image3D Models...")
|
||||
DELIGHT = DelightingModel()
|
||||
IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||
# DELIGHT = DelightingModel()
|
||||
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
||||
RBG_REMOVER = RembgRemover()
|
||||
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
||||
"microsoft/TRELLIS-image-large"
|
||||
@ -108,9 +110,7 @@ def parse_args():
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument("--disable_decompose_convex", action="store_true")
|
||||
parser.add_argument(
|
||||
"--texture_wh", type=int, nargs=2, default=[2048, 2048]
|
||||
)
|
||||
parser.add_argument("--texture_size", type=int, default=2048)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
return args
|
||||
@ -248,16 +248,14 @@ def entrypoint(**kwargs):
|
||||
mesh.export(mesh_obj_path)
|
||||
|
||||
mesh = backproject_api(
|
||||
delight_model=DELIGHT,
|
||||
imagesr_model=IMAGESR_MODEL,
|
||||
color_path=color_path,
|
||||
# delight_model=DELIGHT,
|
||||
# imagesr_model=IMAGESR_MODEL,
|
||||
gs_path=aligned_gs_path,
|
||||
mesh_path=mesh_obj_path,
|
||||
output_path=mesh_obj_path,
|
||||
skip_fix_mesh=False,
|
||||
delight=True,
|
||||
texture_wh=args.texture_wh,
|
||||
elevation=[20, -10, 60, -50],
|
||||
num_images=12,
|
||||
texture_size=args.texture_size,
|
||||
delight=False,
|
||||
)
|
||||
|
||||
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
|
||||
|
||||
@ -29,7 +29,7 @@ from embodied_gen.data.utils import (
|
||||
init_kal_camera,
|
||||
normalize_vertices_array,
|
||||
)
|
||||
from embodied_gen.models.gs_model import GaussianOperator
|
||||
from embodied_gen.models.gs_model import load_gs_model
|
||||
from embodied_gen.utils.process_media import combine_images_to_grid
|
||||
|
||||
logging.basicConfig(
|
||||
@ -97,21 +97,6 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def load_gs_model(
|
||||
input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
|
||||
) -> GaussianOperator:
|
||||
gs_model = GaussianOperator.load_from_ply(input_gs)
|
||||
# Normalize vertices to [-1, 1], center to (0, 0, 0).
|
||||
_, scale, center = normalize_vertices_array(gs_model._means)
|
||||
scale, center = float(scale), center.tolist()
|
||||
transpose = [*[v for v in center], *pre_quat]
|
||||
instance_pose = torch.tensor(transpose).to(gs_model.device)
|
||||
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
|
||||
gs_model.rescale(scale)
|
||||
|
||||
return gs_model
|
||||
|
||||
|
||||
@spaces.GPU
|
||||
def entrypoint(**kwargs) -> None:
|
||||
args = parse_args()
|
||||
|
||||
@ -94,7 +94,6 @@ plugins:
|
||||
docstring_style: google
|
||||
show_source: true
|
||||
merge_init_into_class: true
|
||||
show_inherited_members: true
|
||||
show_root_heading: true
|
||||
show_root_full_path: true
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user