201 lines
7.7 KiB
Python
201 lines
7.7 KiB
Python
![]() |
import torch
|
|||
|
import torch.nn.functional as F
|
|||
|
from typing import Tuple
|
|||
|
|
|||
|
class VolumeRendererUtil:
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def render_rays(
|
|||
|
nerf_model,
|
|||
|
rays_o: torch.Tensor,
|
|||
|
rays_d: torch.Tensor,
|
|||
|
near: torch.Tensor,
|
|||
|
far: torch.Tensor,
|
|||
|
coarse_samples: int = 64,
|
|||
|
fine_samples: int = 128,
|
|||
|
perturb: bool = True
|
|||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|||
|
"""
|
|||
|
渲染光线并计算不确定性(熵)
|
|||
|
|
|||
|
参数:
|
|||
|
nerf_model: NeRF模型(需实现forward方法)
|
|||
|
rays_o: 光线起点 [N_rays, 3]
|
|||
|
rays_d: 光线方向(已归一化) [N_rays, 3]
|
|||
|
near: 近平面距离 [N_rays]
|
|||
|
far: 远平面距离 [N_rays]
|
|||
|
coarse_samples: 粗采样点数
|
|||
|
fine_samples: 精细采样点数
|
|||
|
perturb: 是否在采样时添加噪声
|
|||
|
|
|||
|
返回:
|
|||
|
rgb_map: 渲染颜色 [N_rays, 3]
|
|||
|
weights: 权重分布 [N_rays, N_samples]
|
|||
|
t_vals: 采样点参数 [N_rays, N_samples]
|
|||
|
entropy: 每条光线的熵 [N_rays]
|
|||
|
"""
|
|||
|
# 粗采样
|
|||
|
t_vals_coarse, points_coarse = VolumeRendererUtil.sample_along_ray(
|
|||
|
rays_o, rays_d, near, far, coarse_samples, perturb)
|
|||
|
|
|||
|
# 重要性采样(精细)
|
|||
|
with torch.no_grad():
|
|||
|
sigma_coarse, _ = nerf_model(points_coarse[..., :3], rays_d.unsqueeze(1))
|
|||
|
weights_coarse = VolumeRendererUtil.compute_weights(sigma_coarse, t_vals_coarse, rays_d)
|
|||
|
t_vals_fine = VolumeRendererUtil.importance_sampling(t_vals_coarse, weights_coarse, fine_samples)
|
|||
|
|
|||
|
# 合并采样点
|
|||
|
t_vals = torch.sort(torch.cat([t_vals_coarse, t_vals_fine], -1)).values
|
|||
|
points = rays_o[..., None, :] + t_vals[..., None] * rays_d[..., None, :]
|
|||
|
|
|||
|
# 精细渲染
|
|||
|
sigma, color = nerf_model(points[..., :3], rays_d.unsqueeze(1))
|
|||
|
rgb_map, weights = VolumeRendererUtil.volume_rendering(sigma, color, t_vals, rays_d)
|
|||
|
entropy = VolumeRendererUtil.calculate_entropy(weights)
|
|||
|
|
|||
|
return rgb_map, weights, t_vals, entropy
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def importance_sampling(
|
|||
|
t_vals: torch.Tensor,
|
|||
|
weights: torch.Tensor,
|
|||
|
n_samples: int
|
|||
|
) -> torch.Tensor:
|
|||
|
"""
|
|||
|
重要性采样(根据权重分布生成新采样点)
|
|||
|
|
|||
|
参数:
|
|||
|
t_vals: 原始采样点参数 [N_rays, N_coarse]
|
|||
|
weights: 权重分布 [N_rays, N_coarse]
|
|||
|
n_samples: 需要生成的采样点数
|
|||
|
|
|||
|
返回:
|
|||
|
samples: 新采样点参数 [N_rays, N_fine]
|
|||
|
"""
|
|||
|
weights = weights + 1e-5 # 防止除零
|
|||
|
pdf = weights / torch.sum(weights, -1, keepdims=True)
|
|||
|
cdf = torch.cumsum(pdf, -1)
|
|||
|
|
|||
|
# 逆变换采样
|
|||
|
u = torch.linspace(0, 1, n_samples, device=weights.device)
|
|||
|
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
|||
|
indices = torch.searchsorted(cdf, u, right=True)
|
|||
|
|
|||
|
# 插值得到新采样点
|
|||
|
below = torch.max(torch.zeros_like(indices), indices - 1)
|
|||
|
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(indices), indices)
|
|||
|
indices_g = torch.stack([below, above], -1)
|
|||
|
|
|||
|
cdf_g = torch.gather(cdf, -1, indices_g)
|
|||
|
t_vals_g = torch.gather(t_vals, -1, indices_g)
|
|||
|
|
|||
|
denom = cdf_g[..., 1] - cdf_g[..., 0]
|
|||
|
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
|||
|
t = (u - cdf_g[..., 0]) / denom
|
|||
|
samples = t_vals_g[..., 0] + t * (t_vals_g[..., 1] - t_vals_g[..., 0])
|
|||
|
|
|||
|
return samples
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def sample_along_ray(
|
|||
|
rays_o: torch.Tensor,
|
|||
|
rays_d: torch.Tensor,
|
|||
|
near: torch.Tensor,
|
|||
|
far: torch.Tensor,
|
|||
|
n_samples: int,
|
|||
|
perturb: bool = True
|
|||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|||
|
"""
|
|||
|
沿光线分层采样点
|
|||
|
|
|||
|
参数:
|
|||
|
rays_o: 光线起点 [N_rays, 3]
|
|||
|
rays_d: 光线方向 [N_rays, 3]
|
|||
|
near: 近平面距离 [N_rays]
|
|||
|
far: 远平面距离 [N_rays]
|
|||
|
n_samples: 采样点数
|
|||
|
perturb: 是否添加噪声
|
|||
|
|
|||
|
返回:
|
|||
|
t_vals: 采样点参数 [N_rays, N_samples]
|
|||
|
points: 采样点3D坐标 [N_rays, N_samples, 3]
|
|||
|
"""
|
|||
|
# 基础分层采样
|
|||
|
t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
|
|||
|
t_vals = near + (far - near) * t_vals.unsqueeze(0)
|
|||
|
|
|||
|
if perturb:
|
|||
|
# 添加分层噪声
|
|||
|
mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
|
|||
|
upper = torch.cat([mids, t_vals[..., -1:]], -1)
|
|||
|
lower = torch.cat([t_vals[..., :1], mids], -1)
|
|||
|
t_rand = torch.rand(t_vals.shape, device=rays_o.device)
|
|||
|
t_vals = lower + (upper - lower) * t_rand
|
|||
|
|
|||
|
# 生成3D点
|
|||
|
points = rays_o.unsqueeze(1) + t_vals.unsqueeze(-1) * rays_d.unsqueeze(1)
|
|||
|
return t_vals, points
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def volume_rendering(
|
|||
|
sigma: torch.Tensor,
|
|||
|
color: torch.Tensor,
|
|||
|
t_vals: torch.Tensor,
|
|||
|
rays_d: torch.Tensor
|
|||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|||
|
"""
|
|||
|
执行体积渲染
|
|||
|
|
|||
|
参数:
|
|||
|
sigma: 体积密度 [N_rays, N_samples, 1]
|
|||
|
color: RGB颜色 [N_rays, N_samples, 3]
|
|||
|
t_vals: 采样点参数 [N_rays, N_samples]
|
|||
|
rays_d: 光线方向 [N_rays, 3]
|
|||
|
|
|||
|
返回:
|
|||
|
rgb_map: 渲染颜色 [N_rays, 3]
|
|||
|
weights: 权重分布 [N_rays, N_samples]
|
|||
|
"""
|
|||
|
dists = t_vals[..., 1:] - t_vals[..., :-1]
|
|||
|
dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).expand(dists[..., :1].shape)], -1)
|
|||
|
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
|
|||
|
|
|||
|
alpha = 1. - torch.exp(-sigma.squeeze(-1) * dists)
|
|||
|
trans = torch.exp(-torch.cat([
|
|||
|
torch.zeros_like(sigma[..., :1, 0]),
|
|||
|
torch.cumsum(sigma[..., :-1, 0] * dists[..., :-1].unsqueeze(-1), dim=-2)
|
|||
|
], dim=-2))
|
|||
|
weights = alpha * trans.squeeze(-1)
|
|||
|
|
|||
|
rgb_map = torch.sum(weights.unsqueeze(-1) * color, dim=-2)
|
|||
|
return rgb_map, weights
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def calculate_entropy(weights: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
|
|||
|
"""
|
|||
|
计算权重分布的熵
|
|||
|
|
|||
|
参数:
|
|||
|
weights: 权重分布 [N_rays, N_samples]
|
|||
|
eps: 防止log(0)的小量
|
|||
|
|
|||
|
返回:
|
|||
|
entropy: 每条光线的熵 [N_rays]
|
|||
|
"""
|
|||
|
norm_weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + eps)
|
|||
|
entropy = -torch.sum(norm_weights * torch.log(norm_weights + eps), dim=-1)
|
|||
|
return entropy
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def compute_weights(sigma: torch.Tensor, t_vals: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
|
|||
|
"""计算权重(用于重要性采样)"""
|
|||
|
dists = t_vals[..., 1:] - t_vals[..., :-1]
|
|||
|
dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).expand(dists[..., :1].shape)], -1)
|
|||
|
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
|
|||
|
|
|||
|
alpha = 1. - torch.exp(-sigma.squeeze(-1) * dists)
|
|||
|
trans = torch.exp(-torch.cat([
|
|||
|
torch.zeros_like(sigma[..., :1, 0]),
|
|||
|
torch.cumsum(sigma[..., :-1, 0] * dists[..., :-1].unsqueeze(-1), dim=-2)
|
|||
|
], dim=-2))
|
|||
|
return alpha * trans.squeeze(-1)
|