global: upd inference

This commit is contained in:
2024-11-01 08:43:13 +00:00
parent b221036e8b
commit 985a08d89c
4 changed files with 74 additions and 66 deletions

View File

@@ -27,6 +27,7 @@ class Inferencer(Runner):
self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path")
self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir")
self.voxel_size = ConfigManager.get(namespace.Stereotype.RUNNER, "voxel_size")
''' Pipeline '''
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
@@ -65,16 +66,11 @@ class Inferencer(Runner):
for dataset_idx, test_set in enumerate(self.test_set_list):
status_manager.set_progress("inference", "inferencer", f"dataset", dataset_idx, len(self.test_set_list))
test_set_name = test_set.get_name()
test_loader = test_set.get_loader()
if test_loader.batch_size > 1:
Log.error("Batch size should be 1 for inference, found {} in {}".format(test_loader.batch_size, test_set_name), terminate=True)
total=int(len(test_loader))
loop = tqdm(enumerate(test_loader), total=total)
for i, data in loop:
total=int(len(test_set))
for i in range(total):
data = test_set.__getitem__(i)
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
test_set.process_batch(data, self.device)
output = self.predict_sequence(data)
self.save_inference_result(test_set_name, data["scene_name"][0], output)
@@ -88,26 +84,23 @@ class Inferencer(Runner):
''' data for rendering '''
scene_path = data["scene_path"][0]
O_to_L_pose = data["O_to_L_pose"][0]
voxel_threshold = data["voxel_threshold"][0]
filter_degree = data["filter_degree"][0]
model_points_normals = data["model_points_normals"][0]
model_pts = model_points_normals[:,:3]
down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold)
first_frame_to_world_9d = data["first_to_world_9d"][0]
first_frame_to_world = torch.eye(4, device=first_frame_to_world_9d.device)
first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(first_frame_to_world_9d[:,:6])[0]
first_frame_to_world[:3,3] = first_frame_to_world_9d[0,6:]
first_frame_to_world = first_frame_to_world.to(self.device)
voxel_threshold = self.voxel_size
filter_degree = 75
down_sampled_model_pts = data["gt_pts"]
import ipdb; ipdb.set_trace()
first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0]
first_frame_to_world = np.eye(4)
first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6])
first_frame_to_world[:3,3] = first_frame_to_world_9d[6:]
''' data for inference '''
input_data = {}
input_data["scanned_pts"] = [data["first_pts"][0].to(self.device)]
input_data["scanned_n_to_world_pose_9d"] = [data["first_to_world_9d"][0].to(self.device)]
input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device)
input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)]
input_data["mode"] = namespace.Mode.TEST
input_data["combined_scanned_pts"] = data["combined_scanned_pts"]
input_pts_N = input_data["scanned_pts"][0].shape[1]
input_pts_N = input_data["combined_scanned_pts"].shape[1]
first_frame_target_pts, _ = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
first_frame_target_pts, _ = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, down_sampled_model_pts, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
scanned_view_pts = [first_frame_target_pts]
last_pred_cr = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)