add multi seq training
This commit is contained in:
@@ -92,12 +92,16 @@ class Inferencer(Runner):
|
||||
model_points_normals = data["model_points_normals"][0]
|
||||
model_pts = model_points_normals[:,:3]
|
||||
down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold)
|
||||
first_frame_to_world = data["first_frame_to_world"][0]
|
||||
first_frame_to_world_9d = data["first_to_world_9d"][0]
|
||||
first_frame_to_world = torch.eye(4, device=first_frame_to_world_9d.device)
|
||||
first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(first_frame_to_world_9d[:,:6])[0]
|
||||
first_frame_to_world[:3,3] = first_frame_to_world_9d[0,6:]
|
||||
first_frame_to_world = first_frame_to_world.to(self.device)
|
||||
|
||||
''' data for inference '''
|
||||
input_data = {}
|
||||
input_data["scanned_pts"] = [data["first_pts"][0].to(self.device)]
|
||||
input_data["scanned_n_to_world_pose_9d"] = [data["first_frame_to_world"][0].to(self.device)]
|
||||
input_data["scanned_n_to_world_pose_9d"] = [data["first_to_world_9d"][0].to(self.device)]
|
||||
input_data["mode"] = namespace.Mode.TEST
|
||||
input_pts_N = input_data["scanned_pts"][0].shape[1]
|
||||
|
||||
@@ -113,20 +117,19 @@ class Inferencer(Runner):
|
||||
while len(pred_cr_seq) < max_iter and retry < max_retry:
|
||||
|
||||
output = self.pipeline(input_data)
|
||||
next_pose_9d = output["pred_pose_9d"]
|
||||
pred_pose = torch.eye(4, device=next_pose_9d.device)
|
||||
pred_pose_9d = output["pred_pose_9d"]
|
||||
pred_pose = torch.eye(4, device=pred_pose_9d.device)
|
||||
|
||||
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(next_pose_9d[:,:6])[0]
|
||||
pred_pose[:3,3] = next_pose_9d[0,6:]
|
||||
pred_n_to_world_pose_mat = torch.matmul(first_frame_to_world, pred_pose)
|
||||
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0]
|
||||
pred_pose[:3,3] = pred_pose_9d[0,6:]
|
||||
|
||||
try:
|
||||
new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_n_to_world_pose_mat, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True)
|
||||
new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True)
|
||||
except Exception as e:
|
||||
Log.warning(f"Error in scene {scene_path}, {e}")
|
||||
print("current pose: ", pred_pose)
|
||||
print("curr_pred_cr: ", last_pred_cr)
|
||||
retry_no_pts_pose.append(pred_n_to_world_pose_mat.cpu().numpy().tolist())
|
||||
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
|
||||
retry += 1
|
||||
continue
|
||||
|
||||
@@ -138,7 +141,7 @@ class Inferencer(Runner):
|
||||
break
|
||||
if pred_cr <= last_pred_cr + cr_increase_threshold:
|
||||
retry += 1
|
||||
retry_duplication_pose.append(pred_n_to_world_pose_mat.cpu().numpy().tolist())
|
||||
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
|
||||
continue
|
||||
|
||||
retry = 0
|
||||
@@ -151,7 +154,7 @@ class Inferencer(Runner):
|
||||
new_pts_tensor = torch.tensor(new_pts, dtype=torch.float32).unsqueeze(0).to(self.device)
|
||||
|
||||
input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0] , new_pts_tensor], dim=0)]
|
||||
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], next_pose_9d], dim=0)]
|
||||
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
|
||||
|
||||
last_pred_cr = pred_cr
|
||||
|
||||
|
Reference in New Issue
Block a user