Release 3D scene generation pipeline and tag as v0.1.2. --------- Co-authored-by: xinjie.wang <xinjie.wang@gpu-4090-dev015.hogpu.cc>
332 lines
11 KiB
Python
332 lines
11 KiB
Python
# Project EmbodiedGen
|
|
#
|
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
# Part of the code comes from https://github.com/nerfstudio-project/gsplat
|
|
# Both under the Apache License, Version 2.0.
|
|
|
|
|
|
import math
|
|
import random
|
|
from io import BytesIO
|
|
from typing import Dict, Literal, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import trimesh
|
|
from gsplat.optimizers import SelectiveAdam
|
|
from scipy.spatial.transform import Rotation
|
|
from sklearn.neighbors import NearestNeighbors
|
|
from torch import Tensor
|
|
from embodied_gen.models.gs_model import GaussianOperator
|
|
|
|
__all__ = [
|
|
"set_random_seed",
|
|
"export_splats",
|
|
"create_splats_with_optimizers",
|
|
"compute_pinhole_intrinsics",
|
|
"resize_pinhole_intrinsics",
|
|
"restore_scene_scale_and_position",
|
|
]
|
|
|
|
|
|
def knn(x: Tensor, K: int = 4) -> Tensor:
|
|
x_np = x.cpu().numpy()
|
|
model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
|
|
distances, _ = model.kneighbors(x_np)
|
|
return torch.from_numpy(distances).to(x)
|
|
|
|
|
|
def rgb_to_sh(rgb: Tensor) -> Tensor:
|
|
C0 = 0.28209479177387814
|
|
return (rgb - 0.5) / C0
|
|
|
|
|
|
def set_random_seed(seed: int):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
def splat2ply_bytes(
|
|
means: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
quats: torch.Tensor,
|
|
opacities: torch.Tensor,
|
|
sh0: torch.Tensor,
|
|
shN: torch.Tensor,
|
|
) -> bytes:
|
|
num_splats = means.shape[0]
|
|
buffer = BytesIO()
|
|
|
|
# Write PLY header
|
|
buffer.write(b"ply\n")
|
|
buffer.write(b"format binary_little_endian 1.0\n")
|
|
buffer.write(f"element vertex {num_splats}\n".encode())
|
|
buffer.write(b"property float x\n")
|
|
buffer.write(b"property float y\n")
|
|
buffer.write(b"property float z\n")
|
|
for i, data in enumerate([sh0, shN]):
|
|
prefix = "f_dc" if i == 0 else "f_rest"
|
|
for j in range(data.shape[1]):
|
|
buffer.write(f"property float {prefix}_{j}\n".encode())
|
|
buffer.write(b"property float opacity\n")
|
|
for i in range(scales.shape[1]):
|
|
buffer.write(f"property float scale_{i}\n".encode())
|
|
for i in range(quats.shape[1]):
|
|
buffer.write(f"property float rot_{i}\n".encode())
|
|
buffer.write(b"end_header\n")
|
|
|
|
# Concatenate all tensors in the correct order
|
|
splat_data = torch.cat(
|
|
[means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
|
|
)
|
|
# Ensure correct dtype
|
|
splat_data = splat_data.to(torch.float32)
|
|
|
|
# Write binary data
|
|
float_dtype = np.dtype(np.float32).newbyteorder("<")
|
|
buffer.write(
|
|
splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()
|
|
)
|
|
|
|
return buffer.getvalue()
|
|
|
|
|
|
def export_splats(
|
|
means: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
quats: torch.Tensor,
|
|
opacities: torch.Tensor,
|
|
sh0: torch.Tensor,
|
|
shN: torch.Tensor,
|
|
format: Literal["ply"] = "ply",
|
|
save_to: Optional[str] = None,
|
|
) -> bytes:
|
|
"""Export a Gaussian Splats model to bytes in PLY file format."""
|
|
total_splats = means.shape[0]
|
|
assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
|
|
assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
|
|
assert quats.shape == (
|
|
total_splats,
|
|
4,
|
|
), "Quaternions must be of shape (N, 4)"
|
|
assert opacities.shape == (
|
|
total_splats,
|
|
), "Opacities must be of shape (N,)"
|
|
assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
|
|
assert (
|
|
shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
|
|
), f"shN must be of shape (N, K, 3), got {shN.shape}"
|
|
|
|
# Reshape spherical harmonics
|
|
sh0 = sh0.squeeze(1) # Shape (N, 3)
|
|
shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
|
|
|
|
# Check for NaN or Inf values
|
|
invalid_mask = (
|
|
torch.isnan(means).any(dim=1)
|
|
| torch.isinf(means).any(dim=1)
|
|
| torch.isnan(scales).any(dim=1)
|
|
| torch.isinf(scales).any(dim=1)
|
|
| torch.isnan(quats).any(dim=1)
|
|
| torch.isinf(quats).any(dim=1)
|
|
| torch.isnan(opacities).any(dim=0)
|
|
| torch.isinf(opacities).any(dim=0)
|
|
| torch.isnan(sh0).any(dim=1)
|
|
| torch.isinf(sh0).any(dim=1)
|
|
| torch.isnan(shN).any(dim=1)
|
|
| torch.isinf(shN).any(dim=1)
|
|
)
|
|
|
|
# Filter out invalid entries
|
|
valid_mask = ~invalid_mask
|
|
means = means[valid_mask]
|
|
scales = scales[valid_mask]
|
|
quats = quats[valid_mask]
|
|
opacities = opacities[valid_mask]
|
|
sh0 = sh0[valid_mask]
|
|
shN = shN[valid_mask]
|
|
|
|
if format == "ply":
|
|
data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
|
|
else:
|
|
raise ValueError(f"Unsupported format: {format}")
|
|
|
|
if save_to:
|
|
with open(save_to, "wb") as binary_file:
|
|
binary_file.write(data)
|
|
|
|
return data
|
|
|
|
|
|
def create_splats_with_optimizers(
|
|
points: np.ndarray = None,
|
|
points_rgb: np.ndarray = None,
|
|
init_num_pts: int = 100_000,
|
|
init_extent: float = 3.0,
|
|
init_opacity: float = 0.1,
|
|
init_scale: float = 1.0,
|
|
means_lr: float = 1.6e-4,
|
|
scales_lr: float = 5e-3,
|
|
opacities_lr: float = 5e-2,
|
|
quats_lr: float = 1e-3,
|
|
sh0_lr: float = 2.5e-3,
|
|
shN_lr: float = 2.5e-3 / 20,
|
|
scene_scale: float = 1.0,
|
|
sh_degree: int = 3,
|
|
sparse_grad: bool = False,
|
|
visible_adam: bool = False,
|
|
batch_size: int = 1,
|
|
feature_dim: Optional[int] = None,
|
|
device: str = "cuda",
|
|
world_rank: int = 0,
|
|
world_size: int = 1,
|
|
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
|
|
if points is not None and points_rgb is not None:
|
|
points = torch.from_numpy(points).float()
|
|
rgbs = torch.from_numpy(points_rgb / 255.0).float()
|
|
else:
|
|
points = (
|
|
init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
|
|
)
|
|
rgbs = torch.rand((init_num_pts, 3))
|
|
|
|
# Initialize the GS size to be the average dist of the 3 nearest neighbors
|
|
dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
|
|
dist_avg = torch.sqrt(dist2_avg)
|
|
scales = (
|
|
torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3)
|
|
) # [N, 3]
|
|
|
|
# Distribute the GSs to different ranks (also works for single rank)
|
|
points = points[world_rank::world_size]
|
|
rgbs = rgbs[world_rank::world_size]
|
|
scales = scales[world_rank::world_size]
|
|
|
|
N = points.shape[0]
|
|
quats = torch.rand((N, 4)) # [N, 4]
|
|
opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
|
|
|
|
params = [
|
|
# name, value, lr
|
|
("means", torch.nn.Parameter(points), means_lr * scene_scale),
|
|
("scales", torch.nn.Parameter(scales), scales_lr),
|
|
("quats", torch.nn.Parameter(quats), quats_lr),
|
|
("opacities", torch.nn.Parameter(opacities), opacities_lr),
|
|
]
|
|
|
|
if feature_dim is None:
|
|
# color is SH coefficients.
|
|
colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
|
|
colors[:, 0, :] = rgb_to_sh(rgbs)
|
|
params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr))
|
|
params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr))
|
|
else:
|
|
# features will be used for appearance and view-dependent shading
|
|
features = torch.rand(N, feature_dim) # [N, feature_dim]
|
|
params.append(("features", torch.nn.Parameter(features), sh0_lr))
|
|
colors = torch.logit(rgbs) # [N, 3]
|
|
params.append(("colors", torch.nn.Parameter(colors), sh0_lr))
|
|
|
|
splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
|
|
# Scale learning rate based on batch size, reference:
|
|
# https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
|
|
# Note that this would not make the training exactly equivalent, see
|
|
# https://arxiv.org/pdf/2402.18824v1
|
|
BS = batch_size * world_size
|
|
optimizer_class = None
|
|
if sparse_grad:
|
|
optimizer_class = torch.optim.SparseAdam
|
|
elif visible_adam:
|
|
optimizer_class = SelectiveAdam
|
|
else:
|
|
optimizer_class = torch.optim.Adam
|
|
optimizers = {
|
|
name: optimizer_class(
|
|
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
|
|
eps=1e-15 / math.sqrt(BS),
|
|
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
|
|
betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
|
|
)
|
|
for name, _, lr in params
|
|
}
|
|
return splats, optimizers
|
|
|
|
|
|
def compute_pinhole_intrinsics(
|
|
image_w: int, image_h: int, fov_deg: float
|
|
) -> np.ndarray:
|
|
fov_rad = np.deg2rad(fov_deg)
|
|
fx = image_w / (2 * np.tan(fov_rad / 2))
|
|
fy = fx # assuming square pixels
|
|
cx = image_w / 2
|
|
cy = image_h / 2
|
|
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
|
|
|
|
return K
|
|
|
|
|
|
def resize_pinhole_intrinsics(
|
|
raw_K: np.ndarray | torch.Tensor,
|
|
raw_hw: tuple[int, int],
|
|
new_hw: tuple[int, int],
|
|
) -> np.ndarray:
|
|
raw_h, raw_w = raw_hw
|
|
new_h, new_w = new_hw
|
|
|
|
scale_x = new_w / raw_w
|
|
scale_y = new_h / raw_h
|
|
|
|
new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone()
|
|
new_K[0, 0] *= scale_x # fx
|
|
new_K[0, 2] *= scale_x # cx
|
|
new_K[1, 1] *= scale_y # fy
|
|
new_K[1, 2] *= scale_y # cy
|
|
|
|
return new_K
|
|
|
|
|
|
def restore_scene_scale_and_position(
|
|
real_height: float, mesh_path: str, gs_path: str
|
|
) -> None:
|
|
"""Scales a mesh and corresponding GS model to match a given real-world height.
|
|
|
|
Uses the 1st and 99th percentile of mesh Z-axis to estimate height,
|
|
applies scaling and vertical alignment, and updates both the mesh and GS model.
|
|
|
|
Args:
|
|
real_height (float): Target real-world height among Z axis.
|
|
mesh_path (str): Path to the input mesh file.
|
|
gs_path (str): Path to the Gaussian Splatting model file.
|
|
"""
|
|
mesh = trimesh.load(mesh_path)
|
|
z_min = np.percentile(mesh.vertices[:, 1], 1)
|
|
z_max = np.percentile(mesh.vertices[:, 1], 99)
|
|
height = z_max - z_min
|
|
scale = real_height / height
|
|
|
|
rot = Rotation.from_quat([0, 1, 0, 0])
|
|
mesh.vertices = rot.apply(mesh.vertices)
|
|
mesh.vertices[:, 1] -= z_min
|
|
mesh.vertices *= scale
|
|
mesh.export(mesh_path)
|
|
|
|
gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
|
|
gs_model = gs_model.get_gaussians(
|
|
instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0])
|
|
)
|
|
gs_model.rescale(scale)
|
|
gs_model.save_to_ply(gs_path)
|