import torch import torch.nn as nn import torch.nn.functional as F from PytorchBoot.stereotype import stereotype @stereotype.module("nerf") class NeRF(nn.Module): def __init__(self, config): super().__init__() self.config = config # 读取位置和方向编码维度 pos_enc_out = 3 * (2 * config["pos_enc_dim"] + 1) dir_enc_out = 3 * (2 * config["dir_enc_dim"] + 1) # 读取网络深度和宽度(可配置) netdepth_coarse = config.get("netdepth_coarse", 8) netwidth_coarse = config.get("netwidth_coarse", 256) netdepth_fine = config.get("netdepth_fine", 8) netwidth_fine = config.get("netwidth_fine", 256) # 构建跳跃连接 skips = config.get("skips", [4]) # 是否使用视角方向 self.use_viewdirs = config.get("use_viewdirs", True) # 构建coarse和fine网络 if self.use_viewdirs: # 位置编码 -> 密度 + 特征 self.pts_linears_coarse = self._build_pts_mlp( input_dim=pos_enc_out, width=netwidth_coarse, depth=netdepth_coarse, skips=skips ) self.alpha_linear_coarse = nn.Linear(netwidth_coarse, 1) self.feature_linear_coarse = nn.Linear(netwidth_coarse, netwidth_coarse) # 特征 + 方向编码 -> RGB self.views_linears_coarse = nn.ModuleList([ nn.Linear(netwidth_coarse + dir_enc_out, netwidth_coarse//2) ]) self.rgb_linear_coarse = nn.Linear(netwidth_coarse//2, 3) # 对fine网络执行相同的操作 self.pts_linears_fine = self._build_pts_mlp( input_dim=pos_enc_out, width=netwidth_fine, depth=netdepth_fine, skips=skips ) self.alpha_linear_fine = nn.Linear(netwidth_fine, 1) self.feature_linear_fine = nn.Linear(netwidth_fine, netwidth_fine) self.views_linears_fine = nn.ModuleList([ nn.Linear(netwidth_fine + dir_enc_out, netwidth_fine//2) ]) self.rgb_linear_fine = nn.Linear(netwidth_fine//2, 3) else: # 不使用视角方向的简化版本 self.pts_linears_coarse = self._build_pts_mlp( input_dim=pos_enc_out, width=netwidth_coarse, depth=netdepth_coarse, skips=skips ) self.output_linear_coarse = nn.Linear(netwidth_coarse, 4) self.pts_linears_fine = self._build_pts_mlp( input_dim=pos_enc_out, width=netwidth_fine, depth=netdepth_fine, skips=skips ) self.output_linear_fine = nn.Linear(netwidth_fine, 4) def _build_pts_mlp(self, input_dim, width, depth, skips): """构建处理位置编码的MLP网络,支持跳跃连接""" layers = nn.ModuleList() # 第一层 layers.append(nn.Linear(input_dim, width)) # 中间层 for i in range(1, depth): if i in skips: layers.append(nn.Linear(input_dim + width, width)) else: layers.append(nn.Linear(width, width)) return layers def positional_encoding(self, x, L): """位置编码函数""" encodings = [x] for i in range(L): encodings.append(torch.sin(2**i * x)) encodings.append(torch.cos(2**i * x)) return torch.cat(encodings, dim=-1) def forward_mlp(self, pts_embed, viewdirs_embed, is_coarse=True): """前向传播MLP部分""" if is_coarse: pts_linears = self.pts_linears_coarse alpha_linear = self.alpha_linear_coarse if self.use_viewdirs else None feature_linear = self.feature_linear_coarse if self.use_viewdirs else None views_linears = self.views_linears_coarse if self.use_viewdirs else None rgb_linear = self.rgb_linear_coarse if self.use_viewdirs else None output_linear = self.output_linear_coarse if not self.use_viewdirs else None else: pts_linears = self.pts_linears_fine alpha_linear = self.alpha_linear_fine if self.use_viewdirs else None feature_linear = self.feature_linear_fine if self.use_viewdirs else None views_linears = self.views_linears_fine if self.use_viewdirs else None rgb_linear = self.rgb_linear_fine if self.use_viewdirs else None output_linear = self.output_linear_fine if not self.use_viewdirs else None # 位置编码处理 h = pts_embed for i, l in enumerate(pts_linears): h = pts_linears[i](h) h = F.relu(h) # 处理跳跃连接 if i in self.config.get("skips", [4]): h = torch.cat([pts_embed, h], -1) if self.use_viewdirs: # 分支1:计算sigma sigma = alpha_linear(h) # 分支2:计算颜色特征 feature = feature_linear(h) # 结合方向编码 h = torch.cat([feature, viewdirs_embed], -1) # 视角相关MLP for i, l in enumerate(views_linears): h = l(h) h = F.relu(h) # 输出RGB rgb = rgb_linear(h) rgb = torch.sigmoid(rgb) # [0,1]范围 outputs = torch.cat([rgb, sigma], -1) else: # 直接输出RGBA outputs = output_linear(h) rgb = torch.sigmoid(outputs[..., :3]) # [0,1]范围 sigma = outputs[..., 3:] return rgb, sigma def forward(self, pos, dir, coarse=True): """ 前向传播 参数: pos: 3D位置 [batch_size, ..., 3] dir: 视角方向 [batch_size, ..., 3] coarse: 是否使用coarse网络 返回: sigma: 体积密度 [batch_size, ..., 1] color: RGB颜色 [batch_size, ..., 3] """ # 位置和方向编码 pos_enc = self.positional_encoding(pos, self.config["pos_enc_dim"]) # 当使用视角方向时才编码方向 if self.use_viewdirs: dir_normalized = F.normalize(dir, dim=-1) dir_enc = self.positional_encoding(dir_normalized, self.config["dir_enc_dim"]) else: dir_enc = None # 选择使用coarse还是fine网络 color, sigma = self.forward_mlp(pos_enc, dir_enc, coarse) return sigma, color