- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`. - Added unit tests for quality checkers. - `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%.
169 lines
4.6 KiB
Python
169 lines
4.6 KiB
Python
# Project EmbodiedGen
|
|
#
|
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
|
|
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
|
|
StableDiffusionXLPipeline,
|
|
)
|
|
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
|
|
StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
|
|
)
|
|
from tqdm import tqdm
|
|
from embodied_gen.models.text_model import (
|
|
build_text2img_ip_pipeline,
|
|
build_text2img_pipeline,
|
|
text2img_gen,
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Text to Image.")
|
|
parser.add_argument(
|
|
"--prompts",
|
|
type=str,
|
|
nargs="+",
|
|
help="List of prompts (space-separated).",
|
|
)
|
|
parser.add_argument(
|
|
"--ref_image",
|
|
type=str,
|
|
nargs="+",
|
|
help="List of ref_image paths (space-separated).",
|
|
)
|
|
parser.add_argument(
|
|
"--output_root",
|
|
type=str,
|
|
help="Root directory for saving outputs.",
|
|
)
|
|
parser.add_argument(
|
|
"--guidance_scale",
|
|
type=float,
|
|
default=12.0,
|
|
help="Guidance scale for the diffusion model.",
|
|
)
|
|
parser.add_argument(
|
|
"--ref_scale",
|
|
type=float,
|
|
default=0.3,
|
|
help="Reference image scale for the IP adapter.",
|
|
)
|
|
parser.add_argument(
|
|
"--n_sample",
|
|
type=int,
|
|
default=1,
|
|
)
|
|
parser.add_argument(
|
|
"--resolution",
|
|
type=int,
|
|
default=1024,
|
|
)
|
|
parser.add_argument(
|
|
"--infer_step",
|
|
type=int,
|
|
default=50,
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=None,
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def entrypoint(
|
|
pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None,
|
|
**kwargs,
|
|
) -> list[str]:
|
|
args = parse_args()
|
|
for k, v in kwargs.items():
|
|
if hasattr(args, k) and v is not None:
|
|
setattr(args, k, v)
|
|
|
|
prompts = args.prompts
|
|
if len(prompts) == 1 and prompts[0].endswith(".txt"):
|
|
with open(prompts[0], "r") as f:
|
|
prompts = f.readlines()
|
|
prompts = [
|
|
prompt.strip() for prompt in prompts if prompt.strip() != ""
|
|
]
|
|
|
|
os.makedirs(args.output_root, exist_ok=True)
|
|
|
|
ip_img_paths = args.ref_image
|
|
if ip_img_paths is None or len(ip_img_paths) == 0:
|
|
args.ref_scale = 0
|
|
ip_img_paths = [None] * len(prompts)
|
|
elif isinstance(ip_img_paths, str):
|
|
ip_img_paths = [ip_img_paths] * len(prompts)
|
|
elif isinstance(ip_img_paths, list):
|
|
if len(ip_img_paths) == 1:
|
|
ip_img_paths = ip_img_paths * len(prompts)
|
|
else:
|
|
raise ValueError("Invalid ref_image paths.")
|
|
assert len(ip_img_paths) == len(
|
|
prompts
|
|
), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa
|
|
|
|
if pipeline is None:
|
|
if args.ref_scale > 0:
|
|
pipeline = build_text2img_ip_pipeline(
|
|
"weights/Kolors",
|
|
ref_scale=args.ref_scale,
|
|
)
|
|
else:
|
|
pipeline = build_text2img_pipeline("weights/Kolors")
|
|
|
|
for idx, (prompt, ip_img_path) in tqdm(
|
|
enumerate(zip(prompts, ip_img_paths)),
|
|
desc="Generating images",
|
|
total=len(prompts),
|
|
):
|
|
images = text2img_gen(
|
|
prompt=prompt,
|
|
n_sample=args.n_sample,
|
|
guidance_scale=args.guidance_scale,
|
|
pipeline=pipeline,
|
|
ip_image=ip_img_path,
|
|
image_wh=[args.resolution, args.resolution],
|
|
infer_step=args.infer_step,
|
|
seed=args.seed,
|
|
)
|
|
|
|
save_paths = []
|
|
for sub_idx, image in enumerate(images):
|
|
save_path = (
|
|
f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png"
|
|
)
|
|
image.save(save_path)
|
|
save_paths.append(save_path)
|
|
|
|
logger.info(f"Images saved to {args.output_root}")
|
|
|
|
return save_paths
|
|
|
|
|
|
if __name__ == "__main__":
|
|
entrypoint()
|