embodiedgen/image_to_3d_service.py
u201311 11712a81bc
Some checks failed
Deploy MkDocs Documentation / build (push) Has been cancelled
Deploy MkDocs Documentation / deploy (push) Has been cancelled
add docker image build script
2026-01-16 17:09:18 +08:00

169 lines
6.2 KiB
Python

import os
import sys
import shutil
import uuid
import logging
import time
from typing import List, Optional
# Set environment variable BEFORE importing apps.common
os.environ["GRADIO_APP"] = "imageto3d"
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from PIL import Image
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from apps.common import (
preprocess_image_fn,
image_to_3d,
extract_3d_representations_v3,
extract_urdf,
TMP_DIR,
)
app = FastAPI(title="Image to 3D Service")
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class ImageTo3DRequest(BaseModel):
image_paths: List[str] = Field(..., description="Path to the input images (e.g. apps/assets/example_image/sample_00.jpg)")
output_root: str = Field(..., description="Root directory for saving outputs (e.g. outputs/imageto3d)")
n_retry: int = Field(2, description="Number of retries")
# Optional parameters mirroring the CLI arguments
height_range: Optional[str] = None
mass_range: Optional[str] = None
asset_type: Optional[List[str]] = None
skip_exists: bool = False
version: Optional[str] = None
keep_intermediate: bool = False
seed: int = 0
disable_decompose_convex: bool = False
texture_size: int = 2048
class MockRequest:
def __init__(self, session_hash):
self.session_hash = session_hash
def app_init():
# Models are loaded on import of apps.common when GRADIO_APP=imageto3d
logger.info("Service initialized, models loaded.")
@app.post("/process")
async def process_images(request: ImageTo3DRequest):
start_time = time.time()
results = []
os.makedirs(request.output_root, exist_ok=True)
for image_path in request.image_paths:
if not os.path.exists(image_path):
results.append({"image": image_path, "status": "failed", "error": "File not found"})
continue
image_name = os.path.splitext(os.path.basename(image_path))[0]
dest_dir = os.path.join(request.output_root, image_name)
if request.skip_exists and os.path.exists(dest_dir):
if os.listdir(dest_dir): # check if not empty
logger.info(f"Skipping {image_name} as it already exists.")
results.append({"image": image_path, "status": "skipped", "output_dir": dest_dir})
continue
success = False
last_error = None
for attempt in range(request.n_retry + 1):
session_id = str(uuid.uuid4())
req = MockRequest(session_hash=session_id)
try:
logger.info(f"Processing {image_name} (Attempt {attempt + 1}/{request.n_retry + 1})")
# 1. Preprocess
pil_image = Image.open(image_path)
# Rembg is default in GUI
processed_image, image_cache = preprocess_image_fn(pil_image, rmbg_tag="rembg")
# 2. Generate 3D
state, _ = image_to_3d(
image=processed_image,
seed=request.seed + attempt, # Varing seed on retry
ss_guidance_strength=7.5,
ss_sampling_steps=12,
slat_guidance_strength=3.0,
slat_sampling_steps=12,
raw_image_cache=image_cache,
is_sam_image=False, # We use auto segmentation (preprocess_image_fn)
req=req
)
# 3. Extract Reps
_, _, mesh_obj, aligned_gs = extract_3d_representations_v3(
state=state,
enable_delight=True,
texture_size=request.texture_size,
req=req
)
# 4. Extract URDF
asset_cat = request.asset_type[0] if request.asset_type else ""
urdf_zip, est_type, est_height, est_mass, est_mu = extract_urdf(
gs_path=aligned_gs,
mesh_obj_path=mesh_obj,
asset_cat_text=asset_cat,
height_range_text=request.height_range or "",
mass_range_text=request.mass_range or "",
asset_version_text=request.version or "",
req=req
)
# 5. Move results
source_file = os.path.join(TMP_DIR, session_id, urdf_zip)
if os.path.exists(dest_dir):
shutil.rmtree(dest_dir)
os.makedirs(dest_dir, exist_ok=True)
shutil.copy2(source_file, dest_dir)
# source_dir = os.path.join(TMP_DIR, session_id, urdf_zip)
# if os.path.exists(dest_dir):
# shutil.rmtree(dest_dir)
# shutil.copytree(source_dir, dest_dir)
results.append({
"image": image_path,
"status": "success",
"output_dir": dest_dir,
"estimated_attrs": {
"type": est_type,
"height": est_height,
"mass": est_mass,
"mu": est_mu
}
})
success = True
# Cleanup session
shutil.rmtree(os.path.join(TMP_DIR, session_id), ignore_errors=True)
break
except Exception as e:
logger.error(f"Error processing {image_path} on attempt {attempt + 1}: {str(e)}")
last_error = str(e)
# Cleanup session on failure
shutil.rmtree(os.path.join(TMP_DIR, session_id), ignore_errors=True)
if not success:
results.append({"image": image_path, "status": "failed", "error": last_error})
elapsed_time = time.time() - start_time
logger.info(f"Total processing time: {elapsed_time:.2f} seconds")
return results
@app.get("/health")
def health():
return {"status": "ok"}
if __name__ == "__main__":
app_init()
uvicorn.run(app, host="0.0.0.0", port=9000)