Release 3D scene generation pipeline and tag as v0.1.2. --------- Co-authored-by: xinjie.wang <xinjie.wang@gpu-4090-dev015.hogpu.cc>
278 lines
8.6 KiB
Python
278 lines
8.6 KiB
Python
# Project EmbodiedGen
|
|
#
|
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
import argparse
|
|
import os
|
|
import random
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
|
|
from embodied_gen.models.segment_model import RembgRemover
|
|
from embodied_gen.models.text_model import PROMPT_APPEND
|
|
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
|
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|
from embodied_gen.utils.log import logger
|
|
from embodied_gen.utils.process_media import render_asset3d
|
|
from embodied_gen.validators.quality_checkers import (
|
|
ImageSegChecker,
|
|
SemanticConsistChecker,
|
|
TextGenAlignChecker,
|
|
)
|
|
|
|
# Avoid huggingface/tokenizers: The current process just got forked.
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
random.seed(0)
|
|
|
|
|
|
__all__ = [
|
|
"text_to_image",
|
|
"text_to_3d",
|
|
]
|
|
|
|
|
|
def text_to_image(
|
|
prompt: str,
|
|
save_path: str,
|
|
n_retry: int,
|
|
img_denoise_step: int,
|
|
text_guidance_scale: float,
|
|
n_img_sample: int,
|
|
image_hw: tuple[int, int] = (1024, 1024),
|
|
seed: int = None,
|
|
) -> bool:
|
|
select_image = None
|
|
success_flag = False
|
|
assert save_path.endswith(".png"), "Image save path must end with `.png`."
|
|
for try_idx in range(n_retry):
|
|
if select_image is not None:
|
|
select_image[0].save(save_path.replace(".png", "_raw.png"))
|
|
select_image[1].save(save_path)
|
|
break
|
|
|
|
f_prompt = PROMPT_APPEND.format(object=prompt)
|
|
logger.info(
|
|
f"Image GEN for {os.path.basename(save_path)}\n"
|
|
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
|
|
)
|
|
images = PIPE_IMG.run(
|
|
f_prompt,
|
|
num_inference_steps=img_denoise_step,
|
|
guidance_scale=text_guidance_scale,
|
|
num_images_per_prompt=n_img_sample,
|
|
height=image_hw[0],
|
|
width=image_hw[1],
|
|
generator=(
|
|
torch.Generator().manual_seed(seed)
|
|
if seed is not None
|
|
else None
|
|
),
|
|
)
|
|
|
|
for idx in range(len(images)):
|
|
raw_image: Image.Image = images[idx]
|
|
image = BG_REMOVER(raw_image)
|
|
image.save(save_path)
|
|
semantic_flag, semantic_result = SEMANTIC_CHECKER(
|
|
prompt, [image.convert("RGB")]
|
|
)
|
|
seg_flag, seg_result = SEG_CHECKER(
|
|
[raw_image, image.convert("RGB")]
|
|
)
|
|
if (
|
|
(semantic_flag and seg_flag)
|
|
or semantic_flag is None
|
|
or seg_flag is None
|
|
):
|
|
select_image = [raw_image, image]
|
|
success_flag = True
|
|
break
|
|
|
|
torch.cuda.empty_cache()
|
|
seed = random.randint(0, 100000) if seed is not None else None
|
|
|
|
return success_flag
|
|
|
|
|
|
def text_to_3d(**kwargs) -> dict:
|
|
args = parse_args()
|
|
for k, v in kwargs.items():
|
|
if hasattr(args, k) and v is not None:
|
|
setattr(args, k, v)
|
|
|
|
logger.info("Loading Models...")
|
|
global SEMANTIC_CHECKER, SEG_CHECKER, TXTGEN_CHECKER, PIPE_IMG, BG_REMOVER
|
|
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
|
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
|
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
|
|
PIPE_IMG = build_hf_image_pipeline(args.text_model)
|
|
BG_REMOVER = RembgRemover()
|
|
|
|
if args.asset_names is None or len(args.asset_names) == 0:
|
|
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
|
|
img_save_dir = os.path.join(args.output_root, "images")
|
|
asset_save_dir = os.path.join(args.output_root, "asset3d")
|
|
os.makedirs(img_save_dir, exist_ok=True)
|
|
os.makedirs(asset_save_dir, exist_ok=True)
|
|
results = defaultdict(dict)
|
|
for prompt, node in zip(args.prompts, args.asset_names):
|
|
success_flag = False
|
|
n_pipe_retry = args.n_pipe_retry
|
|
seed_img = args.seed_img
|
|
seed_3d = args.seed_3d
|
|
while success_flag is False and n_pipe_retry > 0:
|
|
logger.info(
|
|
f"GEN pipeline for node {node}\n"
|
|
f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}"
|
|
)
|
|
# Text-to-image GEN
|
|
save_node = node.replace(" ", "_")
|
|
gen_image_path = f"{img_save_dir}/{save_node}.png"
|
|
textgen_flag = text_to_image(
|
|
prompt,
|
|
gen_image_path,
|
|
args.n_image_retry,
|
|
args.img_denoise_step,
|
|
args.text_guidance_scale,
|
|
args.n_img_sample,
|
|
seed=seed_img,
|
|
)
|
|
|
|
# Asset 3D GEN
|
|
node_save_dir = f"{asset_save_dir}/{save_node}"
|
|
asset_type = node if "sample3d_" not in node else None
|
|
imageto3d_api(
|
|
image_path=[gen_image_path],
|
|
output_root=node_save_dir,
|
|
asset_type=[asset_type],
|
|
seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
|
|
n_retry=args.n_asset_retry,
|
|
keep_intermediate=args.keep_intermediate,
|
|
)
|
|
mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
|
|
image_path = render_asset3d(
|
|
mesh_path,
|
|
output_root=f"{node_save_dir}/result",
|
|
num_images=6,
|
|
elevation=(30, -30),
|
|
output_subdir="renders",
|
|
no_index_file=True,
|
|
)
|
|
|
|
check_text = asset_type if asset_type is not None else prompt
|
|
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
|
|
logger.warning(
|
|
f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}"
|
|
)
|
|
results["assets"][node] = f"{node_save_dir}/result"
|
|
results["quality"][node] = qa_result
|
|
|
|
if qa_flag is None or qa_flag is True:
|
|
success_flag = True
|
|
break
|
|
|
|
n_pipe_retry -= 1
|
|
seed_img = (
|
|
random.randint(0, 100000) if seed_img is not None else None
|
|
)
|
|
seed_3d = (
|
|
random.randint(0, 100000) if seed_3d is not None else None
|
|
)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
return results
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="3D Layout Generation Config")
|
|
parser.add_argument("--prompts", nargs="+", help="text descriptions")
|
|
parser.add_argument(
|
|
"--output_root",
|
|
type=str,
|
|
help="Directory to save outputs",
|
|
)
|
|
parser.add_argument(
|
|
"--asset_names",
|
|
type=str,
|
|
nargs="+",
|
|
default=None,
|
|
help="Asset names to generate",
|
|
)
|
|
parser.add_argument(
|
|
"--n_img_sample",
|
|
type=int,
|
|
default=3,
|
|
help="Number of image samples to generate",
|
|
)
|
|
parser.add_argument(
|
|
"--text_guidance_scale",
|
|
type=float,
|
|
default=7,
|
|
help="Text-to-image guidance scale",
|
|
)
|
|
parser.add_argument(
|
|
"--img_denoise_step",
|
|
type=int,
|
|
default=25,
|
|
help="Denoising steps for image generation",
|
|
)
|
|
parser.add_argument(
|
|
"--n_image_retry",
|
|
type=int,
|
|
default=2,
|
|
help="Max retry count for image generation",
|
|
)
|
|
parser.add_argument(
|
|
"--n_asset_retry",
|
|
type=int,
|
|
default=2,
|
|
help="Max retry count for 3D generation",
|
|
)
|
|
parser.add_argument(
|
|
"--n_pipe_retry",
|
|
type=int,
|
|
default=1,
|
|
help="Max retry count for 3D asset generation",
|
|
)
|
|
parser.add_argument(
|
|
"--seed_img",
|
|
type=int,
|
|
default=None,
|
|
help="Random seed for image generation",
|
|
)
|
|
parser.add_argument(
|
|
"--seed_3d",
|
|
type=int,
|
|
default=0,
|
|
help="Random seed for 3D generation",
|
|
)
|
|
parser.add_argument(
|
|
"--text_model",
|
|
type=str,
|
|
default="sd35",
|
|
)
|
|
parser.add_argument("--keep_intermediate", action="store_true")
|
|
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
text_to_3d()
|