From 81bf2678acefc62806c2c5e64271c1f2f4579b05 Mon Sep 17 00:00:00 2001 From: hofee Date: Mon, 28 Apr 2025 06:16:03 +0000 Subject: [PATCH] ablation study --- configs/server/server_train_config.yaml | 41 +++++++---- core/ab_global_only_pts_pipeline.py | 81 ++++++++++++++++++++++ core/ab_local_only_pts_pipeline.py | 91 +++++++++++++++++++++++++ core/global_pts_pipeline.py | 10 +-- core/local_pts_pipeline.py | 19 ++---- core/nbv_dataset.py | 36 +++++----- runners/simulator.py | 4 +- 7 files changed, 232 insertions(+), 50 deletions(-) create mode 100644 core/ab_global_only_pts_pipeline.py create mode 100644 core/ab_local_only_pts_pipeline.py diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index 79ea4a1..90870a4 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -3,11 +3,11 @@ runner: general: seed: 0 device: cuda - cuda_visible_devices: "0" + cuda_visible_devices: "2" parallel: False experiment: - name: train_ab_global_only_with_wp_p++_strong + name: newtrain_real_global_only root_dir: "experiments" use_checkpoint: False epoch: -1 # -1 stands for last epoch @@ -28,18 +28,18 @@ runner: - OmniObject3d_test - OmniObject3d_val - pipeline: nbv_reconstruction_pipeline + pipeline: nbv_reconstruction_pipeline_global_only dataset: OmniObject3d_train: root_dir: "/data/hofee/data/new_full_data" model_dir: "../data/scaled_object_meshes" source: nbv_reconstruction_dataset - split_file: "/data/hofee/data/new_full_data_list/OmniObject3d_train.txt" + split_file: "/data/hofee/data/new_full_data_list/new_OmniObject3d_train.txt" type: train cache: True ratio: 1 - batch_size: 64 + batch_size: 24 num_workers: 128 pts_num: 8192 load_from_preprocess: True @@ -48,14 +48,14 @@ dataset: root_dir: "/data/hofee/data/new_full_data" model_dir: "../data/scaled_object_meshes" source: nbv_reconstruction_dataset - split_file: "/data/hofee/data/new_full_data_list/OmniObject3d_test.txt" + split_file: "/data/hofee/data/new_full_data_list/new_OmniObject3d_test.txt" type: test cache: True filter_degree: 75 eval_list: - pose_diff ratio: 1 - batch_size: 80 + batch_size: 32 num_workers: 12 pts_num: 8192 load_from_preprocess: True @@ -64,21 +64,37 @@ dataset: root_dir: "/data/hofee/data/new_full_data" model_dir: "../data/scaled_object_meshes" source: nbv_reconstruction_dataset - split_file: "/data/hofee/data/new_full_data_list/OmniObject3d_train.txt" + split_file: "/data/hofee/data/new_full_data_list/new_OmniObject3d_train.txt" type: test cache: True filter_degree: 75 eval_list: - pose_diff ratio: 0.1 - batch_size: 80 + batch_size: 32 num_workers: 12 pts_num: 8192 load_from_preprocess: True pipeline: - nbv_reconstruction_pipeline: + nbv_reconstruction_pipeline_local: + modules: + pts_encoder: pointnet++_encoder + seq_encoder: transformer_seq_encoder + pose_encoder: pose_encoder + view_finder: gf_view_finder + eps: 1e-5 + global_scanned_feat: True + nbv_reconstruction_pipeline_global: + modules: + pts_encoder: pointnet++_encoder + seq_encoder: transformer_seq_encoder + pose_encoder: pose_encoder + view_finder: gf_view_finder + eps: 1e-5 + global_scanned_feat: True + nbv_reconstruction_pipeline_local_only: modules: pts_encoder: pointnet++_encoder seq_encoder: transformer_seq_encoder @@ -98,10 +114,9 @@ module: pointnet++_encoder: in_dim: 3 - params_name: strong transformer_seq_encoder: - embed_dim: 256 + embed_dim: 1280 num_heads: 4 ffn_dim: 256 num_layers: 3 @@ -110,7 +125,7 @@ module: gf_view_finder: t_feat_dim: 128 pose_feat_dim: 256 - main_feat_dim: 5120 + main_feat_dim: 1024 regression_head: Rx_Ry_and_T pose_mode: rot_matrix per_point_feature: False diff --git a/core/ab_global_only_pts_pipeline.py b/core/ab_global_only_pts_pipeline.py new file mode 100644 index 0000000..f5603e2 --- /dev/null +++ b/core/ab_global_only_pts_pipeline.py @@ -0,0 +1,81 @@ +import torch +from torch import nn +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory.component_factory import ComponentFactory +from PytorchBoot.utils import Log + + +@stereotype.pipeline("nbv_reconstruction_pipeline_global_only") +class NBVReconstructionGlobalPointsOnlyPipeline(nn.Module): + def __init__(self, config): + super(NBVReconstructionGlobalPointsOnlyPipeline, self).__init__() + self.config = config + self.module_config = config["modules"] + self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) + self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) + self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) + self.eps = float(self.config["eps"]) + self.enable_global_scanned_feat = self.config["global_scanned_feat"] + + def forward(self, data): + mode = data["mode"] + + if mode == namespace.Mode.TRAIN: + return self.forward_train(data) + elif mode == namespace.Mode.TEST: + return self.forward_test(data) + else: + Log.error("Unknown mode: {}".format(mode), True) + + def pertube_data(self, gt_delta_9d): + bs = gt_delta_9d.shape[0] + random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps + random_t = random_t.unsqueeze(-1) + mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) + std = std.view(-1, 1) + z = torch.randn_like(gt_delta_9d) + perturbed_x = mu + z * std + target_score = - z * std / (std ** 2) + return perturbed_x, random_t, target_score, std + + def forward_train(self, data): + main_feat = self.get_main_feat(data) + ''' get std ''' + best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] + perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch) + input_data = { + "sampled_pose": perturbed_x, + "t": random_t, + "main_feat": main_feat, + } + estimated_score = self.view_finder(input_data) + output = { + "estimated_score": estimated_score, + "target_score": target_score, + "std": std + } + return output + + def forward_test(self,data): + main_feat = self.get_main_feat(data) + estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat) + result = { + "pred_pose_9d": estimated_delta_rot_9d, + "in_process_sample": in_process_sample + } + return result + + + def get_main_feat(self, data): + + combined_scanned_pts_batch = data['combined_scanned_pts'] + global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch) + main_feat = global_scanned_feat + + + if torch.isnan(main_feat).any(): + Log.error("nan in main_feat", True) + + return main_feat + diff --git a/core/ab_local_only_pts_pipeline.py b/core/ab_local_only_pts_pipeline.py new file mode 100644 index 0000000..b06c78f --- /dev/null +++ b/core/ab_local_only_pts_pipeline.py @@ -0,0 +1,91 @@ +import torch +from torch import nn +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory.component_factory import ComponentFactory +from PytorchBoot.utils import Log + +@stereotype.pipeline("nbv_reconstruction_pipeline_local_only") +class NBVReconstructionLocalPointsOnlyPipeline(nn.Module): + def __init__(self, config): + super(NBVReconstructionLocalPointsOnlyPipeline, self).__init__() + self.config = config + self.module_config = config["modules"] + self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) + self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) + self.seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["seq_encoder"]) + self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) + self.eps = float(self.config["eps"]) + self.enable_global_scanned_feat = self.config["global_scanned_feat"] + + def forward(self, data): + mode = data["mode"] + + if mode == namespace.Mode.TRAIN: + return self.forward_train(data) + elif mode == namespace.Mode.TEST: + return self.forward_test(data) + else: + Log.error("Unknown mode: {}".format(mode), True) + + def pertube_data(self, gt_delta_9d): + bs = gt_delta_9d.shape[0] + random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps + random_t = random_t.unsqueeze(-1) + mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) + std = std.view(-1, 1) + z = torch.randn_like(gt_delta_9d) + perturbed_x = mu + z * std + target_score = - z * std / (std ** 2) + return perturbed_x, random_t, target_score, std + + def forward_train(self, data): + main_feat = self.get_main_feat(data) + ''' get std ''' + best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] + perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch) + input_data = { + "sampled_pose": perturbed_x, + "t": random_t, + "main_feat": main_feat, + } + estimated_score = self.view_finder(input_data) + output = { + "estimated_score": estimated_score, + "target_score": target_score, + "std": std + } + return output + + def forward_test(self,data): + main_feat = self.get_main_feat(data) + estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat) + result = { + "pred_pose_9d": estimated_delta_rot_9d, + "in_process_sample": in_process_sample + } + return result + + + def get_main_feat(self, data): + scanned_pts_batch = data['scanned_pts'] + scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d'] + device = next(self.parameters()).device + feat_seq_list = [] + + for scanned_pts,scanned_n_to_world_pose_9d in zip(scanned_pts_batch,scanned_n_to_world_pose_9d_batch): + + scanned_pts = scanned_pts.to(device) + scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) + pts_feat = self.pts_encoder.encode_points(scanned_pts) + pose_feat = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) + seq_feat = torch.cat([pts_feat, pose_feat], dim=-1) + feat_seq_list.append(seq_feat) + main_feat = self.seq_encoder.encode_sequence(feat_seq_list) + + + if torch.isnan(main_feat).any(): + Log.error("nan in main_feat", True) + + return main_feat + diff --git a/core/global_pts_pipeline.py b/core/global_pts_pipeline.py index 31b8ad4..a4d9c9b 100644 --- a/core/global_pts_pipeline.py +++ b/core/global_pts_pipeline.py @@ -6,7 +6,7 @@ from PytorchBoot.factory.component_factory import ComponentFactory from PytorchBoot.utils import Log -@stereotype.pipeline("nbv_reconstruction_global_pts_pipeline") +@stereotype.pipeline("nbv_reconstruction_pipeline_global") class NBVReconstructionGlobalPointsPipeline(nn.Module): def __init__(self, config): super(NBVReconstructionGlobalPointsPipeline, self).__init__() @@ -14,7 +14,7 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): self.module_config = config["modules"] self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) - self.pose_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_seq_encoder"]) + self.seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["seq_encoder"]) self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) self.eps = float(self.config["eps"]) self.enable_global_scanned_feat = self.config["global_scanned_feat"] @@ -73,13 +73,13 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module): device = next(self.parameters()).device - pose_feat_seq_list = [] + feat_seq_list = [] for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch: scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) - pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)) + feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)) - main_feat = self.pose_seq_encoder.encode_sequence(pose_feat_seq_list) + main_feat = self.seq_encoder.encode_sequence(feat_seq_list) combined_scanned_pts_batch = data['combined_scanned_pts'] diff --git a/core/local_pts_pipeline.py b/core/local_pts_pipeline.py index 8827dc9..acf0e24 100644 --- a/core/local_pts_pipeline.py +++ b/core/local_pts_pipeline.py @@ -5,7 +5,7 @@ import PytorchBoot.stereotype as stereotype from PytorchBoot.factory.component_factory import ComponentFactory from PytorchBoot.utils import Log -@stereotype.pipeline("nbv_reconstruction_local_pts_pipeline") +@stereotype.pipeline("nbv_reconstruction_pipeline_local") class NBVReconstructionLocalPointsPipeline(nn.Module): def __init__(self, config): super(NBVReconstructionLocalPointsPipeline, self).__init__() @@ -70,23 +70,18 @@ class NBVReconstructionLocalPointsPipeline(nn.Module): def get_main_feat(self, data): scanned_pts_batch = data['scanned_pts'] scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d'] - - device = next(self.parameters()).device - - - - pts_feat_seq_list = [] - pose_feat_seq_list = [] + feat_seq_list = [] for scanned_pts,scanned_n_to_world_pose_9d in zip(scanned_pts_batch,scanned_n_to_world_pose_9d_batch): scanned_pts = scanned_pts.to(device) scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) - pts_feat_seq_list.append(self.pts_encoder.encode_points(scanned_pts)) - pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)) - - main_feat = self.seq_encoder.encode_sequence(pts_feat_seq_list, pose_feat_seq_list) + pts_feat = self.pts_encoder.encode_points(scanned_pts) + pose_feat = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) + seq_feat = torch.cat([pts_feat, pose_feat], dim=-1) + feat_seq_list.append(seq_feat) + main_feat = self.seq_encoder.encode_sequence(feat_seq_list) if self.enable_global_scanned_feat: combined_scanned_pts_batch = data['combined_scanned_pts'] diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index 6583e5b..40294bc 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -135,7 +135,7 @@ class NBVReconstructionDataset(BaseDataset): scanned_coverages_rate, scanned_n_to_world_pose, ) = ([], [], []) - start_time = time.time() + #start_time = time.time() start_indices = [0] total_points = 0 for view in scanned_views: @@ -163,7 +163,7 @@ class NBVReconstructionDataset(BaseDataset): start_indices.append(total_points) - end_time = time.time() + #end_time = time.time() #Log.info(f"load data time: {end_time - start_time}") nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx) @@ -182,22 +182,22 @@ class NBVReconstructionDataset(BaseDataset): voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_views_pts, 0.003) random_downsampled_combined_scanned_pts_np, random_downsample_idx = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, self.pts_num, require_idx=True) - all_idx_unique = np.arange(len(voxel_downsampled_combined_scanned_pts_np)) - all_random_downsample_idx = all_idx_unique[random_downsample_idx] - scanned_pts_mask = [] - for idx, start_idx in enumerate(start_indices): - if idx == len(start_indices) - 1: - break - end_idx = start_indices[idx+1] - view_inverse = inverse[start_idx:end_idx] - view_unique_downsampled_idx = np.unique(view_inverse) - view_unique_downsampled_idx_set = set(view_unique_downsampled_idx) - mask = np.array([idx in view_unique_downsampled_idx_set for idx in all_random_downsample_idx]) - scanned_pts_mask.append(mask) + # all_idx_unique = np.arange(len(voxel_downsampled_combined_scanned_pts_np)) + # all_random_downsample_idx = all_idx_unique[random_downsample_idx] + # scanned_pts_mask = [] + # for idx, start_idx in enumerate(start_indices): + # if idx == len(start_indices) - 1: + # break + # end_idx = start_indices[idx+1] + # view_inverse = inverse[start_idx:end_idx] + # view_unique_downsampled_idx = np.unique(view_inverse) + # view_unique_downsampled_idx_set = set(view_unique_downsampled_idx) + # mask = np.array([idx in view_unique_downsampled_idx_set for idx in all_random_downsample_idx]) + # #scanned_pts_mask.append(mask) data_item = { "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3) "combined_scanned_pts": np.asarray(random_downsampled_combined_scanned_pts_np, dtype=np.float32), # Ndarray(N x 3) - "scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.bool), # Ndarray(N) + #"scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.bool), # Ndarray(N) "scanned_coverage_rate": scanned_coverages_rate, # List(S): Float, range(0, 1) "scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9) "best_coverage_rate": nbv_coverage_rate, # Float, range(0, 1) @@ -223,9 +223,9 @@ class NBVReconstructionDataset(BaseDataset): collate_data["scanned_n_to_world_pose_9d"] = [ torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch ] - collate_data["scanned_pts_mask"] = [ - torch.tensor(item["scanned_pts_mask"]) for item in batch - ] + # collate_data["scanned_pts_mask"] = [ + # torch.tensor(item["scanned_pts_mask"]) for item in batch + # ] ''' ------ Fixed Length ------ ''' collate_data["best_to_world_pose_9d"] = torch.stack( diff --git a/runners/simulator.py b/runners/simulator.py index fbba793..c38fe5d 100644 --- a/runners/simulator.py +++ b/runners/simulator.py @@ -1,5 +1,5 @@ -import pybullet as p -import pybullet_data +# import pybullet as p +# import pybullet_data import numpy as np import os import time