diff --git a/runners/inferece_server.py b/runners/inferece_server.py new file mode 100644 index 0000000..239a0e0 --- /dev/null +++ b/runners/inferece_server.py @@ -0,0 +1,109 @@ +import os +import json +import torch +import numpy as np +from flask import Flask, request, jsonify + +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory import ComponentFactory + +from PytorchBoot.runners.runner import Runner +from PytorchBoot.utils import Log + +from utils.pts import PtsUtil + +@stereotype.runner("inferencer") +class InferencerServer(Runner): + def __init__(self, config_path): + super().__init__(config_path) + + ''' Web Server ''' + self.app = Flask(__name__) + ''' Pipeline ''' + self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] + self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) + self.pipeline = self.pipeline.to(self.device) + + ''' Experiment ''' + self.load_experiment("nbv_evaluator") + + def get_input_data(self, data): + input_data = {} + scanned_pts = data["scanned_pts"] + scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"] + combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0) + fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud( + combined_scanned_views_pts, self.pts_num, require_idx=True + ) + combined_scanned_views_pts_mask = np.zeros(len(scanned_pts), dtype=np.uint8) + start_idx = 0 + for i in range(len(scanned_pts)): + end_idx = start_idx + len(scanned_pts[i]) + combined_scanned_views_pts_mask[start_idx:end_idx] = i + start_idx = end_idx + + fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx] + + input_data["scanned_pts_mask"] = np.asarray(fps_downsampled_combined_scanned_pts_mask, dtype=np.uint8) + input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32) + input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32) + return input_data + + def get_result(self, output_data): + + estimated_delta_rot_9d = output_data["pred_pose_9d"] + result = { + "estimated_delta_rot_9d": estimated_delta_rot_9d.tolist() + } + return result + + def run(self): + Log.info("Loading from epoch {}.".format(self.current_epoch)) + + @self.app.route("/inference", methods=["POST"]) + def inference(): + data = request.json + input_data = self.get_input_data(data) + output_data = self.pipeline.forward_test(input_data) + result = self.get_result(output_data) + return jsonify(result) + + + self.app.run(host="0.0.0.0", port=5000) + + def get_checkpoint_path(self, is_last=False): + return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, + "Epoch_{}.pth".format( + self.current_epoch if self.current_epoch != -1 and not is_last else "last")) + + def load_checkpoint(self, is_last=False): + self.load(self.get_checkpoint_path(is_last)) + Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") + if is_last: + checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) + meta_path = os.path.join(checkpoint_root, "meta.json") + if not os.path.exists(meta_path): + raise FileNotFoundError( + "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) + file_path = os.path.join(checkpoint_root, "meta.json") + with open(file_path, "r") as f: + meta = json.load(f) + self.current_epoch = meta["last_epoch"] + self.current_iter = meta["last_iter"] + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + self.current_epoch = self.experiments_config["epoch"] + self.load_checkpoint(is_last=(self.current_epoch == -1)) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + + + def load(self, path): + state_dict = torch.load(path) + self.pipeline.load_state_dict(state_dict) + + +