* feat(pipe): Add layout generation pipe. * tag v0.1.3 * chore: update assets * update --------- Co-authored-by: xinjie.wang <xinjie.wang@gpu-4090-dev015.hogpu.cc>
679 lines
25 KiB
Python
679 lines
25 KiB
Python
# Project EmbodiedGen
|
|
#
|
|
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
# implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
# Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
|
|
# Both under the Apache License, Version 2.0.
|
|
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import cv2
|
|
import imageio
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import tqdm
|
|
import tyro
|
|
import yaml
|
|
from fused_ssim import fused_ssim
|
|
from gsplat.distributed import cli
|
|
from gsplat.rendering import rasterization
|
|
from gsplat.strategy import DefaultStrategy, MCMCStrategy
|
|
from torch import Tensor
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torchmetrics.image import (
|
|
PeakSignalNoiseRatio,
|
|
StructuralSimilarityIndexMeasure,
|
|
)
|
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
|
from typing_extensions import Literal, assert_never
|
|
from embodied_gen.data.datasets import PanoGSplatDataset
|
|
from embodied_gen.utils.config import GsplatTrainConfig
|
|
from embodied_gen.utils.gaussian import (
|
|
create_splats_with_optimizers,
|
|
export_splats,
|
|
resize_pinhole_intrinsics,
|
|
set_random_seed,
|
|
)
|
|
|
|
|
|
class Runner:
|
|
"""Engine for training and testing from gsplat example.
|
|
|
|
Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
local_rank: int,
|
|
world_rank,
|
|
world_size: int,
|
|
cfg: GsplatTrainConfig,
|
|
) -> None:
|
|
set_random_seed(42 + local_rank)
|
|
|
|
self.cfg = cfg
|
|
self.world_rank = world_rank
|
|
self.local_rank = local_rank
|
|
self.world_size = world_size
|
|
self.device = f"cuda:{local_rank}"
|
|
|
|
# Where to dump results.
|
|
os.makedirs(cfg.result_dir, exist_ok=True)
|
|
|
|
# Setup output directories.
|
|
self.ckpt_dir = f"{cfg.result_dir}/ckpts"
|
|
os.makedirs(self.ckpt_dir, exist_ok=True)
|
|
self.stats_dir = f"{cfg.result_dir}/stats"
|
|
os.makedirs(self.stats_dir, exist_ok=True)
|
|
self.render_dir = f"{cfg.result_dir}/renders"
|
|
os.makedirs(self.render_dir, exist_ok=True)
|
|
self.ply_dir = f"{cfg.result_dir}/ply"
|
|
os.makedirs(self.ply_dir, exist_ok=True)
|
|
|
|
# Tensorboard
|
|
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
|
|
self.trainset = PanoGSplatDataset(cfg.data_dir, split="train")
|
|
self.valset = PanoGSplatDataset(
|
|
cfg.data_dir, split="train", max_sample_num=6
|
|
)
|
|
self.testset = PanoGSplatDataset(cfg.data_dir, split="eval")
|
|
self.scene_scale = cfg.scene_scale
|
|
|
|
# Model
|
|
self.splats, self.optimizers = create_splats_with_optimizers(
|
|
self.trainset.points,
|
|
self.trainset.points_rgb,
|
|
init_num_pts=cfg.init_num_pts,
|
|
init_extent=cfg.init_extent,
|
|
init_opacity=cfg.init_opa,
|
|
init_scale=cfg.init_scale,
|
|
means_lr=cfg.means_lr,
|
|
scales_lr=cfg.scales_lr,
|
|
opacities_lr=cfg.opacities_lr,
|
|
quats_lr=cfg.quats_lr,
|
|
sh0_lr=cfg.sh0_lr,
|
|
shN_lr=cfg.shN_lr,
|
|
scene_scale=self.scene_scale,
|
|
sh_degree=cfg.sh_degree,
|
|
sparse_grad=cfg.sparse_grad,
|
|
visible_adam=cfg.visible_adam,
|
|
batch_size=cfg.batch_size,
|
|
feature_dim=None,
|
|
device=self.device,
|
|
world_rank=world_rank,
|
|
world_size=world_size,
|
|
)
|
|
print("Model initialized. Number of GS:", len(self.splats["means"]))
|
|
|
|
# Densification Strategy
|
|
self.cfg.strategy.check_sanity(self.splats, self.optimizers)
|
|
|
|
if isinstance(self.cfg.strategy, DefaultStrategy):
|
|
self.strategy_state = self.cfg.strategy.initialize_state(
|
|
scene_scale=self.scene_scale
|
|
)
|
|
elif isinstance(self.cfg.strategy, MCMCStrategy):
|
|
self.strategy_state = self.cfg.strategy.initialize_state()
|
|
else:
|
|
assert_never(self.cfg.strategy)
|
|
|
|
# Losses & Metrics.
|
|
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(
|
|
self.device
|
|
)
|
|
self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
|
|
|
|
if cfg.lpips_net == "alex":
|
|
self.lpips = LearnedPerceptualImagePatchSimilarity(
|
|
net_type="alex", normalize=True
|
|
).to(self.device)
|
|
elif cfg.lpips_net == "vgg":
|
|
# The 3DGS official repo uses lpips vgg, which is equivalent with the following:
|
|
self.lpips = LearnedPerceptualImagePatchSimilarity(
|
|
net_type="vgg", normalize=False
|
|
).to(self.device)
|
|
else:
|
|
raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
|
|
|
|
def rasterize_splats(
|
|
self,
|
|
camtoworlds: Tensor,
|
|
Ks: Tensor,
|
|
width: int,
|
|
height: int,
|
|
masks: Optional[Tensor] = None,
|
|
rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
|
|
camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
|
|
**kwargs,
|
|
) -> Tuple[Tensor, Tensor, Dict]:
|
|
means = self.splats["means"] # [N, 3]
|
|
# quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
|
|
# rasterization does normalization internally
|
|
quats = self.splats["quats"] # [N, 4]
|
|
scales = torch.exp(self.splats["scales"]) # [N, 3]
|
|
opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
|
|
image_ids = kwargs.pop("image_ids", None)
|
|
|
|
colors = torch.cat(
|
|
[self.splats["sh0"], self.splats["shN"]], 1
|
|
) # [N, K, 3]
|
|
|
|
if rasterize_mode is None:
|
|
rasterize_mode = (
|
|
"antialiased" if self.cfg.antialiased else "classic"
|
|
)
|
|
if camera_model is None:
|
|
camera_model = self.cfg.camera_model
|
|
|
|
render_colors, render_alphas, info = rasterization(
|
|
means=means,
|
|
quats=quats,
|
|
scales=scales,
|
|
opacities=opacities,
|
|
colors=colors,
|
|
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
|
|
Ks=Ks, # [C, 3, 3]
|
|
width=width,
|
|
height=height,
|
|
packed=self.cfg.packed,
|
|
absgrad=(
|
|
self.cfg.strategy.absgrad
|
|
if isinstance(self.cfg.strategy, DefaultStrategy)
|
|
else False
|
|
),
|
|
sparse_grad=self.cfg.sparse_grad,
|
|
rasterize_mode=rasterize_mode,
|
|
distributed=self.world_size > 1,
|
|
camera_model=self.cfg.camera_model,
|
|
with_ut=self.cfg.with_ut,
|
|
with_eval3d=self.cfg.with_eval3d,
|
|
**kwargs,
|
|
)
|
|
if masks is not None:
|
|
render_colors[~masks] = 0
|
|
return render_colors, render_alphas, info
|
|
|
|
def train(self):
|
|
cfg = self.cfg
|
|
device = self.device
|
|
world_rank = self.world_rank
|
|
|
|
# Dump cfg.
|
|
if world_rank == 0:
|
|
with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
|
|
yaml.dump(vars(cfg), f)
|
|
|
|
max_steps = cfg.max_steps
|
|
init_step = 0
|
|
|
|
schedulers = [
|
|
# means has a learning rate schedule, that end at 0.01 of the initial value
|
|
torch.optim.lr_scheduler.ExponentialLR(
|
|
self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
|
|
),
|
|
]
|
|
trainloader = torch.utils.data.DataLoader(
|
|
self.trainset,
|
|
batch_size=cfg.batch_size,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
persistent_workers=True,
|
|
pin_memory=True,
|
|
)
|
|
trainloader_iter = iter(trainloader)
|
|
|
|
# Training loop.
|
|
global_tic = time.time()
|
|
pbar = tqdm.tqdm(range(init_step, max_steps))
|
|
for step in pbar:
|
|
try:
|
|
data = next(trainloader_iter)
|
|
except StopIteration:
|
|
trainloader_iter = iter(trainloader)
|
|
data = next(trainloader_iter)
|
|
|
|
camtoworlds = data["camtoworld"].to(device) # [1, 4, 4]
|
|
Ks = data["K"].to(device) # [1, 3, 3]
|
|
pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
|
|
image_ids = data["image_id"].to(device)
|
|
masks = (
|
|
data["mask"].to(device) if "mask" in data else None
|
|
) # [1, H, W]
|
|
if cfg.depth_loss:
|
|
points = data["points"].to(device) # [1, M, 2]
|
|
depths_gt = data["depths"].to(device) # [1, M]
|
|
|
|
height, width = pixels.shape[1:3]
|
|
|
|
# sh schedule
|
|
sh_degree_to_use = min(
|
|
step // cfg.sh_degree_interval, cfg.sh_degree
|
|
)
|
|
|
|
# forward
|
|
renders, alphas, info = self.rasterize_splats(
|
|
camtoworlds=camtoworlds,
|
|
Ks=Ks,
|
|
width=width,
|
|
height=height,
|
|
sh_degree=sh_degree_to_use,
|
|
near_plane=cfg.near_plane,
|
|
far_plane=cfg.far_plane,
|
|
image_ids=image_ids,
|
|
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
|
|
masks=masks,
|
|
)
|
|
if renders.shape[-1] == 4:
|
|
colors, depths = renders[..., 0:3], renders[..., 3:4]
|
|
else:
|
|
colors, depths = renders, None
|
|
|
|
if cfg.random_bkgd:
|
|
bkgd = torch.rand(1, 3, device=device)
|
|
colors = colors + bkgd * (1.0 - alphas)
|
|
|
|
self.cfg.strategy.step_pre_backward(
|
|
params=self.splats,
|
|
optimizers=self.optimizers,
|
|
state=self.strategy_state,
|
|
step=step,
|
|
info=info,
|
|
)
|
|
|
|
# loss
|
|
l1loss = F.l1_loss(colors, pixels)
|
|
ssimloss = 1.0 - fused_ssim(
|
|
colors.permute(0, 3, 1, 2),
|
|
pixels.permute(0, 3, 1, 2),
|
|
padding="valid",
|
|
)
|
|
loss = (
|
|
l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
|
|
)
|
|
if cfg.depth_loss:
|
|
# query depths from depth map
|
|
points = torch.stack(
|
|
[
|
|
points[:, :, 0] / (width - 1) * 2 - 1,
|
|
points[:, :, 1] / (height - 1) * 2 - 1,
|
|
],
|
|
dim=-1,
|
|
) # normalize to [-1, 1]
|
|
grid = points.unsqueeze(2) # [1, M, 1, 2]
|
|
depths = F.grid_sample(
|
|
depths.permute(0, 3, 1, 2), grid, align_corners=True
|
|
) # [1, 1, M, 1]
|
|
depths = depths.squeeze(3).squeeze(1) # [1, M]
|
|
# calculate loss in disparity space
|
|
disp = torch.where(
|
|
depths > 0.0, 1.0 / depths, torch.zeros_like(depths)
|
|
)
|
|
disp_gt = 1.0 / depths_gt # [1, M]
|
|
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
|
|
loss += depthloss * cfg.depth_lambda
|
|
|
|
# regularizations
|
|
if cfg.opacity_reg > 0.0:
|
|
loss += (
|
|
cfg.opacity_reg
|
|
* torch.sigmoid(self.splats["opacities"]).mean()
|
|
)
|
|
if cfg.scale_reg > 0.0:
|
|
loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean()
|
|
|
|
loss.backward()
|
|
|
|
desc = (
|
|
f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
|
|
)
|
|
if cfg.depth_loss:
|
|
desc += f"depth loss={depthloss.item():.6f}| "
|
|
pbar.set_description(desc)
|
|
|
|
# write images (gt and render)
|
|
# if world_rank == 0 and step % 800 == 0:
|
|
# canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
|
|
# canvas = canvas.reshape(-1, *canvas.shape[2:])
|
|
# imageio.imwrite(
|
|
# f"{self.render_dir}/train_rank{self.world_rank}.png",
|
|
# (canvas * 255).astype(np.uint8),
|
|
# )
|
|
|
|
if (
|
|
world_rank == 0
|
|
and cfg.tb_every > 0
|
|
and step % cfg.tb_every == 0
|
|
):
|
|
mem = torch.cuda.max_memory_allocated() / 1024**3
|
|
self.writer.add_scalar("train/loss", loss.item(), step)
|
|
self.writer.add_scalar("train/l1loss", l1loss.item(), step)
|
|
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
|
|
self.writer.add_scalar(
|
|
"train/num_GS", len(self.splats["means"]), step
|
|
)
|
|
self.writer.add_scalar("train/mem", mem, step)
|
|
if cfg.depth_loss:
|
|
self.writer.add_scalar(
|
|
"train/depthloss", depthloss.item(), step
|
|
)
|
|
if cfg.tb_save_image:
|
|
canvas = (
|
|
torch.cat([pixels, colors], dim=2)
|
|
.detach()
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
canvas = canvas.reshape(-1, *canvas.shape[2:])
|
|
self.writer.add_image("train/render", canvas, step)
|
|
self.writer.flush()
|
|
|
|
# save checkpoint before updating the model
|
|
if (
|
|
step in [i - 1 for i in cfg.save_steps]
|
|
or step == max_steps - 1
|
|
):
|
|
mem = torch.cuda.max_memory_allocated() / 1024**3
|
|
stats = {
|
|
"mem": mem,
|
|
"ellipse_time": time.time() - global_tic,
|
|
"num_GS": len(self.splats["means"]),
|
|
}
|
|
print("Step: ", step, stats)
|
|
with open(
|
|
f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
|
|
"w",
|
|
) as f:
|
|
json.dump(stats, f)
|
|
data = {"step": step, "splats": self.splats.state_dict()}
|
|
torch.save(
|
|
data,
|
|
f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt",
|
|
)
|
|
if (
|
|
step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
|
|
) and cfg.save_ply:
|
|
sh0 = self.splats["sh0"]
|
|
shN = self.splats["shN"]
|
|
means = self.splats["means"]
|
|
scales = self.splats["scales"]
|
|
quats = self.splats["quats"]
|
|
opacities = self.splats["opacities"]
|
|
export_splats(
|
|
means=means,
|
|
scales=scales,
|
|
quats=quats,
|
|
opacities=opacities,
|
|
sh0=sh0,
|
|
shN=shN,
|
|
format="ply",
|
|
save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
|
|
)
|
|
|
|
# Turn Gradients into Sparse Tensor before running optimizer
|
|
if cfg.sparse_grad:
|
|
assert (
|
|
cfg.packed
|
|
), "Sparse gradients only work with packed mode."
|
|
gaussian_ids = info["gaussian_ids"]
|
|
for k in self.splats.keys():
|
|
grad = self.splats[k].grad
|
|
if grad is None or grad.is_sparse:
|
|
continue
|
|
self.splats[k].grad = torch.sparse_coo_tensor(
|
|
indices=gaussian_ids[None], # [1, nnz]
|
|
values=grad[gaussian_ids], # [nnz, ...]
|
|
size=self.splats[k].size(), # [N, ...]
|
|
is_coalesced=len(Ks) == 1,
|
|
)
|
|
|
|
if cfg.visible_adam:
|
|
gaussian_cnt = self.splats.means.shape[0]
|
|
if cfg.packed:
|
|
visibility_mask = torch.zeros_like(
|
|
self.splats["opacities"], dtype=bool
|
|
)
|
|
visibility_mask.scatter_(0, info["gaussian_ids"], 1)
|
|
else:
|
|
visibility_mask = (info["radii"] > 0).all(-1).any(0)
|
|
|
|
# optimize
|
|
for optimizer in self.optimizers.values():
|
|
if cfg.visible_adam:
|
|
optimizer.step(visibility_mask)
|
|
else:
|
|
optimizer.step()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
for scheduler in schedulers:
|
|
scheduler.step()
|
|
|
|
# Run post-backward steps after backward and optimizer
|
|
if isinstance(self.cfg.strategy, DefaultStrategy):
|
|
self.cfg.strategy.step_post_backward(
|
|
params=self.splats,
|
|
optimizers=self.optimizers,
|
|
state=self.strategy_state,
|
|
step=step,
|
|
info=info,
|
|
packed=cfg.packed,
|
|
)
|
|
elif isinstance(self.cfg.strategy, MCMCStrategy):
|
|
self.cfg.strategy.step_post_backward(
|
|
params=self.splats,
|
|
optimizers=self.optimizers,
|
|
state=self.strategy_state,
|
|
step=step,
|
|
info=info,
|
|
lr=schedulers[0].get_last_lr()[0],
|
|
)
|
|
else:
|
|
assert_never(self.cfg.strategy)
|
|
|
|
# eval the full set
|
|
if step in [i - 1 for i in cfg.eval_steps]:
|
|
self.eval(step)
|
|
self.render_video(step)
|
|
|
|
@torch.no_grad()
|
|
def eval(
|
|
self,
|
|
step: int,
|
|
stage: str = "val",
|
|
canvas_h: int = 512,
|
|
canvas_w: int = 1024,
|
|
):
|
|
"""Entry for evaluation."""
|
|
print("Running evaluation...")
|
|
cfg = self.cfg
|
|
device = self.device
|
|
world_rank = self.world_rank
|
|
|
|
valloader = torch.utils.data.DataLoader(
|
|
self.valset, batch_size=1, shuffle=False, num_workers=1
|
|
)
|
|
ellipse_time = 0
|
|
metrics = defaultdict(list)
|
|
for i, data in enumerate(valloader):
|
|
camtoworlds = data["camtoworld"].to(device)
|
|
Ks = data["K"].to(device)
|
|
pixels = data["image"].to(device) / 255.0
|
|
height, width = pixels.shape[1:3]
|
|
masks = data["mask"].to(device) if "mask" in data else None
|
|
|
|
pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW
|
|
pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2))
|
|
|
|
torch.cuda.synchronize()
|
|
tic = time.time()
|
|
colors, _, _ = self.rasterize_splats(
|
|
camtoworlds=camtoworlds,
|
|
Ks=Ks,
|
|
width=width,
|
|
height=height,
|
|
sh_degree=cfg.sh_degree,
|
|
near_plane=cfg.near_plane,
|
|
far_plane=cfg.far_plane,
|
|
masks=masks,
|
|
) # [1, H, W, 3]
|
|
torch.cuda.synchronize()
|
|
ellipse_time += max(time.time() - tic, 1e-10)
|
|
|
|
colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW
|
|
colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2))
|
|
colors = torch.clamp(colors, 0.0, 1.0)
|
|
canvas_list = [pixels, colors]
|
|
|
|
if world_rank == 0:
|
|
canvas = torch.cat(canvas_list, dim=2).squeeze(0)
|
|
canvas = canvas.permute(1, 2, 0) # CHW -> HWC
|
|
canvas = (canvas * 255).to(torch.uint8).cpu().numpy()
|
|
cv2.imwrite(
|
|
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
|
|
canvas[..., ::-1],
|
|
)
|
|
metrics["psnr"].append(self.psnr(colors, pixels))
|
|
metrics["ssim"].append(self.ssim(colors, pixels))
|
|
metrics["lpips"].append(self.lpips(colors, pixels))
|
|
|
|
if world_rank == 0:
|
|
ellipse_time /= len(valloader)
|
|
|
|
stats = {
|
|
k: torch.stack(v).mean().item() for k, v in metrics.items()
|
|
}
|
|
stats.update(
|
|
{
|
|
"ellipse_time": ellipse_time,
|
|
"num_GS": len(self.splats["means"]),
|
|
}
|
|
)
|
|
print(
|
|
f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
|
|
f"Time: {stats['ellipse_time']:.3f}s/image "
|
|
f"Number of GS: {stats['num_GS']}"
|
|
)
|
|
# save stats as json
|
|
with open(
|
|
f"{self.stats_dir}/{stage}_step{step:04d}.json", "w"
|
|
) as f:
|
|
json.dump(stats, f)
|
|
# save stats to tensorboard
|
|
for k, v in stats.items():
|
|
self.writer.add_scalar(f"{stage}/{k}", v, step)
|
|
self.writer.flush()
|
|
|
|
@torch.no_grad()
|
|
def render_video(
|
|
self, step: int, canvas_h: int = 512, canvas_w: int = 1024
|
|
):
|
|
testloader = torch.utils.data.DataLoader(
|
|
self.testset, batch_size=1, shuffle=False, num_workers=1
|
|
)
|
|
|
|
images_cache = []
|
|
depth_global_min, depth_global_max = float("inf"), -float("inf")
|
|
for data in testloader:
|
|
camtoworlds = data["camtoworld"].to(self.device)
|
|
Ks = resize_pinhole_intrinsics(
|
|
data["K"].squeeze(),
|
|
raw_hw=(data["image_h"].item(), data["image_w"].item()),
|
|
new_hw=(canvas_h, canvas_w // 2),
|
|
).to(self.device)
|
|
renders, _, _ = self.rasterize_splats(
|
|
camtoworlds=camtoworlds,
|
|
Ks=Ks[None, ...],
|
|
width=canvas_w // 2,
|
|
height=canvas_h,
|
|
sh_degree=self.cfg.sh_degree,
|
|
near_plane=self.cfg.near_plane,
|
|
far_plane=self.cfg.far_plane,
|
|
render_mode="RGB+ED",
|
|
) # [1, H, W, 4]
|
|
colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
|
|
colors = (colors * 255).to(torch.uint8).cpu().numpy()
|
|
depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device.
|
|
images_cache.append([colors, depths])
|
|
depth_global_min = min(depth_global_min, depths.min().item())
|
|
depth_global_max = max(depth_global_max, depths.max().item())
|
|
|
|
video_path = f"{self.render_dir}/video_step{step}.mp4"
|
|
writer = imageio.get_writer(video_path, fps=30)
|
|
for rgb, depth in images_cache:
|
|
depth_normalized = torch.clip(
|
|
(depth - depth_global_min)
|
|
/ (depth_global_max - depth_global_min + 1e-8),
|
|
0,
|
|
1,
|
|
)
|
|
depth_normalized = (
|
|
(depth_normalized * 255).to(torch.uint8).cpu().numpy()
|
|
)
|
|
depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET)
|
|
image = np.concatenate([rgb, depth_map], axis=1)
|
|
writer.append_data(image)
|
|
|
|
writer.close()
|
|
|
|
|
|
def entrypoint(
|
|
local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig
|
|
):
|
|
runner = Runner(local_rank, world_rank, world_size, cfg)
|
|
|
|
if cfg.ckpt is not None:
|
|
# run eval only
|
|
ckpts = [
|
|
torch.load(file, map_location=runner.device, weights_only=True)
|
|
for file in cfg.ckpt
|
|
]
|
|
for k in runner.splats.keys():
|
|
runner.splats[k].data = torch.cat(
|
|
[ckpt["splats"][k] for ckpt in ckpts]
|
|
)
|
|
step = ckpts[0]["step"]
|
|
runner.eval(step=step)
|
|
runner.render_video(step=step)
|
|
else:
|
|
runner.train()
|
|
runner.render_video(step=cfg.max_steps - 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
configs = {
|
|
"default": (
|
|
"Gaussian splatting training using densification heuristics from the original paper.",
|
|
GsplatTrainConfig(
|
|
strategy=DefaultStrategy(verbose=True),
|
|
),
|
|
),
|
|
"mcmc": (
|
|
"Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
|
|
GsplatTrainConfig(
|
|
init_scale=0.1,
|
|
opacity_reg=0.01,
|
|
scale_reg=0.01,
|
|
strategy=MCMCStrategy(verbose=True),
|
|
),
|
|
),
|
|
}
|
|
cfg = tyro.extras.overridable_config_cli(configs)
|
|
cfg.adjust_steps(cfg.steps_scaler)
|
|
|
|
cli(entrypoint, cfg, verbose=True)
|