feat(pipe): Faster texture back projection and refine quality checkers. (#29)
This commit is contained in:
parent
87ff24dbd4
commit
c258ff8666
@ -251,6 +251,7 @@ class TextureBacker:
|
|||||||
during rendering. Defaults to 0.5.
|
during rendering. Defaults to 0.5.
|
||||||
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
|
smooth_texture (bool, optional): If True, apply post-processing (e.g.,
|
||||||
blurring) to the final texture. Defaults to True.
|
blurring) to the final texture. Defaults to True.
|
||||||
|
inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -262,6 +263,7 @@ class TextureBacker:
|
|||||||
bake_angle_thresh: int = 75,
|
bake_angle_thresh: int = 75,
|
||||||
mask_thresh: float = 0.5,
|
mask_thresh: float = 0.5,
|
||||||
smooth_texture: bool = True,
|
smooth_texture: bool = True,
|
||||||
|
inpaint_smooth: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.camera_params = camera_params
|
self.camera_params = camera_params
|
||||||
self.renderer = None
|
self.renderer = None
|
||||||
@ -271,6 +273,7 @@ class TextureBacker:
|
|||||||
self.texture_wh = texture_wh
|
self.texture_wh = texture_wh
|
||||||
self.mask_thresh = mask_thresh
|
self.mask_thresh = mask_thresh
|
||||||
self.smooth_texture = smooth_texture
|
self.smooth_texture = smooth_texture
|
||||||
|
self.inpaint_smooth = inpaint_smooth
|
||||||
|
|
||||||
self.bake_angle_thresh = bake_angle_thresh
|
self.bake_angle_thresh = bake_angle_thresh
|
||||||
self.bake_unreliable_kernel_size = int(
|
self.bake_unreliable_kernel_size = int(
|
||||||
@ -446,11 +449,12 @@ class TextureBacker:
|
|||||||
def uv_inpaint(
|
def uv_inpaint(
|
||||||
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
if self.inpaint_smooth:
|
||||||
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
|
||||||
|
|
||||||
texture, mask = _texture_inpaint_smooth(
|
texture, mask = _texture_inpaint_smooth(
|
||||||
texture, mask, vertices, faces, uv_map
|
texture, mask, vertices, faces, uv_map
|
||||||
)
|
)
|
||||||
|
|
||||||
texture = texture.clip(0, 1)
|
texture = texture.clip(0, 1)
|
||||||
texture = cv2.inpaint(
|
texture = cv2.inpaint(
|
||||||
(texture * 255).astype(np.uint8),
|
(texture * 255).astype(np.uint8),
|
||||||
|
|||||||
@ -54,7 +54,7 @@ __all__ = [
|
|||||||
|
|
||||||
PROMPT_APPEND = (
|
PROMPT_APPEND = (
|
||||||
"Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
|
"Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
|
||||||
"no surroundings, matte, on a plain clean surface, 3D style revealing multiple surfaces"
|
"no surroundings, high-quality appearance, vivid colors, on a plain clean surface, 3D style revealing multiple surfaces"
|
||||||
)
|
)
|
||||||
PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"
|
PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
||||||
@ -27,7 +28,10 @@ from embodied_gen.models.text_model import PROMPT_APPEND
|
|||||||
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
|
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
|
||||||
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
||||||
from embodied_gen.utils.log import logger
|
from embodied_gen.utils.log import logger
|
||||||
from embodied_gen.utils.process_media import render_asset3d
|
from embodied_gen.utils.process_media import (
|
||||||
|
check_object_edge_truncated,
|
||||||
|
render_asset3d,
|
||||||
|
)
|
||||||
from embodied_gen.validators.quality_checkers import (
|
from embodied_gen.validators.quality_checkers import (
|
||||||
ImageSegChecker,
|
ImageSegChecker,
|
||||||
SemanticConsistChecker,
|
SemanticConsistChecker,
|
||||||
@ -38,6 +42,13 @@ from embodied_gen.validators.quality_checkers import (
|
|||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
|
|
||||||
|
logger.info("Loading Models...")
|
||||||
|
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
|
||||||
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
||||||
|
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
|
||||||
|
PIPE_IMG = build_hf_image_pipeline(os.environ.get("TEXT_MODEL", "sd35"))
|
||||||
|
BG_REMOVER = RembgRemover()
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"text_to_image",
|
"text_to_image",
|
||||||
@ -69,6 +80,7 @@ def text_to_image(
|
|||||||
f"Image GEN for {os.path.basename(save_path)}\n"
|
f"Image GEN for {os.path.basename(save_path)}\n"
|
||||||
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
|
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
images = PIPE_IMG.run(
|
images = PIPE_IMG.run(
|
||||||
f_prompt,
|
f_prompt,
|
||||||
num_inference_steps=img_denoise_step,
|
num_inference_steps=img_denoise_step,
|
||||||
@ -93,16 +105,20 @@ def text_to_image(
|
|||||||
seg_flag, seg_result = SEG_CHECKER(
|
seg_flag, seg_result = SEG_CHECKER(
|
||||||
[raw_image, image.convert("RGB")]
|
[raw_image, image.convert("RGB")]
|
||||||
)
|
)
|
||||||
|
image_mask = np.array(image)[..., -1]
|
||||||
|
edge_flag = check_object_edge_truncated(image_mask)
|
||||||
|
logger.warning(
|
||||||
|
f"SEMANTIC: {semantic_result}. SEG: {seg_result}. EDGE: {edge_flag}"
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
(semantic_flag and seg_flag)
|
(edge_flag and semantic_flag and seg_flag)
|
||||||
or semantic_flag is None
|
or (edge_flag and semantic_flag is None)
|
||||||
or seg_flag is None
|
or (edge_flag and seg_flag is None)
|
||||||
):
|
):
|
||||||
select_image = [raw_image, image]
|
select_image = [raw_image, image]
|
||||||
success_flag = True
|
success_flag = True
|
||||||
break
|
break
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
seed = random.randint(0, 100000) if seed is not None else None
|
seed = random.randint(0, 100000) if seed is not None else None
|
||||||
|
|
||||||
return success_flag
|
return success_flag
|
||||||
@ -114,14 +130,6 @@ def text_to_3d(**kwargs) -> dict:
|
|||||||
if hasattr(args, k) and v is not None:
|
if hasattr(args, k) and v is not None:
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
|
|
||||||
logger.info("Loading Models...")
|
|
||||||
global SEMANTIC_CHECKER, SEG_CHECKER, TXTGEN_CHECKER, PIPE_IMG, BG_REMOVER
|
|
||||||
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
|
|
||||||
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
|
||||||
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
|
|
||||||
PIPE_IMG = build_hf_image_pipeline(args.text_model)
|
|
||||||
BG_REMOVER = RembgRemover()
|
|
||||||
|
|
||||||
if args.asset_names is None or len(args.asset_names) == 0:
|
if args.asset_names is None or len(args.asset_names) == 0:
|
||||||
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
|
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
|
||||||
img_save_dir = os.path.join(args.output_root, "images")
|
img_save_dir = os.path.join(args.output_root, "images")
|
||||||
@ -261,11 +269,6 @@ def parse_args():
|
|||||||
default=0,
|
default=0,
|
||||||
help="Random seed for 3D generation",
|
help="Random seed for 3D generation",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--text_model",
|
|
||||||
type=str,
|
|
||||||
default="sd35",
|
|
||||||
)
|
|
||||||
parser.add_argument("--keep_intermediate", action="store_true")
|
parser.add_argument("--keep_intermediate", action="store_true")
|
||||||
|
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|||||||
@ -48,6 +48,7 @@ __all__ = [
|
|||||||
"SceneTreeVisualizer",
|
"SceneTreeVisualizer",
|
||||||
"is_image_file",
|
"is_image_file",
|
||||||
"parse_text_prompts",
|
"parse_text_prompts",
|
||||||
|
"check_object_edge_truncated",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -376,6 +377,28 @@ def parse_text_prompts(prompts: list[str]) -> list[str]:
|
|||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
def check_object_edge_truncated(
|
||||||
|
mask: np.ndarray, edge_threshold: int = 5
|
||||||
|
) -> bool:
|
||||||
|
"""Checks if a binary object mask is truncated at the image edges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: A 2D binary NumPy array where nonzero values indicate the object region.
|
||||||
|
edge_threshold: Number of pixels from each image edge to consider for truncation.
|
||||||
|
Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the object is fully enclosed (not truncated).
|
||||||
|
False if the object touches or crosses any image boundary.
|
||||||
|
"""
|
||||||
|
top = mask[:edge_threshold, :].any()
|
||||||
|
bottom = mask[-edge_threshold:, :].any()
|
||||||
|
left = mask[:, :edge_threshold].any()
|
||||||
|
right = mask[:, -edge_threshold:].any()
|
||||||
|
|
||||||
|
return not (top or bottom or left or right)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
merge_video_video(
|
merge_video_video(
|
||||||
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
|
||||||
|
|||||||
@ -113,8 +113,8 @@ class MeshGeoChecker(BaseChecker):
|
|||||||
Your task is to evaluate the quality of the 3D asset generation,
|
Your task is to evaluate the quality of the 3D asset generation,
|
||||||
including geometry, structure, and appearance, based on the rendered views.
|
including geometry, structure, and appearance, based on the rendered views.
|
||||||
Criteria:
|
Criteria:
|
||||||
- Is the geometry complete and well-formed, without missing parts or redundant structures?
|
- Is the object in the image a single, complete, and well-formed instance,
|
||||||
- Is the geometric structure of the object complete?
|
without truncation, missing parts, overlapping duplicates, or redundant geometry?
|
||||||
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
|
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
|
||||||
soft edges) are acceptable if the object is structurally sound and recognizable.
|
soft edges) are acceptable if the object is structurally sound and recognizable.
|
||||||
- Only evaluate geometry. Do not assess texture quality.
|
- Only evaluate geometry. Do not assess texture quality.
|
||||||
@ -241,10 +241,11 @@ class SemanticConsistChecker(BaseChecker):
|
|||||||
|
|
||||||
Criteria:
|
Criteria:
|
||||||
- The image must visually match the text description in terms of object type, structure, geometry, and color.
|
- The image must visually match the text description in terms of object type, structure, geometry, and color.
|
||||||
- The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number of legs).
|
- The object must appear realistic, with reasonable geometry (e.g., a table must have a stable number
|
||||||
|
of legs with a reasonable distribution. Count the number of legs visible in the image. (strict) For tables,
|
||||||
|
fewer than four legs or if the legs are unevenly distributed, are not allowed. Do not assume
|
||||||
|
hidden legs unless they are clearly visible.)
|
||||||
- Geometric completeness is required: the object must not have missing, truncated, or cropped parts.
|
- Geometric completeness is required: the object must not have missing, truncated, or cropped parts.
|
||||||
- The object must be centered in the image frame with clear margins on all sides,
|
|
||||||
it should not touch or nearly touch any image edge.
|
|
||||||
- The image must contain exactly one object. Multiple distinct objects are not allowed.
|
- The image must contain exactly one object. Multiple distinct objects are not allowed.
|
||||||
A single composite object (e.g., a chair with legs) is acceptable.
|
A single composite object (e.g., a chair with legs) is acceptable.
|
||||||
- The object should be shown from a slightly angled (three-quarter) perspective,
|
- The object should be shown from a slightly angled (three-quarter) perspective,
|
||||||
|
|||||||
@ -103,32 +103,40 @@ class URDFGenerator(object):
|
|||||||
+ """of the 3D object asset,
|
+ """of the 3D object asset,
|
||||||
category: {category}.
|
category: {category}.
|
||||||
You are an expert in 3D object analysis and physical property estimation.
|
You are an expert in 3D object analysis and physical property estimation.
|
||||||
Give the category of this object asset (within 3 words),
|
Give the category of this object asset (within 3 words), (if category is
|
||||||
(if category is already provided, use it directly),
|
already provided, use it directly), accurately describe this 3D object asset (within 15 words),
|
||||||
accurately describe this 3D object asset (within 15 words),
|
Determine the pose of the object in the first image and estimate the true vertical height
|
||||||
and give the recommended geometric height range (unit: meter),
|
(vertical projection) range of the object (in meters), i.e., how tall the object appears from top
|
||||||
weight range (unit: kilogram), the average static friction
|
to bottom in the front view (first) image. also weight range (unit: kilogram), the average
|
||||||
coefficient of the object relative to rubber and the average
|
static friction coefficient of the object relative to rubber and the average dynamic friction
|
||||||
dynamic friction coefficient of the object relative to rubber.
|
coefficient of the object relative to rubber. Return response format as shown in Output Example.
|
||||||
Return response format as shown in Output Example.
|
|
||||||
|
|
||||||
IMPORTANT:
|
|
||||||
Inputed images are orthographic projection showing the front, left, right and back views,
|
|
||||||
the first image is always the front view. Use the object's pose and orientation in the
|
|
||||||
rendered images to estimate its **true vertical height as it appears in the image**,
|
|
||||||
not the real-world length or width of the object.
|
|
||||||
For example:
|
|
||||||
- A pen standing upright in the front view → vertical height: 0.15-0.2 m
|
|
||||||
- A pen lying horizontally in the front view → vertical height: 0.01-0.02 m
|
|
||||||
(based on its thickness in the image)
|
|
||||||
|
|
||||||
Output Example:
|
Output Example:
|
||||||
Category: cup
|
Category: cup
|
||||||
Description: shiny golden cup with floral design
|
Description: shiny golden cup with floral design
|
||||||
Height: 0.1-0.15 m
|
Height: 0.1-0.15 m
|
||||||
Weight: 0.3-0.6 kg
|
Weight: 0.3-0.6 kg
|
||||||
Static friction coefficient: 1.1
|
Static friction coefficient: 0.6
|
||||||
Dynamic friction coefficient: 0.9
|
Dynamic friction coefficient: 0.5
|
||||||
|
|
||||||
|
IMPORTANT: Estimating Vertical Height from the First (Front View) Image.
|
||||||
|
- The "vertical height" refers to the real-world vertical size of the object
|
||||||
|
as projected in the first image, aligned with the image's vertical axis.
|
||||||
|
- For flat objects like plates or disks or book, if their face is visible in the front view,
|
||||||
|
use the diameter as the vertical height. If the edge is visible, use the thickness instead.
|
||||||
|
- This is not necessarily the full length of the object, but how tall it appears
|
||||||
|
in the first image vertically, based on its pose and orientation.
|
||||||
|
- For objects(e.g., spoons, forks, writing instruments etc.) at an angle showing in
|
||||||
|
the first image, tilted at 45° will appear shorter vertically than when upright.
|
||||||
|
Estimate the vertical projection of their real length based on its pose.
|
||||||
|
For example:
|
||||||
|
- A pen standing upright in the first view (aligned with the image's vertical axis)
|
||||||
|
full body visible in the first image: → vertical height ≈ 0.14-0.20 m
|
||||||
|
- A pen lying flat in the front view (showing thickness) → vertical height ≈ 0.018-0.025 m
|
||||||
|
- Tilted pen in the first image (e.g., ~45° angle): vertical height ≈ 0.07-0.12 m
|
||||||
|
- Use the rest views(except the first image) to help determine the object's 3D pose and orientation.
|
||||||
|
Assume the object is in real-world scale and estimate the approximate vertical height
|
||||||
|
(in meters) based on how large it appears vertically in the first image.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -374,6 +382,7 @@ class URDFGenerator(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
response = self.gpt_client.query(text_prompt, image_path)
|
response = self.gpt_client.query(text_prompt, image_path)
|
||||||
|
# logger.info(response)
|
||||||
if response is None:
|
if response is None:
|
||||||
asset_attrs = {
|
asset_attrs = {
|
||||||
"category": category.lower(),
|
"category": category.lower(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user