import os import json import torch import numpy as np from flask import Flask, request, jsonify from PytorchBoot.config import ConfigManager import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory import ComponentFactory from PytorchBoot.runners.runner import Runner from PytorchBoot.utils import Log from utils.pts import PtsUtil from beans.predict_result import PredictResult @stereotype.runner("heuristic_inferencer_server") class HeuristicInferencerServer(Runner): def __init__(self, config_path): super().__init__(config_path) self.heuristic_method = ConfigManager.get(namespace.Stereotype.RUNNER, "heuristic_method") self.heuristic_method_config = ConfigManager.get("heuristic_methods", self.heuristic_method) ''' Web Server ''' self.app = Flask(__name__) ''' Pipeline ''' self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) self.pipeline = self.pipeline.to(self.device) self.pts_num = 8192 self.voxel_size = 0.002 ''' Experiment ''' self.load_experiment("inferencer_server") def generate_hemisphere_random_sequence(self, max_iter, config): """Generate a random hemisphere sampling sequence""" radius_fixed = config["radius_fixed"] fixed_radius = config["fixed_radius"] min_radius = config["min_radius"] max_radius = config["max_radius"] poses = [] center = np.array(config["center"]) for _ in range(max_iter): # 随机采样方向 direction = np.random.randn(3) direction[2] = abs(direction[2]) # 确保在上半球 direction = direction / np.linalg.norm(direction) # 确定半径 if radius_fixed: radius = fixed_radius else: radius = np.random.uniform(min_radius, max_radius) # 计算位置和朝向 position = center + direction * radius z_axis = -direction y_axis = np.array([0, 0, 1]) x_axis = np.cross(y_axis, z_axis) x_axis = x_axis / np.linalg.norm(x_axis) y_axis = np.cross(z_axis, x_axis) pose = np.eye(4) pose[:3,:3] = np.stack([x_axis, y_axis, z_axis], axis=1) pose[:3,3] = position poses.append(pose) return poses def generate_hemisphere_circle_sequence(self, config): """Generate a circular trajectory sampling sequence""" radius_fixed = config["radius_fixed"] fixed_radius = config["fixed_radius"] min_radius = config["min_radius"] max_radius = config["max_radius"] phi_list = config["phi_list"] circle_times = config["circle_times"] poses = [] center = np.array(config["center"]) for phi in phi_list: # 仰角 phi_rad = np.deg2rad(phi) for i in range(circle_times): # 方位角 theta = i * (2 * np.pi / circle_times) # 确定半径 if radius_fixed: radius = fixed_radius else: radius = np.random.uniform(min_radius, max_radius) # 球坐标转笛卡尔坐标 x = radius * np.cos(theta) * np.sin(phi_rad) y = radius * np.sin(theta) * np.sin(phi_rad) z = radius * np.cos(phi_rad) position = center + np.array([x, y, z]) # 计算朝向 direction = (center - position) / np.linalg.norm(center - position) z_axis = direction y_axis = np.array([0, 0, 1]) x_axis = np.cross(y_axis, z_axis) x_axis = x_axis / np.linalg.norm(x_axis) y_axis = np.cross(z_axis, x_axis) pose = np.eye(4) pose[:3,:3] = np.stack([x_axis, y_axis, z_axis], axis=1) pose[:3,3] = position poses.append(pose) return poses def generate_seq(self, max_iter=50): if self.heuristic_method == "hemisphere_random": pose_sequence = self.generate_hemisphere_random_sequence( max_iter, self.heuristic_method_config ) elif self.heuristic_method == "hemisphere_circle_trajectory": pose_sequence = self.generate_hemisphere_circle_sequence( self.heuristic_method_config ) else: raise ValueError(f"Unknown heuristic method: {self.heuristic_method}") return pose_sequence def get_input_data(self, data): input_data = {} scanned_pts = data["scanned_pts"] scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"] combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0) voxel_downsampled_combined_scanned_pts = PtsUtil.voxel_downsample_point_cloud( combined_scanned_views_pts, self.voxel_size ) fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud( voxel_downsampled_combined_scanned_pts, self.pts_num, require_idx=True ) input_data["scanned_pts"] = scanned_pts input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32) input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32) return input_data def get_result(self, output_data): pred_pose_9d = output_data["pred_pose_9d"] pred_pose_9d = np.asarray(PredictResult(pred_pose_9d.cpu().numpy(), None, cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses, dtype=np.float32) result = { "pred_pose_9d": pred_pose_9d.tolist() } return result def collate_input(self, input_data): collated_input_data = {} collated_input_data["scanned_pts"] = [torch.tensor(input_data["scanned_pts"], dtype=torch.float32, device=self.device)] collated_input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(input_data["scanned_n_to_world_pose_9d"], dtype=torch.float32, device=self.device)] collated_input_data["combined_scanned_pts"] = torch.tensor(input_data["combined_scanned_pts"], dtype=torch.float32, device=self.device).unsqueeze(0) return collated_input_data def do_inference(self, input_data): scanned_pts = input_data["scanned_pts"] def run(self): Log.info("Loading from epoch {}.".format(self.current_epoch)) @self.app.route("/inference", methods=["POST"]) def inference(): data = request.json input_data = self.get_input_data(data) collated_input_data = self.collate_input(input_data) output_data = self.do_inference(collated_input_data) result = self.get_result(output_data) return jsonify(result) self.app.run(host="0.0.0.0", port=5000) def get_checkpoint_path(self, is_last=False): return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, "Epoch_{}.pth".format( self.current_epoch if self.current_epoch != -1 and not is_last else "last")) def load_checkpoint(self, is_last=False): self.load(self.get_checkpoint_path(is_last)) Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") if is_last: checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) meta_path = os.path.join(checkpoint_root, "meta.json") if not os.path.exists(meta_path): raise FileNotFoundError( "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) file_path = os.path.join(checkpoint_root, "meta.json") with open(file_path, "r") as f: meta = json.load(f) self.current_epoch = meta["last_epoch"] self.current_iter = meta["last_iter"] def load_experiment(self, backup_name=None): super().load_experiment(backup_name) self.current_epoch = self.experiments_config["epoch"] self.load_checkpoint(is_last=(self.current_epoch == -1)) def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def load(self, path): state_dict = torch.load(path) self.pipeline.load_state_dict(state_dict)