243 lines
8.5 KiB
Python
243 lines
8.5 KiB
Python
# Project EmbodiedGen
|
|
#
|
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
|
|
import logging
|
|
import os
|
|
|
|
from tqdm import tqdm
|
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
|
|
from embodied_gen.utils.process_media import render_asset3d
|
|
from embodied_gen.validators.aesthetic_predictor import AestheticPredictor
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseChecker:
|
|
def __init__(self, prompt: str = None, verbose: bool = False) -> None:
|
|
self.prompt = prompt
|
|
self.verbose = verbose
|
|
|
|
def query(self, *args, **kwargs):
|
|
raise NotImplementedError(
|
|
"Subclasses must implement the query method."
|
|
)
|
|
|
|
def __call__(self, *args, **kwargs) -> bool:
|
|
response = self.query(*args, **kwargs)
|
|
if response is None:
|
|
response = "Error when calling gpt api."
|
|
|
|
if self.verbose and response != "YES":
|
|
logger.info(response)
|
|
|
|
flag = "YES" in response
|
|
response = "YES" if flag else response
|
|
|
|
return flag, response
|
|
|
|
@staticmethod
|
|
def validate(
|
|
checkers: list["BaseChecker"], images_list: list[list[str]]
|
|
) -> list:
|
|
assert len(checkers) == len(images_list)
|
|
results = []
|
|
overall_result = True
|
|
for checker, images in zip(checkers, images_list):
|
|
qa_flag, qa_info = checker(images)
|
|
if isinstance(qa_info, str):
|
|
qa_info = qa_info.replace("\n", ".")
|
|
results.append([checker.__class__.__name__, qa_info])
|
|
if qa_flag is False:
|
|
overall_result = False
|
|
|
|
results.append(["overall", "YES" if overall_result else "NO"])
|
|
|
|
return results
|
|
|
|
|
|
class MeshGeoChecker(BaseChecker):
|
|
"""A geometry quality checker for 3D mesh assets using GPT-based reasoning.
|
|
|
|
This class leverages a multi-modal GPT client to analyze rendered images
|
|
of a 3D object and determine if its geometry is complete.
|
|
|
|
Attributes:
|
|
gpt_client (GPTclient): The GPT client used for multi-modal querying.
|
|
prompt (str): The prompt sent to the GPT model. If not provided, a default one is used.
|
|
verbose (bool): Whether to print debug information during evaluation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
gpt_client: GPTclient,
|
|
prompt: str = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
super().__init__(prompt, verbose)
|
|
self.gpt_client = gpt_client
|
|
if self.prompt is None:
|
|
self.prompt = """
|
|
Refer to the provided multi-view rendering images to evaluate
|
|
whether the geometry of the 3D object asset is complete and
|
|
whether the asset can be placed stably on the ground.
|
|
Return "YES" only if reach the requirments,
|
|
otherwise "NO" and explain the reason very briefly.
|
|
"""
|
|
|
|
def query(self, image_paths: str) -> str:
|
|
# Hardcode tmp because of the openrouter can't input multi images.
|
|
if "openrouter" in self.gpt_client.endpoint:
|
|
from embodied_gen.utils.process_media import (
|
|
combine_images_to_base64,
|
|
)
|
|
|
|
image_paths = combine_images_to_base64(image_paths)
|
|
|
|
return self.gpt_client.query(
|
|
text_prompt=self.prompt,
|
|
image_base64=image_paths,
|
|
)
|
|
|
|
|
|
class ImageSegChecker(BaseChecker):
|
|
"""A segmentation quality checker for 3D assets using GPT-based reasoning.
|
|
|
|
This class compares an original image with its segmented version to
|
|
evaluate whether the segmentation successfully isolates the main object
|
|
with minimal truncation and correct foreground extraction.
|
|
|
|
Attributes:
|
|
gpt_client (GPTclient): GPT client used for multi-modal image analysis.
|
|
prompt (str): The prompt used to guide the GPT model for evaluation.
|
|
verbose (bool): Whether to enable verbose logging.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
gpt_client: GPTclient,
|
|
prompt: str = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
super().__init__(prompt, verbose)
|
|
self.gpt_client = gpt_client
|
|
if self.prompt is None:
|
|
self.prompt = """
|
|
The first image is the original, and the second image is the
|
|
result after segmenting the main object. Evaluate the segmentation
|
|
quality to ensure the main object is clearly segmented without
|
|
significant truncation. Note that the foreground of the object
|
|
needs to be extracted instead of the background.
|
|
Minor imperfections can be ignored. If segmentation is acceptable,
|
|
return "YES" only; otherwise, return "NO" with
|
|
very brief explanation.
|
|
"""
|
|
|
|
def query(self, image_paths: list[str]) -> str:
|
|
if len(image_paths) != 2:
|
|
raise ValueError(
|
|
"ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
|
|
)
|
|
# Hardcode tmp because of the openrouter can't input multi images.
|
|
if "openrouter" in self.gpt_client.endpoint:
|
|
from embodied_gen.utils.process_media import (
|
|
combine_images_to_base64,
|
|
)
|
|
|
|
image_paths = combine_images_to_base64(image_paths)
|
|
|
|
return self.gpt_client.query(
|
|
text_prompt=self.prompt,
|
|
image_base64=image_paths,
|
|
)
|
|
|
|
|
|
class ImageAestheticChecker(BaseChecker):
|
|
"""A class for evaluating the aesthetic quality of images.
|
|
|
|
Attributes:
|
|
clip_model_dir (str): Path to the CLIP model directory.
|
|
sac_model_path (str): Path to the aesthetic predictor model weights.
|
|
thresh (float): Threshold above which images are considered aesthetically acceptable.
|
|
verbose (bool): Whether to print detailed log messages.
|
|
predictor (AestheticPredictor): The model used to predict aesthetic scores.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
clip_model_dir: str = None,
|
|
sac_model_path: str = None,
|
|
thresh: float = 4.50,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
super().__init__(verbose=verbose)
|
|
self.clip_model_dir = clip_model_dir
|
|
self.sac_model_path = sac_model_path
|
|
self.thresh = thresh
|
|
self.predictor = AestheticPredictor(clip_model_dir, sac_model_path)
|
|
|
|
def query(self, image_paths: list[str]) -> float:
|
|
scores = [self.predictor.predict(img_path) for img_path in image_paths]
|
|
return sum(scores) / len(scores)
|
|
|
|
def __call__(self, image_paths: list[str], **kwargs) -> bool:
|
|
avg_score = self.query(image_paths)
|
|
if self.verbose:
|
|
logger.info(f"Average aesthetic score: {avg_score}")
|
|
return avg_score > self.thresh, avg_score
|
|
|
|
|
|
if __name__ == "__main__":
|
|
geo_checker = MeshGeoChecker(GPT_CLIENT)
|
|
seg_checker = ImageSegChecker(GPT_CLIENT)
|
|
aesthetic_checker = ImageAestheticChecker()
|
|
|
|
checkers = [geo_checker, seg_checker, aesthetic_checker]
|
|
|
|
output_root = "outputs/test_gpt"
|
|
|
|
fails = []
|
|
for idx in tqdm(range(150)):
|
|
mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa
|
|
if not os.path.exists(mesh_path):
|
|
continue
|
|
image_paths = render_asset3d(
|
|
mesh_path,
|
|
f"{output_root}/{idx}",
|
|
num_images=8,
|
|
elevation=(30, -30),
|
|
distance=5.5,
|
|
)
|
|
|
|
for cid, checker in enumerate(checkers):
|
|
if isinstance(checker, ImageSegChecker):
|
|
images = [
|
|
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
|
|
f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
|
|
]
|
|
else:
|
|
images = image_paths
|
|
result, info = checker(images)
|
|
logger.info(
|
|
f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
|
|
)
|
|
|
|
if result is False:
|
|
fails.append((idx, cid, info))
|
|
|
|
break
|