feat(pipe): Faster texture back projection and refine quality checkers. (#29)

This commit is contained in:
Xinjie 2025-07-31 19:53:56 +08:00 committed by GitHub
parent 87ff24dbd4
commit c258ff8666
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 94 additions and 54 deletions

View File

@ -251,6 +251,7 @@ class TextureBacker:
during rendering. Defaults to 0.5. during rendering. Defaults to 0.5.
smooth_texture (bool, optional): If True, apply post-processing (e.g., smooth_texture (bool, optional): If True, apply post-processing (e.g.,
blurring) to the final texture. Defaults to True. blurring) to the final texture. Defaults to True.
inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
""" """
def __init__( def __init__(
@ -262,6 +263,7 @@ class TextureBacker:
bake_angle_thresh: int = 75, bake_angle_thresh: int = 75,
mask_thresh: float = 0.5, mask_thresh: float = 0.5,
smooth_texture: bool = True, smooth_texture: bool = True,
inpaint_smooth: bool = False,
) -> None: ) -> None:
self.camera_params = camera_params self.camera_params = camera_params
self.renderer = None self.renderer = None
@ -271,6 +273,7 @@ class TextureBacker:
self.texture_wh = texture_wh self.texture_wh = texture_wh
self.mask_thresh = mask_thresh self.mask_thresh = mask_thresh
self.smooth_texture = smooth_texture self.smooth_texture = smooth_texture
self.inpaint_smooth = inpaint_smooth
self.bake_angle_thresh = bake_angle_thresh self.bake_angle_thresh = bake_angle_thresh
self.bake_unreliable_kernel_size = int( self.bake_unreliable_kernel_size = int(
@ -446,11 +449,12 @@ class TextureBacker:
def uv_inpaint( def uv_inpaint(
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
) -> np.ndarray: ) -> np.ndarray:
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh) if self.inpaint_smooth:
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
texture, mask = _texture_inpaint_smooth(
texture, mask, vertices, faces, uv_map
)
texture, mask = _texture_inpaint_smooth(
texture, mask, vertices, faces, uv_map
)
texture = texture.clip(0, 1) texture = texture.clip(0, 1)
texture = cv2.inpaint( texture = cv2.inpaint(
(texture * 255).astype(np.uint8), (texture * 255).astype(np.uint8),

View File

@ -54,7 +54,7 @@ __all__ = [
PROMPT_APPEND = ( PROMPT_APPEND = (
"Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, " "Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
"no surroundings, matte, on a plain clean surface, 3D style revealing multiple surfaces" "no surroundings, high-quality appearance, vivid colors, on a plain clean surface, 3D style revealing multiple surfaces"
) )
PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality" PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"

View File

@ -19,6 +19,7 @@ import os
import random import random
from collections import defaultdict from collections import defaultdict
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from embodied_gen.models.image_comm_model import build_hf_image_pipeline from embodied_gen.models.image_comm_model import build_hf_image_pipeline
@ -27,7 +28,10 @@ from embodied_gen.models.text_model import PROMPT_APPEND
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
from embodied_gen.utils.gpt_clients import GPT_CLIENT from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.log import logger from embodied_gen.utils.log import logger
from embodied_gen.utils.process_media import render_asset3d from embodied_gen.utils.process_media import (
check_object_edge_truncated,
render_asset3d,
)
from embodied_gen.validators.quality_checkers import ( from embodied_gen.validators.quality_checkers import (
ImageSegChecker, ImageSegChecker,
SemanticConsistChecker, SemanticConsistChecker,
@ -38,6 +42,13 @@ from embodied_gen.validators.quality_checkers import (
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
random.seed(0) random.seed(0)
logger.info("Loading Models...")
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
PIPE_IMG = build_hf_image_pipeline(os.environ.get("TEXT_MODEL", "sd35"))
BG_REMOVER = RembgRemover()
__all__ = [ __all__ = [
"text_to_image", "text_to_image",
@ -69,6 +80,7 @@ def text_to_image(
f"Image GEN for {os.path.basename(save_path)}\n" f"Image GEN for {os.path.basename(save_path)}\n"
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}" f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
) )
torch.cuda.empty_cache()
images = PIPE_IMG.run( images = PIPE_IMG.run(
f_prompt, f_prompt,
num_inference_steps=img_denoise_step, num_inference_steps=img_denoise_step,
@ -93,16 +105,20 @@ def text_to_image(
seg_flag, seg_result = SEG_CHECKER( seg_flag, seg_result = SEG_CHECKER(
[raw_image, image.convert("RGB")] [raw_image, image.convert("RGB")]
) )
image_mask = np.array(image)[..., -1]
edge_flag = check_object_edge_truncated(image_mask)
logger.warning(
f"SEMANTIC: {semantic_result}. SEG: {seg_result}. EDGE: {edge_flag}"
)
if ( if (
(semantic_flag and seg_flag) (edge_flag and semantic_flag and seg_flag)
or semantic_flag is None or (edge_flag and semantic_flag is None)
or seg_flag is None or (edge_flag and seg_flag is None)
): ):
select_image = [raw_image, image] select_image = [raw_image, image]
success_flag = True success_flag = True
break break
torch.cuda.empty_cache()
seed = random.randint(0, 100000) if seed is not None else None seed = random.randint(0, 100000) if seed is not None else None
return success_flag return success_flag
@ -114,14 +130,6 @@ def text_to_3d(**kwargs) -> dict:
if hasattr(args, k) and v is not None: if hasattr(args, k) and v is not None:
setattr(args, k, v) setattr(args, k, v)
logger.info("Loading Models...")
global SEMANTIC_CHECKER, SEG_CHECKER, TXTGEN_CHECKER, PIPE_IMG, BG_REMOVER
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
PIPE_IMG = build_hf_image_pipeline(args.text_model)
BG_REMOVER = RembgRemover()
if args.asset_names is None or len(args.asset_names) == 0: if args.asset_names is None or len(args.asset_names) == 0:
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))] args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
img_save_dir = os.path.join(args.output_root, "images") img_save_dir = os.path.join(args.output_root, "images")
@ -261,11 +269,6 @@ def parse_args():
default=0, default=0,
help="Random seed for 3D generation", help="Random seed for 3D generation",
) )
parser.add_argument(
"--text_model",
type=str,
default="sd35",
)
parser.add_argument("--keep_intermediate", action="store_true") parser.add_argument("--keep_intermediate", action="store_true")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()

View File

@ -48,6 +48,7 @@ __all__ = [
"SceneTreeVisualizer", "SceneTreeVisualizer",
"is_image_file", "is_image_file",
"parse_text_prompts", "parse_text_prompts",
"check_object_edge_truncated",
] ]
@ -376,6 +377,28 @@ def parse_text_prompts(prompts: list[str]) -> list[str]:
return prompts return prompts
def check_object_edge_truncated(
mask: np.ndarray, edge_threshold: int = 5
) -> bool:
"""Checks if a binary object mask is truncated at the image edges.
Args:
mask: A 2D binary NumPy array where nonzero values indicate the object region.
edge_threshold: Number of pixels from each image edge to consider for truncation.
Defaults to 5.
Returns:
True if the object is fully enclosed (not truncated).
False if the object touches or crosses any image boundary.
"""
top = mask[:edge_threshold, :].any()
bottom = mask[-edge_threshold:, :].any()
left = mask[:, :edge_threshold].any()
right = mask[:, -edge_threshold:].any()
return not (top or bottom or left or right)
if __name__ == "__main__": if __name__ == "__main__":
merge_video_video( merge_video_video(
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa

View File

@ -113,8 +113,8 @@ class MeshGeoChecker(BaseChecker):
Your task is to evaluate the quality of the 3D asset generation, Your task is to evaluate the quality of the 3D asset generation,
including geometry, structure, and appearance, based on the rendered views. including geometry, structure, and appearance, based on the rendered views.
Criteria: Criteria:
- Is the geometry complete and well-formed, without missing parts or redundant structures? - Is the object in the image a single, complete, and well-formed instance,
- Is the geometric structure of the object complete? without truncation, missing parts, overlapping duplicates, or redundant geometry?
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back, - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
soft edges) are acceptable if the object is structurally sound and recognizable. soft edges) are acceptable if the object is structurally sound and recognizable.
- Only evaluate geometry. Do not assess texture quality. - Only evaluate geometry. Do not assess texture quality.
@ -241,10 +241,11 @@ class SemanticConsistChecker(BaseChecker):
Criteria: Criteria:
- The image must visually match the text description in terms of object type, structure, geometry, and color. - The image must visually match the text description in terms of object type, structure, geometry, and color.
- The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number of legs). - The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number
of legs with a reasonable distribution. Count the number of legs visible in the image. (strict) For tables,
fewer than four legs or if the legs are unevenly distributed, are not allowed. Do not assume
hidden legs unless they are clearly visible.)
- Geometric completeness is required: the object must not have missing, truncated, or cropped parts. - Geometric completeness is required: the object must not have missing, truncated, or cropped parts.
- The object must be centered in the image frame with clear margins on all sides,
it should not touch or nearly touch any image edge.
- The image must contain exactly one object. Multiple distinct objects are not allowed. - The image must contain exactly one object. Multiple distinct objects are not allowed.
A single composite object (e.g., a chair with legs) is acceptable. A single composite object (e.g., a chair with legs) is acceptable.
- The object should be shown from a slightly angled (three-quarter) perspective, - The object should be shown from a slightly angled (three-quarter) perspective,

View File

@ -101,34 +101,42 @@ class URDFGenerator(object):
prompt_template = ( prompt_template = (
view_desc view_desc
+ """of the 3D object asset, + """of the 3D object asset,
category: {category}. category: {category}.
You are an expert in 3D object analysis and physical property estimation. You are an expert in 3D object analysis and physical property estimation.
Give the category of this object asset (within 3 words), Give the category of this object asset (within 3 words), (if category is
(if category is already provided, use it directly), already provided, use it directly), accurately describe this 3D object asset (within 15 words),
accurately describe this 3D object asset (within 15 words), Determine the pose of the object in the first image and estimate the true vertical height
and give the recommended geometric height range (unit: meter), (vertical projection) range of the object (in meters), i.e., how tall the object appears from top
weight range (unit: kilogram), the average static friction to bottom in the front view (first) image. also weight range (unit: kilogram), the average
coefficient of the object relative to rubber and the average static friction coefficient of the object relative to rubber and the average dynamic friction
dynamic friction coefficient of the object relative to rubber. coefficient of the object relative to rubber. Return response format as shown in Output Example.
Return response format as shown in Output Example.
IMPORTANT: Output Example:
Inputed images are orthographic projection showing the front, left, right and back views, Category: cup
the first image is always the front view. Use the object's pose and orientation in the Description: shiny golden cup with floral design
rendered images to estimate its **true vertical height as it appears in the image**, Height: 0.1-0.15 m
not the real-world length or width of the object. Weight: 0.3-0.6 kg
For example: Static friction coefficient: 0.6
- A pen standing upright in the front view vertical height: 0.15-0.2 m Dynamic friction coefficient: 0.5
- A pen lying horizontally in the front view vertical height: 0.01-0.02 m
(based on its thickness in the image)
Output Example: IMPORTANT: Estimating Vertical Height from the First (Front View) Image.
Category: cup - The "vertical height" refers to the real-world vertical size of the object
Description: shiny golden cup with floral design as projected in the first image, aligned with the image's vertical axis.
Height: 0.1-0.15 m - For flat objects like plates or disks or book, if their face is visible in the front view,
Weight: 0.3-0.6 kg use the diameter as the vertical height. If the edge is visible, use the thickness instead.
Static friction coefficient: 1.1 - This is not necessarily the full length of the object, but how tall it appears
Dynamic friction coefficient: 0.9 in the first image vertically, based on its pose and orientation.
- For objects(e.g., spoons, forks, writing instruments etc.) at an angle showing in
the first image, tilted at 45° will appear shorter vertically than when upright.
Estimate the vertical projection of their real length based on its pose.
For example:
- A pen standing upright in the first view (aligned with the image's vertical axis)
full body visible in the first image: vertical height 0.14-0.20 m
- A pen lying flat in the front view (showing thickness) vertical height 0.018-0.025 m
- Tilted pen in the first image (e.g., ~45° angle): vertical height 0.07-0.12 m
- Use the rest views(except the first image) to help determine the object's 3D pose and orientation.
Assume the object is in real-world scale and estimate the approximate vertical height
(in meters) based on how large it appears vertically in the first image.
""" """
) )
@ -374,6 +382,7 @@ class URDFGenerator(object):
) )
response = self.gpt_client.query(text_prompt, image_path) response = self.gpt_client.query(text_prompt, image_path)
# logger.info(response)
if response is None: if response is None:
asset_attrs = { asset_attrs = {
"category": category.lower(), "category": category.lower(),