Lerobot/docker/marigold_server.py
2025-12-11 14:11:41 +08:00

26 lines
691 B
Python

import os
from diffusers import MarigoldDepthPipeline
import torch
from cloud_helper import Server, time_it
device = torch.device("cuda:1")
model = MarigoldDepthPipeline.from_pretrained("prs-eth/marigold-depth-v1-1").to(device)
@time_it
def get_depth(image):
output = model(image)
depth = output.prediction.squeeze()
return depth
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
server = Server(host="0.0.0.0", port=50001)
server.register_endpoint("get_depth", get_depth)
print(f"Marigold Depth Pipeline is running at {server.host}:{server.port}...")
server.loop_forever()