import os import json from utils.render import RenderUtil from utils.pose import PoseUtil from utils.pts import PtsUtil from utils.reconstruction import ReconstructionUtil import torch from tqdm import tqdm import numpy as np import pickle from PytorchBoot.config import ConfigManager import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory import ComponentFactory from PytorchBoot.dataset import BaseDataset from PytorchBoot.runners.runner import Runner from PytorchBoot.utils import Log from PytorchBoot.status import status_manager from utils.data_load import DataLoadUtil @stereotype.runner("heuristic") class Heuristic(Runner): def __init__(self, config_path): super().__init__(config_path) self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path") self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir") self.voxel_size = ConfigManager.get(namespace.Stereotype.RUNNER, "voxel_size") self.min_new_area = ConfigManager.get(namespace.Stereotype.RUNNER, "min_new_area") self.heuristic_method = ConfigManager.get(namespace.Stereotype.RUNNER, "heuristic_method") self.heuristic_method_config = ConfigManager.get("heuristic_methods", self.heuristic_method) CM = 0.01 self.min_new_pts_num = self.min_new_area * (CM / self.voxel_size) **2 ''' Experiment ''' self.load_experiment("nbv_evaluator") self.stat_result_path = os.path.join(self.output_dir, "stat.json") if os.path.exists(self.stat_result_path): with open(self.stat_result_path, "r") as f: self.stat_result = json.load(f) else: self.stat_result = {} ''' Test ''' self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST) self.test_dataset_name_list = self.test_config["dataset_list"] self.test_set_list = [] self.test_writer_list = [] seen_name = set() for test_dataset_name in self.test_dataset_name_list: if test_dataset_name not in seen_name: seen_name.add(test_dataset_name) else: raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name)) test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name) self.test_set_list.append(test_set) self.print_info() def run(self): Log.info("Loading from epoch {}.".format(self.current_epoch)) self.run_heuristic() Log.success("Inference finished.") def run_heuristic(self): test_set: BaseDataset for dataset_idx, test_set in enumerate(self.test_set_list): status_manager.set_progress("heuristic", "heuristic", f"dataset", dataset_idx, len(self.test_set_list)) test_set_name = test_set.get_name() total=int(len(test_set)) for i in tqdm(range(total), desc=f"Processing {test_set_name}", ncols=100): try: data = test_set.__getitem__(i) scene_name = data["scene_name"] inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl") if os.path.exists(inference_result_path): Log.info(f"Inference result already exists for scene: {scene_name}") continue status_manager.set_progress("heuristic", "heuristic", f"Batch[{test_set_name}]", i+1, total) output = self.predict_sequence(data) self.save_inference_result(test_set_name, data["scene_name"], output) except Exception as e: print(e) Log.error(f"Error, {e}") continue status_manager.set_progress("heuristic", "heuristic", f"dataset", len(self.test_set_list), len(self.test_set_list)) def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=5000, max_retry=5000, max_success=5000): scene_name = data["scene_name"] Log.info(f"Processing scene: {scene_name}") status_manager.set_status("heuristic", "heuristic", "scene", scene_name) ''' data for rendering ''' scene_path = data["scene_path"] O_to_L_pose = data["O_to_L_pose"] voxel_threshold = self.voxel_size filter_degree = 75 down_sampled_model_pts = data["gt_pts"] first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0] first_frame_to_world = np.eye(4) first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6]) first_frame_to_world[:3,3] = first_frame_to_world_9d[6:] # 获取扫描点 root = os.path.dirname(scene_path) display_table_info = DataLoadUtil.get_display_table_info(root, scene_name) radius = display_table_info["radius"] scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0,display_table_radius=radius)) # 生成位姿序列 if self.heuristic_method == "hemisphere_random": pose_sequence = self.generate_hemisphere_random_sequence( max_iter, self.heuristic_method_config ) elif self.heuristic_method == "hemisphere_circle_trajectory": pose_sequence = self.generate_hemisphere_circle_sequence( self.heuristic_method_config ) else: raise ValueError(f"Unknown heuristic method: {self.heuristic_method}") # 执行第一帧 first_frame_target_pts, _, first_frame_scan_points_indices = RenderUtil.render_pts( first_frame_to_world, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose ) # 初始化结果存储 scanned_view_pts = [first_frame_target_pts] history_indices = [first_frame_scan_points_indices] pred_cr_seq = [] retry_duplication_pose = [] retry_no_pts_pose = [] retry_overlap_pose = [] pose_9d_seq = [first_frame_to_world_9d] last_pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold) pred_cr_seq.append(last_pred_cr) last_pts_num = PtsUtil.voxel_downsample_point_cloud(first_frame_target_pts, voxel_threshold).shape[0] # 执行序列 retry = 0 success = 0 #import ipdb; ipdb.set_trace() combined_scanned_pts_tensor = torch.tensor([0,0,0]) cnt = 0 for pred_pose in pose_sequence: cnt += 1 if retry >= max_retry or success >= max_success: break Log.green(f"迭代: {cnt}/{len(pose_sequence)}, 重试: {retry}/{max_retry}, 成功: {success}/{max_success}") try: new_target_pts, _, new_scan_points_indices = RenderUtil.render_pts( pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose ) # 检查扫描点重叠 if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): curr_overlap_area_threshold = overlap_area_threshold else: curr_overlap_area_threshold = overlap_area_threshold * 0.5 # 检查点云重叠 downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold) overlap, _ = ReconstructionUtil.check_overlap( downsampled_new_target_pts, down_sampled_model_pts, overlap_area_threshold=curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num=True ) if not overlap: Log.yellow("no overlap!") retry += 1 retry_overlap_pose.append(pred_pose.tolist()) continue if new_target_pts.shape[0] == 0: Log.red("新视角无点云") retry_no_pts_pose.append(pred_pose.tolist()) retry += 1 continue history_indices.append(new_scan_points_indices) # 计算覆盖率 pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) Log.yellow(f"覆盖率: {pred_cr}, 上一次: {last_pred_cr}, 最大: {data['seq_max_coverage_rate']}") # 更新结果 pred_cr_seq.append(pred_cr) scanned_view_pts.append(new_target_pts) pose_6d = PoseUtil.matrix_to_rotation_6d_numpy(pred_pose[:3,:3]) pose_9d = np.concatenate([ pose_6d, pred_pose[:3,3] ]) pose_9d_seq.append(pose_9d) # 处理点云数据用于combined_scanned_pts combined_scanned_pts = np.vstack(scanned_view_pts) voxel_downsampled_pts, _ = self.voxel_downsample_with_mapping(combined_scanned_pts, voxel_threshold) random_downsampled_pts, _ = PtsUtil.random_downsample_point_cloud(voxel_downsampled_pts, 8192, require_idx=True) combined_scanned_pts_tensor = torch.tensor(random_downsampled_pts, dtype=torch.float32) # 检查点数增量 pts_num = voxel_downsampled_pts.shape[0] Log.info(f"点数增量: {pts_num - last_pts_num}, 当前: {pts_num}, 上一次: {last_pts_num}") if pts_num - last_pts_num < self.min_new_pts_num: if pred_cr <= data["seq_max_coverage_rate"] - 1e-2: retry += 1 retry_duplication_pose.append(pred_pose.tolist()) Log.red(f"点数增量过小 < {self.min_new_pts_num}") else: success += 1 Log.success(f"达到目标覆盖率") last_pts_num = pts_num last_pred_cr = pred_cr if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: Log.success(f"达到最大覆盖率: {pred_cr}") except Exception as e: import traceback traceback.print_exc() Log.error(f"场景 {scene_path} 处理出错: {e}") retry_no_pts_pose.append(pred_pose.tolist()) retry += 1 continue # 返回结果 result = { "pred_pose_9d_seq": pose_9d_seq, "combined_scanned_pts_tensor": combined_scanned_pts_tensor, "target_pts_seq": scanned_view_pts, "coverage_rate_seq": pred_cr_seq, "max_coverage_rate": data["seq_max_coverage_rate"], "pred_max_coverage_rate": max(pred_cr_seq), "scene_name": scene_name, "retry_no_pts_pose": retry_no_pts_pose, "retry_duplication_pose": retry_duplication_pose, "retry_overlap_pose": retry_overlap_pose, "best_seq_len": data["best_seq_len"], } self.stat_result[scene_name] = { "coverage_rate_seq": pred_cr_seq, "pred_max_coverage_rate": max(pred_cr_seq), "pred_seq_len": len(pred_cr_seq), } print('success rate: ', max(pred_cr_seq)) return result def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) idx_sort = np.argsort(inverse) idx_unique = idx_sort[np.cumsum(counts)-counts] downsampled_points = point_cloud[idx_unique] return downsampled_points, inverse def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005): if new_pts is not None: new_scanned_view_pts = scanned_view_pts + [new_pts] else: new_scanned_view_pts = scanned_view_pts combined_point_cloud = np.vstack(new_scanned_view_pts) down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold) def save_inference_result(self, dataset_name, scene_name, output): dataset_dir = os.path.join(self.output_dir, dataset_name) if not os.path.exists(dataset_dir): os.makedirs(dataset_dir) output_path = os.path.join(dataset_dir, f"{scene_name}.pkl") pickle.dump(output, open(output_path, "wb")) with open(self.stat_result_path, "w") as f: json.dump(self.stat_result, f) def get_checkpoint_path(self, is_last=False): return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, "Epoch_{}.pth".format( self.current_epoch if self.current_epoch != -1 and not is_last else "last")) def load_checkpoint(self, is_last=False): self.load(self.get_checkpoint_path(is_last)) Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") if is_last: checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) meta_path = os.path.join(checkpoint_root, "meta.json") if not os.path.exists(meta_path): raise FileNotFoundError( "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) file_path = os.path.join(checkpoint_root, "meta.json") with open(file_path, "r") as f: meta = json.load(f) self.current_epoch = meta["last_epoch"] self.current_iter = meta["last_iter"] def load_experiment(self, backup_name=None): super().load_experiment(backup_name) self.current_epoch = self.experiments_config["epoch"] def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def print_info(self): def print_dataset(dataset: BaseDataset): config = dataset.get_config() name = dataset.get_name() Log.blue(f"Dataset: {name}") for k,v in config.items(): Log.blue(f"\t{k}: {v}") super().print_info() table_size = 70 Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') for i, test_set in enumerate(self.test_set_list): Log.blue(f"test dataset {i}: ") print_dataset(test_set) Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+') def generate_hemisphere_random_sequence(self, max_iter, config): """Generate a random hemisphere sampling sequence""" radius_fixed = config["radius_fixed"] fixed_radius = config["fixed_radius"] min_radius = config["min_radius"] max_radius = config["max_radius"] poses = [] center = np.array(config["center"]) for _ in range(max_iter): # 随机采样方向 direction = np.random.randn(3) direction[2] = abs(direction[2]) # 确保在上半球 direction = direction / np.linalg.norm(direction) # 确定半径 if radius_fixed: radius = fixed_radius else: radius = np.random.uniform(min_radius, max_radius) # 计算位置和朝向 position = center + direction * radius z_axis = -direction y_axis = np.array([0, 0, 1]) x_axis = np.cross(y_axis, z_axis) x_axis = x_axis / np.linalg.norm(x_axis) y_axis = np.cross(z_axis, x_axis) pose = np.eye(4) pose[:3,:3] = np.stack([x_axis, y_axis, z_axis], axis=1) pose[:3,3] = position poses.append(pose) return poses def generate_hemisphere_circle_sequence(self, config): """Generate a circular trajectory sampling sequence""" radius_fixed = config["radius_fixed"] fixed_radius = config["fixed_radius"] min_radius = config["min_radius"] max_radius = config["max_radius"] phi_list = config["phi_list"] circle_times = config["circle_times"] poses = [] center = np.array(config["center"]) for phi in phi_list: # 仰角 phi_rad = np.deg2rad(phi) for i in range(circle_times): # 方位角 theta = i * (2 * np.pi / circle_times) # 确定半径 if radius_fixed: radius = fixed_radius else: radius = np.random.uniform(min_radius, max_radius) # 球坐标转笛卡尔坐标 x = radius * np.cos(theta) * np.sin(phi_rad) y = radius * np.sin(theta) * np.sin(phi_rad) z = radius * np.cos(phi_rad) position = center + np.array([x, y, z]) # 计算朝向 direction = (center - position) / np.linalg.norm(center - position) z_axis = direction y_axis = np.array([0, 0, 1]) x_axis = np.cross(y_axis, z_axis) x_axis = x_axis / np.linalg.norm(x_axis) y_axis = np.cross(z_axis, x_axis) pose = np.eye(4) pose[:3,:3] = np.stack([x_axis, y_axis, z_axis], axis=1) pose[:3,3] = position poses.append(pose) return poses