169 lines
6.2 KiB
Python
169 lines
6.2 KiB
Python
import os
|
|
import sys
|
|
import shutil
|
|
import uuid
|
|
import logging
|
|
import time
|
|
from typing import List, Optional
|
|
|
|
# Set environment variable BEFORE importing apps.common
|
|
os.environ["GRADIO_APP"] = "imageto3d"
|
|
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from PIL import Image
|
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from apps.common import (
|
|
preprocess_image_fn,
|
|
image_to_3d,
|
|
extract_3d_representations_v3,
|
|
extract_urdf,
|
|
TMP_DIR,
|
|
)
|
|
|
|
app = FastAPI(title="Image to 3D Service")
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
class ImageTo3DRequest(BaseModel):
|
|
image_paths: List[str] = Field(..., description="Path to the input images (e.g. apps/assets/example_image/sample_00.jpg)")
|
|
output_root: str = Field(..., description="Root directory for saving outputs (e.g. outputs/imageto3d)")
|
|
n_retry: int = Field(2, description="Number of retries")
|
|
|
|
# Optional parameters mirroring the CLI arguments
|
|
height_range: Optional[str] = None
|
|
mass_range: Optional[str] = None
|
|
asset_type: Optional[List[str]] = None
|
|
skip_exists: bool = False
|
|
version: Optional[str] = None
|
|
keep_intermediate: bool = False
|
|
seed: int = 0
|
|
disable_decompose_convex: bool = False
|
|
texture_size: int = 2048
|
|
|
|
class MockRequest:
|
|
def __init__(self, session_hash):
|
|
self.session_hash = session_hash
|
|
|
|
def app_init():
|
|
# Models are loaded on import of apps.common when GRADIO_APP=imageto3d
|
|
logger.info("Service initialized, models loaded.")
|
|
|
|
@app.post("/process")
|
|
async def process_images(request: ImageTo3DRequest):
|
|
start_time = time.time()
|
|
results = []
|
|
os.makedirs(request.output_root, exist_ok=True)
|
|
|
|
for image_path in request.image_paths:
|
|
if not os.path.exists(image_path):
|
|
results.append({"image": image_path, "status": "failed", "error": "File not found"})
|
|
continue
|
|
|
|
image_name = os.path.splitext(os.path.basename(image_path))[0]
|
|
dest_dir = os.path.join(request.output_root, image_name)
|
|
|
|
if request.skip_exists and os.path.exists(dest_dir):
|
|
if os.listdir(dest_dir): # check if not empty
|
|
logger.info(f"Skipping {image_name} as it already exists.")
|
|
results.append({"image": image_path, "status": "skipped", "output_dir": dest_dir})
|
|
continue
|
|
|
|
success = False
|
|
last_error = None
|
|
|
|
for attempt in range(request.n_retry + 1):
|
|
session_id = str(uuid.uuid4())
|
|
req = MockRequest(session_hash=session_id)
|
|
try:
|
|
logger.info(f"Processing {image_name} (Attempt {attempt + 1}/{request.n_retry + 1})")
|
|
|
|
# 1. Preprocess
|
|
pil_image = Image.open(image_path)
|
|
# Rembg is default in GUI
|
|
processed_image, image_cache = preprocess_image_fn(pil_image, rmbg_tag="rembg")
|
|
|
|
# 2. Generate 3D
|
|
state, _ = image_to_3d(
|
|
image=processed_image,
|
|
seed=request.seed + attempt, # Varing seed on retry
|
|
ss_guidance_strength=7.5,
|
|
ss_sampling_steps=12,
|
|
slat_guidance_strength=3.0,
|
|
slat_sampling_steps=12,
|
|
raw_image_cache=image_cache,
|
|
is_sam_image=False, # We use auto segmentation (preprocess_image_fn)
|
|
req=req
|
|
)
|
|
|
|
# 3. Extract Reps
|
|
_, _, mesh_obj, aligned_gs = extract_3d_representations_v3(
|
|
state=state,
|
|
enable_delight=True,
|
|
texture_size=request.texture_size,
|
|
req=req
|
|
)
|
|
|
|
# 4. Extract URDF
|
|
asset_cat = request.asset_type[0] if request.asset_type else ""
|
|
urdf_zip, est_type, est_height, est_mass, est_mu = extract_urdf(
|
|
gs_path=aligned_gs,
|
|
mesh_obj_path=mesh_obj,
|
|
asset_cat_text=asset_cat,
|
|
height_range_text=request.height_range or "",
|
|
mass_range_text=request.mass_range or "",
|
|
asset_version_text=request.version or "",
|
|
req=req
|
|
)
|
|
# 5. Move results
|
|
source_file = os.path.join(TMP_DIR, session_id, urdf_zip)
|
|
if os.path.exists(dest_dir):
|
|
shutil.rmtree(dest_dir)
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|
shutil.copy2(source_file, dest_dir)
|
|
# source_dir = os.path.join(TMP_DIR, session_id, urdf_zip)
|
|
|
|
# if os.path.exists(dest_dir):
|
|
# shutil.rmtree(dest_dir)
|
|
# shutil.copytree(source_dir, dest_dir)
|
|
|
|
results.append({
|
|
"image": image_path,
|
|
"status": "success",
|
|
"output_dir": dest_dir,
|
|
"estimated_attrs": {
|
|
"type": est_type,
|
|
"height": est_height,
|
|
"mass": est_mass,
|
|
"mu": est_mu
|
|
}
|
|
})
|
|
success = True
|
|
|
|
# Cleanup session
|
|
shutil.rmtree(os.path.join(TMP_DIR, session_id), ignore_errors=True)
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing {image_path} on attempt {attempt + 1}: {str(e)}")
|
|
last_error = str(e)
|
|
# Cleanup session on failure
|
|
shutil.rmtree(os.path.join(TMP_DIR, session_id), ignore_errors=True)
|
|
|
|
if not success:
|
|
results.append({"image": image_path, "status": "failed", "error": last_error})
|
|
|
|
elapsed_time = time.time() - start_time
|
|
logger.info(f"Total processing time: {elapsed_time:.2f} seconds")
|
|
return results
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {"status": "ok"}
|
|
|
|
if __name__ == "__main__":
|
|
app_init()
|
|
uvicorn.run(app, host="0.0.0.0", port=9000) |