upd
This commit is contained in:
216
runners/inference_heuristic_server.py
Normal file
216
runners/inference_heuristic_server.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.factory import ComponentFactory
|
||||
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.utils import Log
|
||||
|
||||
from utils.pts import PtsUtil
|
||||
from beans.predict_result import PredictResult
|
||||
|
||||
@stereotype.runner("heuristic_inferencer_server")
|
||||
class HeuristicInferencerServer(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
self.heuristic_method = ConfigManager.get(namespace.Stereotype.RUNNER, "heuristic_method")
|
||||
self.heuristic_method_config = ConfigManager.get("heuristic_methods", self.heuristic_method)
|
||||
''' Web Server '''
|
||||
self.app = Flask(__name__)
|
||||
''' Pipeline '''
|
||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
self.pts_num = 8192
|
||||
self.voxel_size = 0.002
|
||||
|
||||
''' Experiment '''
|
||||
self.load_experiment("inferencer_server")
|
||||
|
||||
def generate_hemisphere_random_sequence(self, max_iter, config):
|
||||
"""Generate a random hemisphere sampling sequence"""
|
||||
radius_fixed = config["radius_fixed"]
|
||||
fixed_radius = config["fixed_radius"]
|
||||
min_radius = config["min_radius"]
|
||||
max_radius = config["max_radius"]
|
||||
poses = []
|
||||
center = np.array(config["center"])
|
||||
|
||||
for _ in range(max_iter):
|
||||
# 随机采样方向
|
||||
direction = np.random.randn(3)
|
||||
direction[2] = abs(direction[2]) # 确保在上半球
|
||||
direction = direction / np.linalg.norm(direction)
|
||||
|
||||
# 确定半径
|
||||
if radius_fixed:
|
||||
radius = fixed_radius
|
||||
else:
|
||||
radius = np.random.uniform(min_radius, max_radius)
|
||||
|
||||
# 计算位置和朝向
|
||||
position = center + direction * radius
|
||||
z_axis = -direction
|
||||
y_axis = np.array([0, 0, 1])
|
||||
x_axis = np.cross(y_axis, z_axis)
|
||||
x_axis = x_axis / np.linalg.norm(x_axis)
|
||||
y_axis = np.cross(z_axis, x_axis)
|
||||
|
||||
pose = np.eye(4)
|
||||
pose[:3,:3] = np.stack([x_axis, y_axis, z_axis], axis=1)
|
||||
pose[:3,3] = position
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def generate_hemisphere_circle_sequence(self, config):
|
||||
"""Generate a circular trajectory sampling sequence"""
|
||||
radius_fixed = config["radius_fixed"]
|
||||
fixed_radius = config["fixed_radius"]
|
||||
min_radius = config["min_radius"]
|
||||
max_radius = config["max_radius"]
|
||||
phi_list = config["phi_list"]
|
||||
circle_times = config["circle_times"]
|
||||
|
||||
poses = []
|
||||
center = np.array(config["center"])
|
||||
|
||||
for phi in phi_list: # 仰角
|
||||
phi_rad = np.deg2rad(phi)
|
||||
for i in range(circle_times): # 方位角
|
||||
theta = i * (2 * np.pi / circle_times)
|
||||
|
||||
# 确定半径
|
||||
if radius_fixed:
|
||||
radius = fixed_radius
|
||||
else:
|
||||
radius = np.random.uniform(min_radius, max_radius)
|
||||
|
||||
# 球坐标转笛卡尔坐标
|
||||
x = radius * np.cos(theta) * np.sin(phi_rad)
|
||||
y = radius * np.sin(theta) * np.sin(phi_rad)
|
||||
z = radius * np.cos(phi_rad)
|
||||
position = center + np.array([x, y, z])
|
||||
|
||||
# 计算朝向
|
||||
direction = (center - position) / np.linalg.norm(center - position)
|
||||
z_axis = direction
|
||||
y_axis = np.array([0, 0, 1])
|
||||
x_axis = np.cross(y_axis, z_axis)
|
||||
x_axis = x_axis / np.linalg.norm(x_axis)
|
||||
y_axis = np.cross(z_axis, x_axis)
|
||||
|
||||
pose = np.eye(4)
|
||||
pose[:3,:3] = np.stack([x_axis, y_axis, z_axis], axis=1)
|
||||
pose[:3,3] = position
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
|
||||
def generate_seq(self, max_iter=50):
|
||||
if self.heuristic_method == "hemisphere_random":
|
||||
pose_sequence = self.generate_hemisphere_random_sequence(
|
||||
max_iter,
|
||||
self.heuristic_method_config
|
||||
)
|
||||
elif self.heuristic_method == "hemisphere_circle_trajectory":
|
||||
pose_sequence = self.generate_hemisphere_circle_sequence(
|
||||
self.heuristic_method_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown heuristic method: {self.heuristic_method}")
|
||||
return pose_sequence
|
||||
|
||||
def get_input_data(self, data):
|
||||
input_data = {}
|
||||
scanned_pts = data["scanned_pts"]
|
||||
scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"]
|
||||
combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0)
|
||||
voxel_downsampled_combined_scanned_pts = PtsUtil.voxel_downsample_point_cloud(
|
||||
combined_scanned_views_pts, self.voxel_size
|
||||
)
|
||||
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
|
||||
voxel_downsampled_combined_scanned_pts, self.pts_num, require_idx=True
|
||||
)
|
||||
|
||||
input_data["scanned_pts"] = scanned_pts
|
||||
input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32)
|
||||
input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32)
|
||||
return input_data
|
||||
|
||||
def get_result(self, output_data):
|
||||
|
||||
pred_pose_9d = output_data["pred_pose_9d"]
|
||||
pred_pose_9d = np.asarray(PredictResult(pred_pose_9d.cpu().numpy(), None, cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses, dtype=np.float32)
|
||||
result = {
|
||||
"pred_pose_9d": pred_pose_9d.tolist()
|
||||
}
|
||||
return result
|
||||
|
||||
def collate_input(self, input_data):
|
||||
collated_input_data = {}
|
||||
collated_input_data["scanned_pts"] = [torch.tensor(input_data["scanned_pts"], dtype=torch.float32, device=self.device)]
|
||||
collated_input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(input_data["scanned_n_to_world_pose_9d"], dtype=torch.float32, device=self.device)]
|
||||
collated_input_data["combined_scanned_pts"] = torch.tensor(input_data["combined_scanned_pts"], dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
return collated_input_data
|
||||
|
||||
def do_inference(self, input_data):
|
||||
scanned_pts = input_data["scanned_pts"]
|
||||
|
||||
def run(self):
|
||||
Log.info("Loading from epoch {}.".format(self.current_epoch))
|
||||
|
||||
@self.app.route("/inference", methods=["POST"])
|
||||
def inference():
|
||||
data = request.json
|
||||
input_data = self.get_input_data(data)
|
||||
collated_input_data = self.collate_input(input_data)
|
||||
output_data = self.do_inference(collated_input_data)
|
||||
result = self.get_result(output_data)
|
||||
return jsonify(result)
|
||||
|
||||
|
||||
self.app.run(host="0.0.0.0", port=5000)
|
||||
|
||||
def get_checkpoint_path(self, is_last=False):
|
||||
return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME,
|
||||
"Epoch_{}.pth".format(
|
||||
self.current_epoch if self.current_epoch != -1 and not is_last else "last"))
|
||||
|
||||
def load_checkpoint(self, is_last=False):
|
||||
self.load(self.get_checkpoint_path(is_last))
|
||||
Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}")
|
||||
if is_last:
|
||||
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
|
||||
meta_path = os.path.join(checkpoint_root, "meta.json")
|
||||
if not os.path.exists(meta_path):
|
||||
raise FileNotFoundError(
|
||||
"No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"]))
|
||||
file_path = os.path.join(checkpoint_root, "meta.json")
|
||||
with open(file_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
self.current_epoch = meta["last_epoch"]
|
||||
self.current_iter = meta["last_iter"]
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
self.current_epoch = self.experiments_config["epoch"]
|
||||
self.load_checkpoint(is_last=(self.current_epoch == -1))
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
|
||||
|
||||
def load(self, path):
|
||||
state_dict = torch.load(path)
|
||||
self.pipeline.load_state_dict(state_dict)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user