- Added intelligent quality checkers and auto-retry pipeline for `image-to-3d` and `text-to-3d`. - Added unit tests for quality checkers. - `text-to-3d` now supports more `text-to-image` models, pipeline success rate improved to 94%.
237 lines
6.9 KiB
Python
237 lines
6.9 KiB
Python
# Project EmbodiedGen
|
|
#
|
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
# Text-to-Image generation models from Hugging Face community.
|
|
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
from diffusers import (
|
|
ChromaPipeline,
|
|
Cosmos2TextToImagePipeline,
|
|
DPMSolverMultistepScheduler,
|
|
FluxPipeline,
|
|
KolorsPipeline,
|
|
StableDiffusion3Pipeline,
|
|
)
|
|
from diffusers.quantizers import PipelineQuantizationConfig
|
|
from huggingface_hub import snapshot_download
|
|
from PIL import Image
|
|
from transformers import AutoModelForCausalLM, SiglipProcessor
|
|
|
|
__all__ = [
|
|
"build_hf_image_pipeline",
|
|
]
|
|
|
|
|
|
class BasePipelineLoader(ABC):
|
|
def __init__(self, device="cuda"):
|
|
self.device = device
|
|
|
|
@abstractmethod
|
|
def load(self):
|
|
pass
|
|
|
|
|
|
class BasePipelineRunner(ABC):
|
|
def __init__(self, pipe):
|
|
self.pipe = pipe
|
|
|
|
@abstractmethod
|
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
|
pass
|
|
|
|
|
|
# ===== SD3.5-medium =====
|
|
class SD35Loader(BasePipelineLoader):
|
|
def load(self):
|
|
pipe = StableDiffusion3Pipeline.from_pretrained(
|
|
"stabilityai/stable-diffusion-3.5-medium",
|
|
torch_dtype=torch.float16,
|
|
)
|
|
pipe = pipe.to(self.device)
|
|
pipe.enable_model_cpu_offload()
|
|
pipe.enable_xformers_memory_efficient_attention()
|
|
pipe.enable_attention_slicing()
|
|
return pipe
|
|
|
|
|
|
class SD35Runner(BasePipelineRunner):
|
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
|
return self.pipe(prompt=prompt, **kwargs).images
|
|
|
|
|
|
# ===== Cosmos2 =====
|
|
class CosmosLoader(BasePipelineLoader):
|
|
def __init__(
|
|
self,
|
|
model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
|
|
local_dir="weights/cosmos2",
|
|
device="cuda",
|
|
):
|
|
super().__init__(device)
|
|
self.model_id = model_id
|
|
self.local_dir = local_dir
|
|
|
|
def _patch(self):
|
|
def patch_model(cls):
|
|
orig = cls.from_pretrained
|
|
|
|
def new(*args, **kwargs):
|
|
kwargs.setdefault("attn_implementation", "flash_attention_2")
|
|
kwargs.setdefault("torch_dtype", torch.bfloat16)
|
|
return orig(*args, **kwargs)
|
|
|
|
cls.from_pretrained = new
|
|
|
|
def patch_processor(cls):
|
|
orig = cls.from_pretrained
|
|
|
|
def new(*args, **kwargs):
|
|
kwargs.setdefault("use_fast", True)
|
|
return orig(*args, **kwargs)
|
|
|
|
cls.from_pretrained = new
|
|
|
|
patch_model(AutoModelForCausalLM)
|
|
patch_processor(SiglipProcessor)
|
|
|
|
def load(self):
|
|
self._patch()
|
|
snapshot_download(
|
|
repo_id=self.model_id,
|
|
local_dir=self.local_dir,
|
|
local_dir_use_symlinks=False,
|
|
resume_download=True,
|
|
)
|
|
|
|
config = PipelineQuantizationConfig(
|
|
quant_backend="bitsandbytes_4bit",
|
|
quant_kwargs={
|
|
"load_in_4bit": True,
|
|
"bnb_4bit_quant_type": "nf4",
|
|
"bnb_4bit_compute_dtype": torch.bfloat16,
|
|
"bnb_4bit_use_double_quant": True,
|
|
},
|
|
components_to_quantize=["text_encoder", "transformer", "unet"],
|
|
)
|
|
|
|
pipe = Cosmos2TextToImagePipeline.from_pretrained(
|
|
self.model_id,
|
|
torch_dtype=torch.bfloat16,
|
|
quantization_config=config,
|
|
use_safetensors=True,
|
|
safety_checker=None,
|
|
requires_safety_checker=False,
|
|
).to(self.device)
|
|
return pipe
|
|
|
|
|
|
class CosmosRunner(BasePipelineRunner):
|
|
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
|
return self.pipe(
|
|
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
|
).images
|
|
|
|
|
|
# ===== Kolors =====
|
|
class KolorsLoader(BasePipelineLoader):
|
|
def load(self):
|
|
pipe = KolorsPipeline.from_pretrained(
|
|
"Kwai-Kolors/Kolors-diffusers",
|
|
torch_dtype=torch.float16,
|
|
variant="fp16",
|
|
).to(self.device)
|
|
pipe.enable_model_cpu_offload()
|
|
pipe.enable_xformers_memory_efficient_attention()
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
|
pipe.scheduler.config, use_karras_sigmas=True
|
|
)
|
|
return pipe
|
|
|
|
|
|
class KolorsRunner(BasePipelineRunner):
|
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
|
return self.pipe(prompt=prompt, **kwargs).images
|
|
|
|
|
|
# ===== Flux =====
|
|
class FluxLoader(BasePipelineLoader):
|
|
def load(self):
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
pipe = FluxPipeline.from_pretrained(
|
|
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
|
)
|
|
pipe.enable_model_cpu_offload()
|
|
pipe.enable_xformers_memory_efficient_attention()
|
|
pipe.enable_attention_slicing()
|
|
return pipe.to(self.device)
|
|
|
|
|
|
class FluxRunner(BasePipelineRunner):
|
|
def run(self, prompt: str, **kwargs) -> Image.Image:
|
|
return self.pipe(prompt=prompt, **kwargs).images
|
|
|
|
|
|
# ===== Chroma =====
|
|
class ChromaLoader(BasePipelineLoader):
|
|
def load(self):
|
|
return ChromaPipeline.from_pretrained(
|
|
"lodestones/Chroma", torch_dtype=torch.bfloat16
|
|
).to(self.device)
|
|
|
|
|
|
class ChromaRunner(BasePipelineRunner):
|
|
def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
|
|
return self.pipe(
|
|
prompt=prompt, negative_prompt=negative_prompt, **kwargs
|
|
).images
|
|
|
|
|
|
PIPELINE_REGISTRY = {
|
|
"sd35": (SD35Loader, SD35Runner),
|
|
"cosmos": (CosmosLoader, CosmosRunner),
|
|
"kolors": (KolorsLoader, KolorsRunner),
|
|
"flux": (FluxLoader, FluxRunner),
|
|
"chroma": (ChromaLoader, ChromaRunner),
|
|
}
|
|
|
|
|
|
def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
|
|
if name not in PIPELINE_REGISTRY:
|
|
raise ValueError(f"Unsupported model: {name}")
|
|
loader_cls, runner_cls = PIPELINE_REGISTRY[name]
|
|
pipe = loader_cls(device=device).load()
|
|
|
|
return runner_cls(pipe)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model_name = "sd35"
|
|
runner = build_hf_image_pipeline(model_name)
|
|
# NOTE: Just for pipeline testing, generation quality at low resolution is poor.
|
|
images = runner.run(
|
|
prompt="A robot holding a sign that says 'Hello'",
|
|
height=512,
|
|
width=512,
|
|
num_inference_steps=10,
|
|
guidance_scale=6,
|
|
num_images_per_prompt=1,
|
|
)
|
|
|
|
for i, img in enumerate(images):
|
|
img.save(f"image_{model_name}_{i}.jpg")
|