upd
This commit is contained in:
@@ -13,7 +13,7 @@ from PytorchBoot.utils import Log
|
||||
|
||||
from utils.pts import PtsUtil
|
||||
|
||||
@stereotype.runner("inferencer")
|
||||
@stereotype.runner("inferencer_server")
|
||||
class InferencerServer(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
@@ -24,9 +24,10 @@ class InferencerServer(Runner):
|
||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
self.pts_num = 8192
|
||||
|
||||
''' Experiment '''
|
||||
self.load_experiment("nbv_evaluator")
|
||||
self.load_experiment("inferencer_server")
|
||||
|
||||
def get_input_data(self, data):
|
||||
input_data = {}
|
||||
@@ -36,28 +37,36 @@ class InferencerServer(Runner):
|
||||
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
|
||||
combined_scanned_views_pts, self.pts_num, require_idx=True
|
||||
)
|
||||
combined_scanned_views_pts_mask = np.zeros(len(scanned_pts), dtype=np.uint8)
|
||||
start_idx = 0
|
||||
for i in range(len(scanned_pts)):
|
||||
end_idx = start_idx + len(scanned_pts[i])
|
||||
combined_scanned_views_pts_mask[start_idx:end_idx] = i
|
||||
start_idx = end_idx
|
||||
# combined_scanned_views_pts_mask = np.zeros(len(scanned_pts), dtype=np.uint8)
|
||||
# start_idx = 0
|
||||
# for i in range(len(scanned_pts)):
|
||||
# end_idx = start_idx + len(scanned_pts[i])
|
||||
# combined_scanned_views_pts_mask[start_idx:end_idx] = i
|
||||
# start_idx = end_idx
|
||||
|
||||
fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx]
|
||||
# fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx]
|
||||
|
||||
input_data["scanned_pts_mask"] = np.asarray(fps_downsampled_combined_scanned_pts_mask, dtype=np.uint8)
|
||||
input_data["scanned_pts"] = scanned_pts
|
||||
# input_data["scanned_pts_mask"] = np.asarray(fps_downsampled_combined_scanned_pts_mask, dtype=np.uint8)
|
||||
input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32)
|
||||
input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32)
|
||||
return input_data
|
||||
|
||||
def get_result(self, output_data):
|
||||
|
||||
estimated_delta_rot_9d = output_data["pred_pose_9d"]
|
||||
pred_pose_9d = output_data["pred_pose_9d"]
|
||||
result = {
|
||||
"estimated_delta_rot_9d": estimated_delta_rot_9d.tolist()
|
||||
"pred_pose_9d": pred_pose_9d.tolist()
|
||||
}
|
||||
return result
|
||||
|
||||
def collate_input(self, input_data):
|
||||
collated_input_data = {}
|
||||
collated_input_data["scanned_pts"] = [torch.tensor(input_data["scanned_pts"], dtype=torch.float32, device=self.device)]
|
||||
collated_input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(input_data["scanned_n_to_world_pose_9d"], dtype=torch.float32, device=self.device)]
|
||||
collated_input_data["combined_scanned_pts"] = torch.tensor(input_data["combined_scanned_pts"], dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
return collated_input_data
|
||||
|
||||
def run(self):
|
||||
Log.info("Loading from epoch {}.".format(self.current_epoch))
|
||||
|
||||
@@ -65,7 +74,8 @@ class InferencerServer(Runner):
|
||||
def inference():
|
||||
data = request.json
|
||||
input_data = self.get_input_data(data)
|
||||
output_data = self.pipeline.forward_test(input_data)
|
||||
collated_input_data = self.collate_input(input_data)
|
||||
output_data = self.pipeline.forward_test(collated_input_data)
|
||||
result = self.get_result(output_data)
|
||||
return jsonify(result)
|
||||
|
@@ -115,9 +115,12 @@ class Inferencer(Runner):
|
||||
retry = 0
|
||||
pred_cr_seq = [last_pred_cr]
|
||||
success = 0
|
||||
import time
|
||||
while len(pred_cr_seq) < max_iter and retry < max_retry:
|
||||
|
||||
start_time = time.time()
|
||||
output = self.pipeline(input_data)
|
||||
end_time = time.time()
|
||||
print(f"Time taken for inference: {end_time - start_time} seconds")
|
||||
pred_pose_9d = output["pred_pose_9d"]
|
||||
pred_pose = torch.eye(4, device=pred_pose_9d.device)
|
||||
|
||||
@@ -125,7 +128,10 @@ class Inferencer(Runner):
|
||||
pred_pose[:3,3] = pred_pose_9d[0,6:]
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
new_target_pts, new_target_normals = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
|
||||
end_time = time.time()
|
||||
print(f"Time taken for rendering: {end_time - start_time} seconds")
|
||||
except Exception as e:
|
||||
Log.warning(f"Error in scene {scene_path}, {e}")
|
||||
print("current pose: ", pred_pose)
|
||||
@@ -140,8 +146,10 @@ class Inferencer(Runner):
|
||||
retry += 1
|
||||
continue
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
pred_cr, new_added_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
|
||||
end_time = time.time()
|
||||
print(f"Time taken for coverage rate computation: {end_time - start_time} seconds")
|
||||
print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"])
|
||||
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
|
||||
print("max coverage rate reached!: ", pred_cr)
|
||||
|
Reference in New Issue
Block a user