update inference server

This commit is contained in:
hofee 2025-01-07 19:32:02 +08:00
parent fca984e76b
commit 5a03659112
3 changed files with 12 additions and 4 deletions

View File

@ -6,16 +6,16 @@ runner:
cuda_visible_devices: "0,1,2,3,4,5,6,7" cuda_visible_devices: "0,1,2,3,4,5,6,7"
experiment: experiment:
name: train_ab_global_only_p++_wp name: train_ab_global_only_dense
root_dir: "experiments" root_dir: "experiments"
epoch: 922 # -1 stands for last epoch epoch: 441 # -1 stands for last epoch
test: test:
dataset_list: dataset_list:
- OmniObject3d_test - OmniObject3d_test
blender_script_path: "/media/hofee/data/project/python/nbv_reconstruction/blender/data_renderer.py" blender_script_path: "/media/hofee/data/project/python/nbv_reconstruction/blender/data_renderer.py"
output_dir: "/media/hofee/data/data/p++_wp_temp_cluster" output_dir: "/media/hofee/data/data/p++_dense"
pipeline: nbv_reconstruction_pipeline pipeline: nbv_reconstruction_pipeline
voxel_size: 0.003 voxel_size: 0.003
min_new_area: 1.0 min_new_area: 1.0
@ -62,6 +62,7 @@ pipeline:
module: module:
pointnet++_encoder: pointnet++_encoder:
in_dim: 3 in_dim: 3
params_name: dense
pointnet_encoder: pointnet_encoder:
in_dim: 3 in_dim: 3

View File

@ -12,6 +12,7 @@ from PytorchBoot.runners.runner import Runner
from PytorchBoot.utils import Log from PytorchBoot.utils import Log
from utils.pts import PtsUtil from utils.pts import PtsUtil
from beans.predict_result import PredictResult
@stereotype.runner("inferencer_server") @stereotype.runner("inferencer_server")
class InferencerServer(Runner): class InferencerServer(Runner):
@ -50,6 +51,7 @@ class InferencerServer(Runner):
def get_result(self, output_data): def get_result(self, output_data):
pred_pose_9d = output_data["pred_pose_9d"] pred_pose_9d = output_data["pred_pose_9d"]
pred_pose_9d = np.asarray(PredictResult(pred_pose_9d.cpu().numpy(), None, cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses, dtype=np.float32)
result = { result = {
"pred_pose_9d": pred_pose_9d.tolist() "pred_pose_9d": pred_pose_9d.tolist()
} }

View File

@ -156,7 +156,12 @@ class Inferencer(Runner):
# np.save(pred_9d_path, pred_pose_9d.cpu().numpy()) # np.save(pred_9d_path, pred_pose_9d.cpu().numpy())
# np.savetxt(pts_path, np_combined_scanned_pts) # np.savetxt(pts_path, np_combined_scanned_pts)
# # ----- ----- ----- # # ----- ----- -----
pred_pose_9d_candidates = PredictResult(pred_pose_9d.cpu().numpy(), input_pts=input_data["combined_scanned_pts"][0].cpu().numpy(), cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses predict_result = PredictResult(pred_pose_9d.cpu().numpy(), input_pts=input_data["combined_scanned_pts"][0].cpu().numpy(), cluster_params=dict(eps=0.25, min_samples=3))
# -----------------------
# import ipdb; ipdb.set_trace()
# predict_result.visualize()
# -----------------------
pred_pose_9d_candidates = predict_result.candidate_9d_poses
for pred_pose_9d in pred_pose_9d_candidates: for pred_pose_9d in pred_pose_9d_candidates:
#import ipdb; ipdb.set_trace() #import ipdb; ipdb.set_trace()
pred_pose_9d = torch.tensor(pred_pose_9d, dtype=torch.float32).to(self.device).unsqueeze(0) pred_pose_9d = torch.tensor(pred_pose_9d, dtype=torch.float32).to(self.device).unsqueeze(0)