32 lines
999 B
Python
32 lines
999 B
Python
import os
|
|
import torch
|
|
|
|
from cloud_helper import Server, time_it
|
|
|
|
from moge.model.v2 import MoGeModel # Let's try MoGe-2
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
device = torch.device("cuda")
|
|
# model = MoGeModel.from_pretrained("Ruicheng/moge-2-vits-normal").to(device)
|
|
model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
|
|
|
|
def get_depth(image, normalize=True):
|
|
image = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
|
|
output = model.infer(image)
|
|
depth = output["depth"].cpu().numpy()
|
|
if normalize:
|
|
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
|
|
return depth
|
|
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
)
|
|
|
|
server = Server(host="0.0.0.0", port=50001)
|
|
server.register_endpoint("get_depth", get_depth)
|
|
print(f"MoGe-2 Server is running at {server.host}:{server.port}...")
|
|
server.loop_forever()
|