import os import json import torch import numpy as np from flask import Flask, request, jsonify import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.factory import ComponentFactory from PytorchBoot.runners.runner import Runner from PytorchBoot.utils import Log from utils.pts import PtsUtil from beans.predict_result import PredictResult @stereotype.runner("ug_inference_server") class UGInferencerServer(Runner): def __init__(self, config_path): super().__init__(config_path) ''' Web Server ''' self.app = Flask(__name__) ''' Pipeline ''' self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) self.pipeline = self.pipeline.to(self.device) self.pts_num = 8192 self.voxel_size = 0.002 ''' Experiment ''' self.load_experiment("ug_inference_server") def get_result(self, output_data): pred_pose_9d = output_data["pred_pose_9d"] pred_pose_9d = np.asarray(PredictResult(pred_pose_9d.cpu().numpy(), None, cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses, dtype=np.float32) result = { "pred_pose_9d": pred_pose_9d.tolist() } return result def run(self): Log.info("Loading from epoch {}.".format(self.current_epoch)) @self.app.route("/inference", methods=["POST"]) def inference(): data = request.json input_data = self.get_input_data(data) collated_input_data = self.collate_input(input_data) output_data = self.pipeline.forward_test(collated_input_data) result = self.get_result(output_data) return jsonify(result) self.app.run(host="0.0.0.0", port=5000) def get_checkpoint_path(self, is_last=False): return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, "Epoch_{}.pth".format( self.current_epoch if self.current_epoch != -1 and not is_last else "last")) def load_checkpoint(self, is_last=False): self.load(self.get_checkpoint_path(is_last)) Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") if is_last: checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) meta_path = os.path.join(checkpoint_root, "meta.json") if not os.path.exists(meta_path): raise FileNotFoundError( "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) file_path = os.path.join(checkpoint_root, "meta.json") with open(file_path, "r") as f: meta = json.load(f) self.current_epoch = meta["last_epoch"] self.current_iter = meta["last_iter"] def load_experiment(self, backup_name=None): super().load_experiment(backup_name) self.current_epoch = self.experiments_config["epoch"] self.load_checkpoint(is_last=(self.current_epoch == -1)) def create_experiment(self, backup_name=None): super().create_experiment(backup_name) def load(self, path): state_dict = torch.load(path) self.pipeline.load_state_dict(state_dict)