Release 3D scene generation pipeline and tag as v0.1.2. --------- Co-authored-by: xinjie.wang <xinjie.wang@gpu-4090-dev015.hogpu.cc>
192 lines
6.2 KiB
Python
192 lines
6.2 KiB
Python
import logging
|
|
import os
|
|
import random
|
|
import time
|
|
import warnings
|
|
from dataclasses import dataclass, field
|
|
from shutil import copy, rmtree
|
|
|
|
import torch
|
|
import tyro
|
|
from huggingface_hub import snapshot_download
|
|
from packaging import version
|
|
|
|
# Suppress warnings
|
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
|
|
|
# TorchVision monkey patch for >0.16
|
|
if version.parse(torch.__version__) >= version.parse("0.16"):
|
|
import sys
|
|
import types
|
|
|
|
import torchvision.transforms.functional as TF
|
|
|
|
functional_tensor = types.ModuleType(
|
|
"torchvision.transforms.functional_tensor"
|
|
)
|
|
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
|
|
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
|
|
|
|
from gsplat.distributed import cli
|
|
from txt2panoimg import Text2360PanoramaImagePipeline
|
|
from embodied_gen.trainer.gsplat_trainer import (
|
|
DefaultStrategy,
|
|
GsplatTrainConfig,
|
|
)
|
|
from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint
|
|
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
|
|
from embodied_gen.utils.config import Pano2MeshSRConfig
|
|
from embodied_gen.utils.gaussian import restore_scene_scale_and_position
|
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|
from embodied_gen.utils.log import logger
|
|
from embodied_gen.utils.process_media import is_image_file, parse_text_prompts
|
|
from embodied_gen.validators.quality_checkers import (
|
|
PanoHeightEstimator,
|
|
PanoImageOccChecker,
|
|
)
|
|
|
|
__all__ = [
|
|
"generate_pano_image",
|
|
"entrypoint",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class Scene3DGenConfig:
|
|
prompts: list[str] # Text desc of indoor room or style reference image.
|
|
output_dir: str
|
|
seed: int | None = None
|
|
real_height: float | None = None # The real height of the room in meters.
|
|
pano_image_only: bool = False
|
|
disable_pano_check: bool = False
|
|
keep_middle_result: bool = False
|
|
n_retry: int = 7
|
|
gs3d: GsplatTrainConfig = field(
|
|
default_factory=lambda: GsplatTrainConfig(
|
|
strategy=DefaultStrategy(verbose=True),
|
|
max_steps=4000,
|
|
init_opa=0.9,
|
|
opacity_reg=2e-3,
|
|
sh_degree=0,
|
|
means_lr=1e-4,
|
|
scales_lr=1e-3,
|
|
)
|
|
)
|
|
|
|
|
|
def generate_pano_image(
|
|
prompt: str,
|
|
output_path: str,
|
|
pipeline,
|
|
seed: int,
|
|
n_retry: int,
|
|
checker=None,
|
|
num_inference_steps: int = 40,
|
|
) -> None:
|
|
for i in range(n_retry):
|
|
logger.info(
|
|
f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}"
|
|
)
|
|
if is_image_file(prompt):
|
|
raise NotImplementedError("Image mode not implemented yet.")
|
|
else:
|
|
txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture"
|
|
inputs = {
|
|
"prompt": txt_prompt,
|
|
"num_inference_steps": num_inference_steps,
|
|
"upscale": False,
|
|
"seed": seed,
|
|
}
|
|
pano_image = pipeline(inputs)
|
|
|
|
pano_image.save(output_path)
|
|
if checker is None:
|
|
break
|
|
|
|
flag, response = checker(pano_image)
|
|
logger.warning(f"{response}, image saved in {output_path}")
|
|
if flag is True or flag is None:
|
|
break
|
|
|
|
seed = random.randint(0, 100000)
|
|
|
|
return
|
|
|
|
|
|
def entrypoint(*args, **kwargs):
|
|
cfg = tyro.cli(Scene3DGenConfig)
|
|
|
|
# Init global models.
|
|
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
|
|
IMG2PANO_PIPE = Text2360PanoramaImagePipeline(
|
|
model_path, torch_dtype=torch.float16, device="cuda"
|
|
)
|
|
PANOMESH_CFG = Pano2MeshSRConfig()
|
|
PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG)
|
|
PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000])
|
|
PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT)
|
|
|
|
prompts = parse_text_prompts(cfg.prompts)
|
|
for idx, prompt in enumerate(prompts):
|
|
start_time = time.time()
|
|
output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}")
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
pano_path = os.path.join(output_dir, "pano_image.png")
|
|
with open(f"{output_dir}/prompt.txt", "w") as f:
|
|
f.write(prompt)
|
|
|
|
generate_pano_image(
|
|
prompt,
|
|
pano_path,
|
|
IMG2PANO_PIPE,
|
|
cfg.seed if cfg.seed is not None else random.randint(0, 100000),
|
|
cfg.n_retry,
|
|
checker=None if cfg.disable_pano_check else PANO_CHECKER,
|
|
)
|
|
|
|
if cfg.pano_image_only:
|
|
continue
|
|
|
|
logger.info("GEN and REPAIR Mesh from Panorama...")
|
|
PANO2MESH_PIPE(pano_path, output_dir)
|
|
|
|
logger.info("TRAIN 3DGS from Mesh Init and Cube Image...")
|
|
cfg.gs3d.data_dir = output_dir
|
|
cfg.gs3d.result_dir = f"{output_dir}/gaussian"
|
|
cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler)
|
|
torch.set_default_device("cpu") # recover default setting.
|
|
cli(gsplat_entrypoint, cfg.gs3d, verbose=True)
|
|
|
|
# Clean up the middle results.
|
|
gs_path = (
|
|
f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply"
|
|
)
|
|
copy(gs_path, f"{output_dir}/gs_model.ply")
|
|
video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4"
|
|
copy(video_path, f"{output_dir}/video.mp4")
|
|
gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml"
|
|
copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml")
|
|
if not cfg.keep_middle_result:
|
|
rmtree(cfg.gs3d.result_dir, ignore_errors=True)
|
|
os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}")
|
|
|
|
real_height = (
|
|
PANOHEIGHT_ESTOR(pano_path)
|
|
if cfg.real_height is None
|
|
else cfg.real_height
|
|
)
|
|
gs_path = os.path.join(output_dir, "gs_model.ply")
|
|
mesh_path = os.path.join(output_dir, "mesh_model.ply")
|
|
restore_scene_scale_and_position(real_height, mesh_path, gs_path)
|
|
|
|
elapsed_time = (time.time() - start_time) / 60
|
|
logger.info(
|
|
f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins."
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
entrypoint()
|