182 lines
7.0 KiB
Python
Raw Permalink Normal View History

2025-04-20 10:26:09 +08:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from PytorchBoot.stereotype import stereotype
@stereotype.module("nerf")
class NeRF(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# 读取位置和方向编码维度
pos_enc_out = 3 * (2 * config["pos_enc_dim"] + 1)
dir_enc_out = 3 * (2 * config["dir_enc_dim"] + 1)
# 读取网络深度和宽度(可配置)
netdepth_coarse = config.get("netdepth_coarse", 8)
netwidth_coarse = config.get("netwidth_coarse", 256)
netdepth_fine = config.get("netdepth_fine", 8)
netwidth_fine = config.get("netwidth_fine", 256)
# 构建跳跃连接
skips = config.get("skips", [4])
# 是否使用视角方向
self.use_viewdirs = config.get("use_viewdirs", True)
# 构建coarse和fine网络
if self.use_viewdirs:
# 位置编码 -> 密度 + 特征
self.pts_linears_coarse = self._build_pts_mlp(
input_dim=pos_enc_out,
width=netwidth_coarse,
depth=netdepth_coarse,
skips=skips
)
self.alpha_linear_coarse = nn.Linear(netwidth_coarse, 1)
self.feature_linear_coarse = nn.Linear(netwidth_coarse, netwidth_coarse)
# 特征 + 方向编码 -> RGB
self.views_linears_coarse = nn.ModuleList([
nn.Linear(netwidth_coarse + dir_enc_out, netwidth_coarse//2)
])
self.rgb_linear_coarse = nn.Linear(netwidth_coarse//2, 3)
# 对fine网络执行相同的操作
self.pts_linears_fine = self._build_pts_mlp(
input_dim=pos_enc_out,
width=netwidth_fine,
depth=netdepth_fine,
skips=skips
)
self.alpha_linear_fine = nn.Linear(netwidth_fine, 1)
self.feature_linear_fine = nn.Linear(netwidth_fine, netwidth_fine)
self.views_linears_fine = nn.ModuleList([
nn.Linear(netwidth_fine + dir_enc_out, netwidth_fine//2)
])
self.rgb_linear_fine = nn.Linear(netwidth_fine//2, 3)
else:
# 不使用视角方向的简化版本
self.pts_linears_coarse = self._build_pts_mlp(
input_dim=pos_enc_out,
width=netwidth_coarse,
depth=netdepth_coarse,
skips=skips
)
self.output_linear_coarse = nn.Linear(netwidth_coarse, 4)
self.pts_linears_fine = self._build_pts_mlp(
input_dim=pos_enc_out,
width=netwidth_fine,
depth=netdepth_fine,
skips=skips
)
self.output_linear_fine = nn.Linear(netwidth_fine, 4)
def _build_pts_mlp(self, input_dim, width, depth, skips):
"""构建处理位置编码的MLP网络支持跳跃连接"""
layers = nn.ModuleList()
# 第一层
layers.append(nn.Linear(input_dim, width))
# 中间层
for i in range(1, depth):
if i in skips:
layers.append(nn.Linear(input_dim + width, width))
else:
layers.append(nn.Linear(width, width))
return layers
def positional_encoding(self, x, L):
"""位置编码函数"""
encodings = [x]
for i in range(L):
encodings.append(torch.sin(2**i * x))
encodings.append(torch.cos(2**i * x))
return torch.cat(encodings, dim=-1)
def forward_mlp(self, pts_embed, viewdirs_embed, is_coarse=True):
"""前向传播MLP部分"""
if is_coarse:
pts_linears = self.pts_linears_coarse
alpha_linear = self.alpha_linear_coarse if self.use_viewdirs else None
feature_linear = self.feature_linear_coarse if self.use_viewdirs else None
views_linears = self.views_linears_coarse if self.use_viewdirs else None
rgb_linear = self.rgb_linear_coarse if self.use_viewdirs else None
output_linear = self.output_linear_coarse if not self.use_viewdirs else None
else:
pts_linears = self.pts_linears_fine
alpha_linear = self.alpha_linear_fine if self.use_viewdirs else None
feature_linear = self.feature_linear_fine if self.use_viewdirs else None
views_linears = self.views_linears_fine if self.use_viewdirs else None
rgb_linear = self.rgb_linear_fine if self.use_viewdirs else None
output_linear = self.output_linear_fine if not self.use_viewdirs else None
# 位置编码处理
h = pts_embed
for i, l in enumerate(pts_linears):
h = pts_linears[i](h)
h = F.relu(h)
# 处理跳跃连接
if i in self.config.get("skips", [4]):
h = torch.cat([pts_embed, h], -1)
if self.use_viewdirs:
# 分支1计算sigma
sigma = alpha_linear(h)
# 分支2计算颜色特征
feature = feature_linear(h)
# 结合方向编码
h = torch.cat([feature, viewdirs_embed], -1)
# 视角相关MLP
for i, l in enumerate(views_linears):
h = l(h)
h = F.relu(h)
# 输出RGB
rgb = rgb_linear(h)
rgb = torch.sigmoid(rgb) # [0,1]范围
outputs = torch.cat([rgb, sigma], -1)
else:
# 直接输出RGBA
outputs = output_linear(h)
rgb = torch.sigmoid(outputs[..., :3]) # [0,1]范围
sigma = outputs[..., 3:]
return rgb, sigma
def forward(self, pos, dir, coarse=True):
"""
前向传播
参数:
pos: 3D位置 [batch_size, ..., 3]
dir: 视角方向 [batch_size, ..., 3]
coarse: 是否使用coarse网络
返回:
sigma: 体积密度 [batch_size, ..., 1]
color: RGB颜色 [batch_size, ..., 3]
"""
# 位置和方向编码
pos_enc = self.positional_encoding(pos, self.config["pos_enc_dim"])
# 当使用视角方向时才编码方向
if self.use_viewdirs:
dir_normalized = F.normalize(dir, dim=-1)
dir_enc = self.positional_encoding(dir_normalized, self.config["dir_enc_dim"])
else:
dir_enc = None
# 选择使用coarse还是fine网络
color, sigma = self.forward_mlp(pos_enc, dir_enc, coarse)
return sigma, color