nbv_rec_uncertainty_guide/ref_code/active_reconstruction.py
2025-04-20 10:26:09 +08:00

520 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import numpy as np
import os
import yaml
import time
from nerf_model import NeRF
from pipeline import ActiveReconstructionPolicy
from uncertainty_guide import UncertaintyGuideNeRF
import argparse
from typing import Dict, Any, List
from utils.volume_render_util import VolumeRendererUtil
import mcubes # 导入Python Marching Cubes库
import trimesh # 处理网格
from tqdm import tqdm # 进度条
class ActiveReconstruction:
"""基于NeRF不确定性引导的主动3D重建系统"""
def __init__(self, config_path: str):
"""
初始化主动重建系统
参数:
config_path: 配置文件路径
"""
# 加载配置
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
# 设置设备
self.device = torch.device(self.config.get("device", "cuda") if torch.cuda.is_available() else "cpu")
print(f"使用设备: {self.device}")
# 创建输出目录
self.output_dir = self.config.get("output_dir", "output")
os.makedirs(self.output_dir, exist_ok=True)
# 初始化NeRF模型
self._init_nerf_model()
# 初始化视图选择策略
self.policy = ActiveReconstructionPolicy(self.config)
def _init_nerf_model(self):
"""初始化NeRF模型"""
# 从配置中获取NeRF参数
nerf_config = self.config.get("nerf", {})
model_config = {
"pos_enc_dim": nerf_config.get("pos_enc_dim", 10),
"dir_enc_dim": nerf_config.get("dir_enc_dim", 4),
"netdepth_coarse": nerf_config.get("netdepth_coarse", 8),
"netwidth_coarse": nerf_config.get("netwidth_coarse", 256),
"netdepth_fine": nerf_config.get("netdepth_fine", 8),
"netwidth_fine": nerf_config.get("netwidth_fine", 256),
"skips": nerf_config.get("skips", [4]),
"use_viewdirs": nerf_config.get("use_viewdirs", True)
}
self.nerf_model = NeRF(model_config).to(self.device)
def _generate_rays(self,
poses: torch.Tensor,
H: int,
W: int,
focal: float) -> tuple:
"""
为每个相机位姿生成光线
参数:
poses: 相机位姿 [N, 4, 4]
H: 图像高度
W: 图像宽度
focal: 焦距
返回:
rays_o: 光线起点 [N, H*W, 3]
rays_d: 光线方向 [N, H*W, 3]
"""
# 创建像素坐标网格
i, j = torch.meshgrid(
torch.linspace(0, W-1, W),
torch.linspace(0, H-1, H),
indexing='ij'
)
i = i.t() # [H, W]
j = j.t() # [H, W]
# 转换为相机坐标系中的方向
dirs = torch.stack([
(i - W * 0.5) / focal,
-(j - H * 0.5) / focal,
-torch.ones_like(i)
], dim=-1) # [H, W, 3]
# 为每个位姿生成光线
rays_o_list = []
rays_d_list = []
for pose in poses:
# 转换光线方向到世界坐标系
rays_d = torch.sum(dirs[..., None, :] * pose[:3, :3], dim=-1) # [H, W, 3]
# 设置光线原点
rays_o = pose[:3, -1].expand(rays_d.shape) # [H, W, 3]
# 展平为批处理格式
rays_o = rays_o.reshape(-1, 3) # [H*W, 3]
rays_d = rays_d.reshape(-1, 3) # [H*W, 3]
rays_o_list.append(rays_o)
rays_d_list.append(rays_d)
# 组合所有位姿的光线
rays_o_all = torch.stack(rays_o_list, dim=0) # [N, H*W, 3]
rays_d_all = torch.stack(rays_d_list, dim=0) # [N, H*W, 3]
return rays_o_all, rays_d_all
def _sample_pixel_batch(self,
images: torch.Tensor,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
batch_size: int) -> tuple:
"""
随机采样像素批次
参数:
images: 图像数据 [N, H, W, 3]
rays_o: 光线起点 [N, H*W, 3]
rays_d: 光线方向 [N, H*W, 3]
batch_size: 批次大小
返回:
sampled_rays_o: 采样的光线起点 [batch_size, 3]
sampled_rays_d: 采样的光线方向 [batch_size, 3]
sampled_pixels: 采样的像素值 [batch_size, 3]
"""
# 获取图像形状
N = images.shape[0]
H = images.shape[1]
W = images.shape[2]
total_rays = N * H * W
# 将图像展平
pixels = images.reshape(N, -1, 3) # [N, H*W, 3]
# 随机选择批次
indices = torch.randint(0, total_rays, size=(batch_size,))
img_indices = indices // (H * W)
pixel_indices = indices % (H * W)
# 采样光线和像素
sampled_rays_o = torch.stack([rays_o[i, j] for i, j in zip(img_indices, pixel_indices)])
sampled_rays_d = torch.stack([rays_d[i, j] for i, j in zip(img_indices, pixel_indices)])
sampled_pixels = torch.stack([pixels[i, j] for i, j in zip(img_indices, pixel_indices)])
return sampled_rays_o, sampled_rays_d, sampled_pixels
def train_nerf(self,
images: torch.Tensor,
poses: torch.Tensor,
epochs: int = 5000,
batch_size: int = 4096,
lr: float = 5e-4,
start_from_model=None) -> float:
"""
训练NeRF模型
参数:
images: 图像数据 [N, H, W, 3]
poses: 相机位姿 [N, 4, 4]
epochs: 训练轮数
batch_size: 批量大小
lr: 学习率
start_from_model: 可选的初始模型状态
返回:
final_loss: 最终损失值
"""
print(f"开始训练NeRF模型使用{len(images)}张图像...")
# 获取图像和采样参数
H, W = images.shape[1], images.shape[2]
sampling_config = self.config.get("sampling", {})
camera_config = self.config.get("camera", {})
focal = camera_config.get("focal", 1000.0)
near = camera_config.get("near", 2.0)
far = camera_config.get("far", 6.0)
coarse_samples = sampling_config.get("coarse_samples", 64)
fine_samples = sampling_config.get("fine_samples", 128)
perturb = sampling_config.get("perturb", True)
# 如果提供了初始模型,使用它
if start_from_model is not None:
print("从现有模型初始化权重")
self.nerf_model.load_state_dict(start_from_model.state_dict())
# 设置优化器和损失函数
optimizer = torch.optim.Adam(self.nerf_model.parameters(), lr=lr)
mse_loss = torch.nn.MSELoss()
# 将模型设置为训练模式
self.nerf_model.train()
# 为所有图像生成光线(预计算光线可以加速训练)
rays_o, rays_d = self._generate_rays(poses, H, W, focal)
rays_o = rays_o.to(self.device)
rays_d = rays_d.to(self.device)
images = images.to(self.device)
# 训练循环
best_loss = float('inf')
for epoch in range(epochs):
# 随机采样一批光线
batch_rays_o, batch_rays_d, target_pixels = self._sample_pixel_batch(
images, rays_o, rays_d, batch_size)
# 光线方向归一化
batch_rays_d = torch.nn.functional.normalize(batch_rays_d, dim=-1)
# 创建近平面和远平面张量
near_tensor = torch.ones_like(batch_rays_o[..., 0]) * near
far_tensor = torch.ones_like(batch_rays_o[..., 0]) * far
# 使用体积渲染进行前向传播
# 首先进行粗采样渲染
optimizer.zero_grad()
# 体积渲染
rgb_map, _, _, _ = VolumeRendererUtil.render_rays(
self.nerf_model,
batch_rays_o,
batch_rays_d,
near_tensor,
far_tensor,
coarse_samples,
fine_samples,
perturb
)
# 计算损失并反向传播
loss = mse_loss(rgb_map, target_pixels)
loss.backward()
optimizer.step()
# 输出训练进度
if (epoch + 1) % 100 == 0:
psnr = -10.0 * torch.log10(loss)
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.6f}, PSNR: {psnr.item():.2f}")
# 保存最佳模型
if loss.item() < best_loss:
best_loss = loss.item()
torch.save(self.nerf_model.state_dict(), os.path.join(self.output_dir, "best_model.pth"))
# 加载最佳模型
self.nerf_model.load_state_dict(torch.load(os.path.join(self.output_dir, "best_model.pth")))
print(f"NeRF模型训练完成最终损失: {best_loss:.6f}")
return best_loss
def extract_mesh(self, output_path: str, resolution: int = 256, threshold: float = 50.0, bound: float = 2.0):
"""
从NeRF模型中提取3D网格使用Marching Cubes算法
参数:
output_path: 输出路径
resolution: 体素网格分辨率
threshold: 密度阈值,用于确定表面位置
bound: 体素网格边界大小
"""
print(f"从NeRF提取3D网格分辨率: {resolution}...")
# 设置网格提取参数
self.nerf_model.eval() # 设置为评估模式
# 定义采样网格
x = torch.linspace(-bound, bound, resolution)
y = torch.linspace(-bound, bound, resolution)
z = torch.linspace(-bound, bound, resolution)
# 创建采样点坐标网格
xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
# 准备查询点
points = torch.stack([xx, yy, zz], dim=-1).reshape(-1, 3).to(self.device)
# 创建密度场
print("正在计算体积密度场...")
density_field = torch.zeros((resolution, resolution, resolution))
# 分批处理以避免显存溢出
batch_size = 4096 # 根据GPU内存调整
with torch.no_grad():
for i in tqdm(range(0, points.shape[0], batch_size)):
# 获取当前批次的点
batch_points = points[i:i+batch_size]
# 计算密度 - 使用固定方向(这里使用+z方向
# 注意在NeRF中密度不依赖于视角方向只有颜色依赖视角
fixed_dirs = torch.zeros_like(batch_points)
fixed_dirs[..., 2] = 1.0 # 设置为+z方向
# 使用fine网络进行推理
sigma, _ = self.nerf_model(batch_points, fixed_dirs, coarse=False)
# 更新密度场
batch_indices = torch.arange(i, min(i+batch_size, points.shape[0]))
xyz_indices = torch.stack([
(points[batch_indices, 0] + bound) / (2 * bound) * (resolution - 1),
(points[batch_indices, 1] + bound) / (2 * bound) * (resolution - 1),
(points[batch_indices, 2] + bound) / (2 * bound) * (resolution - 1)
], dim=-1).long()
for j, (xi, yi, zi) in enumerate(xyz_indices):
density_field[xi, yi, zi] = sigma[j].cpu()
# 使用Marching Cubes提取网格
print("使用Marching Cubes提取网格...")
density_field_np = density_field.cpu().numpy()
vertices, triangles = mcubes.marching_cubes(density_field_np, threshold)
# 转换为正确的坐标系(视场的[-bound, bound]范围)
vertices = vertices / (resolution - 1) * (2 * bound) - bound
# 创建trimesh对象
mesh = trimesh.Trimesh(vertices=vertices, faces=triangles)
# 保存网格
mesh.export(output_path)
print(f"网格提取完成,保存至: {output_path}")
print(f"网格统计: {len(vertices)}个顶点, {len(triangles)}个三角面")
return mesh
def evaluate_reconstruction(self,
gt_mesh_path: str = None) -> Dict[str, float]:
"""
评估重建质量
参数:
gt_mesh_path: 真实网格路径(如果有)
返回:
metrics: 评估指标如F-score
"""
if gt_mesh_path is None:
print("没有提供真实网格,跳过评估")
return {}
print("评估重建质量...")
# 在实际实现中,这里应该有评估重建质量的代码
# 通常使用F-score、Chamfer距离等指标
# 为了简化,我们返回模拟的指标
metrics = {
"f_score": 0.85,
"precision": 0.87,
"recall": 0.83
}
print(f"评估结果: F-score={metrics['f_score']:.4f}, "
f"精确率={metrics['precision']:.4f}, 召回率={metrics['recall']:.4f}")
return metrics
def run_active_reconstruction(self,
initial_poses: np.ndarray,
initial_images: torch.Tensor = None,
max_iterations: int = 3) -> List[np.ndarray]:
"""
运行主动重建过程
参数:
initial_poses: 初始相机位姿
initial_images: 初始图像(如果有)
max_iterations: 最大迭代次数
返回:
selected_poses: 所有选定的相机位姿
"""
print("开始主动重建过程...")
# 初始训练,使用初始视图
if initial_images is None:
initial_images = self._simulate_image_capture(initial_poses)
# 使用初始图像训练模型
self.train_nerf(
initial_images,
torch.from_numpy(initial_poses).float().to(self.device),
epochs=self.config.get("reconstruction", {}).get("epochs_per_iteration", 2000)
)
# 保存初始模型
initial_model_path = os.path.join(self.output_dir, "initial_model.pth")
torch.save(self.nerf_model.state_dict(), initial_model_path)
initial_model = self.nerf_model.state_dict()
all_poses = initial_poses.copy()
current_poses = initial_poses.copy()
all_images = initial_images.clone()
# 提取初始网格
initial_mesh_path = os.path.join(self.output_dir, "initial_mesh.obj")
self.extract_mesh(
initial_mesh_path,
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
)
# 迭代执行主动重建
for iteration in range(max_iterations):
print(f"\n开始迭代 {iteration+1}/{max_iterations}")
# 选择下一批视角
next_views = self.policy.select_next_views(self.nerf_model, current_poses)
print(f"选择了 {len(next_views)} 个新视角")
# 采集新视角的图像
new_images = self._simulate_image_capture(next_views)
# 将新选择的视角添加到当前位姿和图像中
current_poses = np.concatenate([current_poses, next_views], axis=0)
all_poses = np.concatenate([all_poses, next_views], axis=0)
all_images = torch.cat([all_images, new_images], dim=0)
# 按照作者的描述,我们从初始模型重新初始化,而不是继续训练
# "After selecting additional images, we initialize the network with the model from the initialization step and refine the model further with the updated training set."
# 因此,我们先加载初始模型,然后用扩展的数据集重新训练
self.nerf_model.load_state_dict(torch.load(initial_model_path))
# 用扩展的数据集重新训练模型
self.train_nerf(
all_images,
torch.from_numpy(current_poses).float().to(self.device),
epochs=self.config.get("reconstruction", {}).get("epochs_per_iteration", 2000)
)
# 每次迭代后提取网格,以便观察重建质量的改进
iter_mesh_path = os.path.join(self.output_dir, f"mesh_iter_{iteration+1}.obj")
self.extract_mesh(
iter_mesh_path,
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
)
# 提取最终的3D网格
output_mesh_path = os.path.join(self.output_dir, "final_mesh.obj")
self.extract_mesh(
output_mesh_path,
resolution=self.config.get("reconstruction", {}).get("mesh_resolution", 256)
)
# 评估重建质量
self.evaluate_reconstruction()
print("主动重建过程完成")
return all_poses
def _simulate_image_capture(self, poses: np.ndarray) -> torch.Tensor:
"""
模拟图像采集过程(实际系统中应该从相机或数据集获取)
参数:
poses: 相机位姿
返回:
images: 模拟的图像
"""
# 模拟图像大小
camera_config = self.config.get("camera", {})
H, W = camera_config.get("height", 800), camera_config.get("width", 800)
# 创建随机图像(实际应来自相机或渲染)
images = torch.rand(len(poses), H, W, 3, device=self.device)
return images
def main():
parser = argparse.ArgumentParser(description="基于NeRF不确定性的主动3D重建")
parser.add_argument("--config", type=str, default="nbv_config.yaml", help="配置文件路径")
parser.add_argument("--synthetic", action="store_true", help="使用合成数据集")
args = parser.parse_args()
# 创建主动重建系统
reconstruction = ActiveReconstruction(args.config)
# 初始化一些相机位姿(通常来自中心圆环)
# 根据配置获取初始位姿数量
config = yaml.safe_load(open(args.config, 'r'))
initial_view_count = config.get("reconstruction", {}).get("initial_view_count", 15)
# 根据数据集类型调整初始视图数量
if args.synthetic:
initial_view_count = min(initial_view_count, 6) # 合成数据使用6个初始视图
print(f"使用合成数据集,初始视图数量: {initial_view_count}")
else:
print(f"使用真实数据集,初始视图数量: {initial_view_count}")
# 获取中间圆环上的相机位姿
# 假设poses是按圆环组织的我们选择中间圆环的部分位姿
middle_circle_index = config.get("view_selection", {}).get("n_circles", 5) // 2
poses_per_circle = config.get("view_selection", {}).get("n_poses_per_circle", 30)
# 等距选择初始位姿
start_index = middle_circle_index * poses_per_circle
step = poses_per_circle // initial_view_count
initial_pose_indices = [start_index + i * step for i in range(initial_view_count)]
initial_poses = reconstruction.policy.poses[initial_pose_indices]
# 运行主动重建
selected_poses = reconstruction.run_active_reconstruction(
initial_poses,
max_iterations=config.get("reconstruction", {}).get("max_iterations", 3)
)
print(f"主动重建完成,共选择了{len(selected_poses)}个相机位姿")
if __name__ == "__main__":
main()