feat(pipe): Update model auto-download set delight by default. (#8)
Update model auto-download set delight by default.
This commit is contained in:
parent
5f810f3574
commit
c2d4b506ae
@ -63,7 +63,7 @@ CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 &
|
|||||||
|
|
||||||
### API
|
### API
|
||||||
Generate a 3D model from an image using the command-line API.
|
Generate a 3D model from an image using the command-line API.
|
||||||
|
Models will be downloaded automatically, please wait for the first run.
|
||||||
```sh
|
```sh
|
||||||
python3 embodied_gen/scripts/imageto3d.py \
|
python3 embodied_gen/scripts/imageto3d.py \
|
||||||
--image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
|
--image_path apps/assets/example_image/sample_04.jpg apps/assets/example_image/sample_19.jpg \
|
||||||
@ -87,7 +87,7 @@ python apps/text_to_3d.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
### API
|
### API
|
||||||
|
Models will be downloaded automatically, see `download_kolors_weights`.
|
||||||
```sh
|
```sh
|
||||||
bash embodied_gen/scripts/textto3d.sh \
|
bash embodied_gen/scripts/textto3d.sh \
|
||||||
--prompts "small bronze figurine of a lion" "带木质底座,具有经纬线的地球仪" "橙色电动手钻,有磨损细节" \
|
--prompts "small bronze figurine of a lion" "带木质底座,具有经纬线的地球仪" "橙色电动手钻,有磨损细节" \
|
||||||
@ -109,8 +109,7 @@ python apps/texture_edit.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
### API
|
### API
|
||||||
Generate textures for a 3D mesh using a text prompt.
|
Models will be downloaded automatically, see `download_kolors_weights`, `geo_cond_mv`.
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
bash embodied_gen/scripts/texture_gen.sh \
|
bash embodied_gen/scripts/texture_gen.sh \
|
||||||
--mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
|
--mesh_path "apps/assets/example_texture/meshes/robot_text.obj" \
|
||||||
|
|||||||
@ -136,32 +136,6 @@ def patched_setup_functions(self):
|
|||||||
Gaussian.setup_functions = patched_setup_functions
|
Gaussian.setup_functions = patched_setup_functions
|
||||||
|
|
||||||
|
|
||||||
def download_kolors_weights() -> None:
|
|
||||||
logger.info(f"Download kolors weights from huggingface...")
|
|
||||||
subprocess.run(
|
|
||||||
[
|
|
||||||
"huggingface-cli",
|
|
||||||
"download",
|
|
||||||
"--resume-download",
|
|
||||||
"Kwai-Kolors/Kolors",
|
|
||||||
"--local-dir",
|
|
||||||
"weights/Kolors",
|
|
||||||
],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
subprocess.run(
|
|
||||||
[
|
|
||||||
"huggingface-cli",
|
|
||||||
"download",
|
|
||||||
"--resume-download",
|
|
||||||
"Kwai-Kolors/Kolors-IP-Adapter-Plus",
|
|
||||||
"--local-dir",
|
|
||||||
"weights/Kolors-IP-Adapter-Plus",
|
|
||||||
],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if os.getenv("GRADIO_APP") == "imageto3d":
|
if os.getenv("GRADIO_APP") == "imageto3d":
|
||||||
RBG_REMOVER = RembgRemover()
|
RBG_REMOVER = RembgRemover()
|
||||||
RBG14_REMOVER = BMGG14Remover()
|
RBG14_REMOVER = BMGG14Remover()
|
||||||
@ -185,9 +159,6 @@ elif os.getenv("GRADIO_APP") == "textto3d":
|
|||||||
)
|
)
|
||||||
# PIPELINE.cuda()
|
# PIPELINE.cuda()
|
||||||
text_model_dir = "weights/Kolors"
|
text_model_dir = "weights/Kolors"
|
||||||
if not os.path.exists(text_model_dir):
|
|
||||||
download_kolors_weights()
|
|
||||||
|
|
||||||
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
|
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
|
||||||
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
|
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
|
||||||
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
||||||
@ -198,9 +169,6 @@ elif os.getenv("GRADIO_APP") == "textto3d":
|
|||||||
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
||||||
)
|
)
|
||||||
elif os.getenv("GRADIO_APP") == "texture_edit":
|
elif os.getenv("GRADIO_APP") == "texture_edit":
|
||||||
if not os.path.exists("weights/Kolors"):
|
|
||||||
download_kolors_weights()
|
|
||||||
|
|
||||||
PIPELINE_IP = build_texture_gen_pipe(
|
PIPELINE_IP = build_texture_gen_pipe(
|
||||||
base_ckpt_dir="./weights",
|
base_ckpt_dir="./weights",
|
||||||
ip_adapt_scale=0.7,
|
ip_adapt_scale=0.7,
|
||||||
|
|||||||
@ -153,8 +153,8 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
label="Randomize Seed", value=False
|
label="Randomize Seed", value=False
|
||||||
)
|
)
|
||||||
project_delight = gr.Checkbox(
|
project_delight = gr.Checkbox(
|
||||||
label="Backproject delighting",
|
label="Back-project Delight",
|
||||||
value=False,
|
value=True,
|
||||||
)
|
)
|
||||||
gr.Markdown("Geo Structure Generation")
|
gr.Markdown("Geo Structure Generation")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|||||||
@ -152,7 +152,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
label="Randomize Seed", value=False
|
label="Randomize Seed", value=False
|
||||||
)
|
)
|
||||||
project_delight = gr.Checkbox(
|
project_delight = gr.Checkbox(
|
||||||
label="backproject delight", value=False
|
label="Back-project Delight", value=True
|
||||||
)
|
)
|
||||||
gr.Markdown("Geo Structure Generation")
|
gr.Markdown("Geo Structure Generation")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|||||||
@ -215,7 +215,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
project_delight = gr.Checkbox(
|
project_delight = gr.Checkbox(
|
||||||
label="delight", value=True
|
label="Back-project delight", value=True
|
||||||
)
|
)
|
||||||
fix_mesh = gr.Checkbox(
|
fix_mesh = gr.Checkbox(
|
||||||
label="simplify mesh", value=False
|
label="simplify mesh", value=False
|
||||||
|
|||||||
@ -16,7 +16,9 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
import subprocess
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -47,14 +49,46 @@ __all__ = [
|
|||||||
"build_text2img_ip_pipeline",
|
"build_text2img_ip_pipeline",
|
||||||
"build_text2img_pipeline",
|
"build_text2img_pipeline",
|
||||||
"text2img_gen",
|
"text2img_gen",
|
||||||
|
"download_kolors_weights",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
|
||||||
|
logger.info(f"Download kolors weights from huggingface...")
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"huggingface-cli",
|
||||||
|
"download",
|
||||||
|
"--resume-download",
|
||||||
|
"Kwai-Kolors/Kolors",
|
||||||
|
"--local-dir",
|
||||||
|
local_dir,
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ip_adapter_path = f"{local_dir}/../Kolors-IP-Adapter-Plus"
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"huggingface-cli",
|
||||||
|
"download",
|
||||||
|
"--resume-download",
|
||||||
|
"Kwai-Kolors/Kolors-IP-Adapter-Plus",
|
||||||
|
"--local-dir",
|
||||||
|
ip_adapter_path,
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_text2img_ip_pipeline(
|
def build_text2img_ip_pipeline(
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
ref_scale: float,
|
ref_scale: float,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> StableDiffusionXLPipelineIP:
|
) -> StableDiffusionXLPipelineIP:
|
||||||
|
download_kolors_weights(ckpt_dir)
|
||||||
|
|
||||||
text_encoder = ChatGLMModel.from_pretrained(
|
text_encoder = ChatGLMModel.from_pretrained(
|
||||||
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
|
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
|
||||||
).half()
|
).half()
|
||||||
@ -106,6 +140,8 @@ def build_text2img_pipeline(
|
|||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> StableDiffusionXLPipeline:
|
) -> StableDiffusionXLPipeline:
|
||||||
|
download_kolors_weights(ckpt_dir)
|
||||||
|
|
||||||
text_encoder = ChatGLMModel.from_pretrained(
|
text_encoder = ChatGLMModel.from_pretrained(
|
||||||
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
|
f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
|
||||||
).half()
|
).half()
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
|
|||||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||||
)
|
)
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
from embodied_gen.models.text_model import download_kolors_weights
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_texture_gen_pipe",
|
"build_texture_gen_pipe",
|
||||||
@ -40,6 +41,8 @@ def build_texture_gen_pipe(
|
|||||||
ip_adapt_scale: float = 0,
|
ip_adapt_scale: float = 0,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
|
download_kolors_weights(f"{base_ckpt_dir}/Kolors")
|
||||||
|
|
||||||
tokenizer = ChatGLMTokenizer.from_pretrained(
|
tokenizer = ChatGLMTokenizer.from_pretrained(
|
||||||
f"{base_ckpt_dir}/Kolors/text_encoder"
|
f"{base_ckpt_dir}/Kolors/text_encoder"
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user