diff --git a/configs/local/inference_config.yaml b/configs/local/inference_config.yaml index 64bccbf..e43b478 100644 --- a/configs/local/inference_config.yaml +++ b/configs/local/inference_config.yaml @@ -6,7 +6,7 @@ runner: cuda_visible_devices: "0,1,2,3,4,5,6,7" experiment: - name: train_ab_global_only + name: train_ab_partial root_dir: "experiments" epoch: -1 # -1 stands for last epoch @@ -15,7 +15,7 @@ runner: - OmniObject3d_test blender_script_path: "/media/hofee/data/project/python/nbv_reconstruction/blender/data_renderer.py" - output_dir: "/media/hofee/data/data/new_inference_test_output" + output_dir: "/media/hofee/data/data/new_partial_inference_test_output" pipeline: nbv_reconstruction_pipeline voxel_size: 0.003 min_new_area: 1.0 @@ -66,7 +66,7 @@ module: global_feat: True feature_transform: False transformer_seq_encoder: - embed_dim: 256 + embed_dim: 320 num_heads: 4 ffn_dim: 256 num_layers: 3 diff --git a/core/pipeline.py b/core/pipeline.py index a43d572..8bdb265 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -88,26 +88,49 @@ class NBVReconstructionPipeline(nn.Module): scanned_n_to_world_pose_9d_batch = data[ "scanned_n_to_world_pose_9d" ] # List(B): Tensor(S x 9) + scanned_pts_mask_batch = data["scanned_pts_mask"] # List(B): Tensor(S x N) device = next(self.parameters()).device embedding_list_batch = [] combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3) - global_scanned_feat = self.pts_encoder.encode_points( - combined_scanned_pts_batch, require_per_point_feat=False + global_scanned_feat, per_point_feat_batch = self.pts_encoder.encode_points( + combined_scanned_pts_batch, require_per_point_feat=True ) # global_scanned_feat: Tensor(B x Dg) - - for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch: - scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) # Tensor(S x 9) + batch_size = len(scanned_n_to_world_pose_9d_batch) + for i in range(batch_size): + seq_len = len(scanned_n_to_world_pose_9d_batch[i]) + scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d_batch[i].to(device) # Tensor(S x 9) + scanned_pts_mask = scanned_pts_mask_batch[i] # Tensor(S x N) + per_point_feat = per_point_feat_batch[i] # Tensor(N x Dp) + partial_point_feat_seq = [] + for j in range(seq_len): + partial_per_point_feat = per_point_feat[scanned_pts_mask[j]] + if partial_per_point_feat.shape[0] == 0: + partial_point_feat = torch.zeros(per_point_feat.shape[1], device=device) + else: + partial_point_feat = torch.mean(partial_per_point_feat, dim=0) # Tensor(Dp) + partial_point_feat_seq.append(partial_point_feat) + partial_point_feat_seq = torch.stack(partial_point_feat_seq, dim=0) # Tensor(S x Dp) + pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp) - seq_embedding = pose_feat_seq + + seq_embedding = torch.cat([partial_point_feat_seq, pose_feat_seq], dim=-1) + embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp)) seq_feat = self.seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg)) if torch.isnan(main_feat).any(): + for i in range(len(main_feat)): + if torch.isnan(main_feat[i]).any(): + scanned_pts_mask = scanned_pts_mask_batch[i] + Log.info(f"scanned_pts_mask shape: {scanned_pts_mask.shape}") + Log.info(f"scanned_pts_mask sum: {scanned_pts_mask.sum()}") + import ipdb + ipdb.set_trace() Log.error("nan in main_feat", True) - return main_feat + return main_feat \ No newline at end of file diff --git a/runners/inferencer.py b/runners/inferencer.py index 5d0cf16..df79cef 100644 --- a/runners/inferencer.py +++ b/runners/inferencer.py @@ -90,7 +90,8 @@ class Inferencer(Runner): output = self.predict_sequence(data) self.save_inference_result(test_set_name, data["scene_name"], output) except Exception as e: - Log.error(f"Error in scene {scene_name}, {e}") + print(e) + Log.error(f"Error, {e}") continue status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list)) @@ -114,7 +115,9 @@ class Inferencer(Runner): ''' data for inference ''' input_data = {} + input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0) + input_data["scanned_pts_mask"] = [torch.zeros(input_data["combined_scanned_pts"].shape[1], dtype=torch.bool).to(self.device).unsqueeze(0)] input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)] input_data["mode"] = namespace.Mode.TEST input_pts_N = input_data["combined_scanned_pts"].shape[1] @@ -187,11 +190,30 @@ class Inferencer(Runner): scanned_view_pts.append(new_target_pts) input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] - + start_indices = [0] + total_points = 0 + for pts in scanned_view_pts: + total_points += pts.shape[0] + start_indices.append(total_points) combined_scanned_pts = np.vstack(scanned_view_pts) - voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold) - random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) + voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_pts, voxel_threshold) + random_downsampled_combined_scanned_pts_np, random_downsample_idx = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N, require_idx=True) + all_idx_unique = np.arange(len(voxel_downsampled_combined_scanned_pts_np)) + all_random_downsample_idx = all_idx_unique[random_downsample_idx] + scanned_pts_mask = [] + for idx, start_idx in enumerate(start_indices): + if idx == len(start_indices) - 1: + break + end_idx = start_indices[idx+1] + view_inverse = inverse[start_idx:end_idx] + view_unique_downsampled_idx = np.unique(view_inverse) + view_unique_downsampled_idx_set = set(view_unique_downsampled_idx) + mask = np.array([idx in view_unique_downsampled_idx_set for idx in all_random_downsample_idx]) + scanned_pts_mask.append(mask) + input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) + #import ipdb; ipdb.set_trace() + input_data["scanned_pts_mask"] = [torch.tensor(scanned_pts_mask, dtype=torch.bool)] last_pred_cr = pred_cr @@ -232,6 +254,14 @@ class Inferencer(Runner): return result + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005): if new_pts is not None: new_scanned_view_pts = scanned_view_pts + [new_pts]