diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index 87ba6e3..79ea4a1 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -7,19 +7,19 @@ runner: parallel: False experiment: - name: train_ab_global_only_with_wp_p++_dense + name: train_ab_global_only_with_wp_p++_strong root_dir: "experiments" use_checkpoint: False epoch: -1 # -1 stands for last epoch max_epochs: 5000 save_checkpoint_interval: 1 - test_first: True + test_first: False train: optimizer: type: Adam lr: 0.0001 - losses: + losses: - gf_loss dataset: OmniObject3d_train test: @@ -39,7 +39,7 @@ dataset: type: train cache: True ratio: 1 - batch_size: 80 + batch_size: 64 num_workers: 128 pts_num: 8192 load_from_preprocess: True @@ -98,7 +98,7 @@ module: pointnet++_encoder: in_dim: 3 - params_name: dense + params_name: strong transformer_seq_encoder: embed_dim: 256 @@ -110,7 +110,7 @@ module: gf_view_finder: t_feat_dim: 128 pose_feat_dim: 256 - main_feat_dim: 2048 + main_feat_dim: 5120 regression_head: Rx_Ry_and_T pose_mode: rot_matrix per_point_feature: False diff --git a/core/pipeline.py b/core/pipeline.py index ae04d9e..1050628 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -75,7 +75,7 @@ class NBVReconstructionPipeline(nn.Module): def forward_test(self, data): main_feat = self.get_main_feat(data) - repeat_num = data.get("repeat_num", 100) + repeat_num = data.get("repeat_num", 1) main_feat = main_feat.repeat(repeat_num, 1) estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view( main_feat