debug pose_diff

This commit is contained in:
hofee 2024-09-27 08:06:49 +00:00
parent 030bf55192
commit 92250aeb62
3 changed files with 11 additions and 11 deletions

View File

@ -3,11 +3,11 @@ runner:
general: general:
seed: 0 seed: 0
device: cuda device: cuda
cuda_visible_devices: "0" cuda_visible_devices: "1"
parallel: False parallel: False
experiment: experiment:
name: overfit_w_global_feat_wo_local_pts_feat_small name: full_w_global_feat_wo_local_pts_feat
root_dir: "experiments" root_dir: "experiments"
use_checkpoint: False use_checkpoint: False
epoch: -1 # -1 stands for last epoch epoch: -1 # -1 stands for last epoch
@ -25,7 +25,7 @@ runner:
test: test:
frequency: 3 # test frequency frequency: 3 # test frequency
dataset_list: dataset_list:
#- OmniObject3d_test - OmniObject3d_test
- OmniObject3d_val - OmniObject3d_val
pipeline: nbv_reconstruction_global_pts_pipeline pipeline: nbv_reconstruction_global_pts_pipeline
@ -35,7 +35,7 @@ dataset:
root_dir: "/home/data/hofee/project/nbv_rec/data/nbv_rec_data_512_preproc_npy" root_dir: "/home/data/hofee/project/nbv_rec/data/nbv_rec_data_512_preproc_npy"
model_dir: "../data/scaled_object_meshes" model_dir: "../data/scaled_object_meshes"
source: nbv_reconstruction_dataset source: nbv_reconstruction_dataset
split_file: "/home/data/hofee/project/nbv_rec/data/OmniObject3d_sample.txt" split_file: "/home/data/hofee/project/nbv_rec/data/OmniObject3d_train.txt"
type: train type: train
cache: True cache: True
ratio: 1 ratio: 1
@ -55,7 +55,7 @@ dataset:
eval_list: eval_list:
- pose_diff - pose_diff
ratio: 0.05 ratio: 0.05
batch_size: 1 batch_size: 160
num_workers: 12 num_workers: 12
pts_num: 4096 pts_num: 4096
load_from_preprocess: True load_from_preprocess: True
@ -64,14 +64,14 @@ dataset:
root_dir: "/home/data/hofee/project/nbv_rec/data/nbv_rec_data_512_preproc_npy" root_dir: "/home/data/hofee/project/nbv_rec/data/nbv_rec_data_512_preproc_npy"
model_dir: "../data/scaled_object_meshes" model_dir: "../data/scaled_object_meshes"
source: nbv_reconstruction_dataset source: nbv_reconstruction_dataset
split_file: "/home/data/hofee/project/nbv_rec/data/OmniObject3d_sample.txt" split_file: "/home/data/hofee/project/nbv_rec/data/OmniObject3d_train.txt"
type: test type: test
cache: True cache: True
filter_degree: 75 filter_degree: 75
eval_list: eval_list:
- pose_diff - pose_diff
ratio: 1 ratio: 0.005
batch_size: 1 batch_size: 160
num_workers: 12 num_workers: 12
pts_num: 4096 pts_num: 4096
load_from_preprocess: True load_from_preprocess: True

View File

@ -29,8 +29,9 @@ class PoseDiff:
gt_rot_mat = PoseUtil.rotation_6d_to_matrix_tensor_batch(gt_rot_6d) gt_rot_mat = PoseUtil.rotation_6d_to_matrix_tensor_batch(gt_rot_6d)
pred_rot_mat = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_rot_6d) pred_rot_mat = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_rot_6d)
rotation_angles = PoseUtil.rotation_angle_distance(gt_rot_mat, pred_rot_mat) rotation_angles = PoseUtil.rotation_angle_distance(gt_rot_mat, pred_rot_mat)
rot_angle_list.extend(list(rotation_angles)) rot_angle_list.extend(list(rotation_angles))
trans_dist = torch.norm(gt_trans-pred_trans) trans_dist = torch.norm(gt_trans-pred_trans, dim=1).mean().item()
trans_dist_list.append(trans_dist) trans_dist_list.append(trans_dist)

View File

@ -34,7 +34,7 @@ class NBVReconstructionDataset(BaseDataset):
self.model_dir = config["model_dir"] self.model_dir = config["model_dir"]
self.filter_degree = config["filter_degree"] self.filter_degree = config["filter_degree"]
if self.type == namespace.Mode.TRAIN: if self.type == namespace.Mode.TRAIN:
scale_ratio = 100 scale_ratio = 1
self.datalist = self.datalist*scale_ratio self.datalist = self.datalist*scale_ratio
if self.cache: if self.cache:
expr_root = ConfigManager.get("runner", "experiment", "root_dir") expr_root = ConfigManager.get("runner", "experiment", "root_dir")
@ -83,7 +83,6 @@ class NBVReconstructionDataset(BaseDataset):
"label_idx": seq_idx, "label_idx": seq_idx,
"scene_max_coverage_rate": scene_max_coverage_rate "scene_max_coverage_rate": scene_max_coverage_rate
}) })
break # TODO: for small version debug
return datalist return datalist
def preprocess_cache(self): def preprocess_cache(self):