2025-04-20 10:26:09 +08:00

182 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from PytorchBoot.stereotype import stereotype
@stereotype.module("nerf")
class NeRF(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# 读取位置和方向编码维度
pos_enc_out = 3 * (2 * config["pos_enc_dim"] + 1)
dir_enc_out = 3 * (2 * config["dir_enc_dim"] + 1)
# 读取网络深度和宽度(可配置)
netdepth_coarse = config.get("netdepth_coarse", 8)
netwidth_coarse = config.get("netwidth_coarse", 256)
netdepth_fine = config.get("netdepth_fine", 8)
netwidth_fine = config.get("netwidth_fine", 256)
# 构建跳跃连接
skips = config.get("skips", [4])
# 是否使用视角方向
self.use_viewdirs = config.get("use_viewdirs", True)
# 构建coarse和fine网络
if self.use_viewdirs:
# 位置编码 -> 密度 + 特征
self.pts_linears_coarse = self._build_pts_mlp(
input_dim=pos_enc_out,
width=netwidth_coarse,
depth=netdepth_coarse,
skips=skips
)
self.alpha_linear_coarse = nn.Linear(netwidth_coarse, 1)
self.feature_linear_coarse = nn.Linear(netwidth_coarse, netwidth_coarse)
# 特征 + 方向编码 -> RGB
self.views_linears_coarse = nn.ModuleList([
nn.Linear(netwidth_coarse + dir_enc_out, netwidth_coarse//2)
])
self.rgb_linear_coarse = nn.Linear(netwidth_coarse//2, 3)
# 对fine网络执行相同的操作
self.pts_linears_fine = self._build_pts_mlp(
input_dim=pos_enc_out,
width=netwidth_fine,
depth=netdepth_fine,
skips=skips
)
self.alpha_linear_fine = nn.Linear(netwidth_fine, 1)
self.feature_linear_fine = nn.Linear(netwidth_fine, netwidth_fine)
self.views_linears_fine = nn.ModuleList([
nn.Linear(netwidth_fine + dir_enc_out, netwidth_fine//2)
])
self.rgb_linear_fine = nn.Linear(netwidth_fine//2, 3)
else:
# 不使用视角方向的简化版本
self.pts_linears_coarse = self._build_pts_mlp(
input_dim=pos_enc_out,
width=netwidth_coarse,
depth=netdepth_coarse,
skips=skips
)
self.output_linear_coarse = nn.Linear(netwidth_coarse, 4)
self.pts_linears_fine = self._build_pts_mlp(
input_dim=pos_enc_out,
width=netwidth_fine,
depth=netdepth_fine,
skips=skips
)
self.output_linear_fine = nn.Linear(netwidth_fine, 4)
def _build_pts_mlp(self, input_dim, width, depth, skips):
"""构建处理位置编码的MLP网络支持跳跃连接"""
layers = nn.ModuleList()
# 第一层
layers.append(nn.Linear(input_dim, width))
# 中间层
for i in range(1, depth):
if i in skips:
layers.append(nn.Linear(input_dim + width, width))
else:
layers.append(nn.Linear(width, width))
return layers
def positional_encoding(self, x, L):
"""位置编码函数"""
encodings = [x]
for i in range(L):
encodings.append(torch.sin(2**i * x))
encodings.append(torch.cos(2**i * x))
return torch.cat(encodings, dim=-1)
def forward_mlp(self, pts_embed, viewdirs_embed, is_coarse=True):
"""前向传播MLP部分"""
if is_coarse:
pts_linears = self.pts_linears_coarse
alpha_linear = self.alpha_linear_coarse if self.use_viewdirs else None
feature_linear = self.feature_linear_coarse if self.use_viewdirs else None
views_linears = self.views_linears_coarse if self.use_viewdirs else None
rgb_linear = self.rgb_linear_coarse if self.use_viewdirs else None
output_linear = self.output_linear_coarse if not self.use_viewdirs else None
else:
pts_linears = self.pts_linears_fine
alpha_linear = self.alpha_linear_fine if self.use_viewdirs else None
feature_linear = self.feature_linear_fine if self.use_viewdirs else None
views_linears = self.views_linears_fine if self.use_viewdirs else None
rgb_linear = self.rgb_linear_fine if self.use_viewdirs else None
output_linear = self.output_linear_fine if not self.use_viewdirs else None
# 位置编码处理
h = pts_embed
for i, l in enumerate(pts_linears):
h = pts_linears[i](h)
h = F.relu(h)
# 处理跳跃连接
if i in self.config.get("skips", [4]):
h = torch.cat([pts_embed, h], -1)
if self.use_viewdirs:
# 分支1计算sigma
sigma = alpha_linear(h)
# 分支2计算颜色特征
feature = feature_linear(h)
# 结合方向编码
h = torch.cat([feature, viewdirs_embed], -1)
# 视角相关MLP
for i, l in enumerate(views_linears):
h = l(h)
h = F.relu(h)
# 输出RGB
rgb = rgb_linear(h)
rgb = torch.sigmoid(rgb) # [0,1]范围
outputs = torch.cat([rgb, sigma], -1)
else:
# 直接输出RGBA
outputs = output_linear(h)
rgb = torch.sigmoid(outputs[..., :3]) # [0,1]范围
sigma = outputs[..., 3:]
return rgb, sigma
def forward(self, pos, dir, coarse=True):
"""
前向传播
参数:
pos: 3D位置 [batch_size, ..., 3]
dir: 视角方向 [batch_size, ..., 3]
coarse: 是否使用coarse网络
返回:
sigma: 体积密度 [batch_size, ..., 1]
color: RGB颜色 [batch_size, ..., 3]
"""
# 位置和方向编码
pos_enc = self.positional_encoding(pos, self.config["pos_enc_dim"])
# 当使用视角方向时才编码方向
if self.use_viewdirs:
dir_normalized = F.normalize(dir, dim=-1)
dir_enc = self.positional_encoding(dir_normalized, self.config["dir_enc_dim"])
else:
dir_enc = None
# 选择使用coarse还是fine网络
color, sigma = self.forward_mlp(pos_enc, dir_enc, coarse)
return sigma, color