update inferencer: success rate
This commit is contained in:
@@ -73,7 +73,6 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
pts_feat_seq_list = []
|
||||
pose_feat_seq_list = []
|
||||
|
||||
for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch:
|
||||
@@ -82,10 +81,10 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
|
||||
main_feat = self.pose_seq_encoder.encode_sequence(pose_feat_seq_list)
|
||||
|
||||
if self.enable_global_scanned_feat:
|
||||
combined_scanned_pts_batch = data['combined_scanned_pts']
|
||||
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
|
||||
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
|
||||
|
||||
combined_scanned_pts_batch = data['combined_scanned_pts']
|
||||
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
|
||||
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
|
||||
|
||||
|
||||
if torch.isnan(main_feat).any():
|
||||
|
@@ -83,7 +83,6 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
"label_idx": seq_idx,
|
||||
"scene_max_coverage_rate": scene_max_coverage_rate
|
||||
})
|
||||
break # TODO: for small version debug
|
||||
return datalist
|
||||
|
||||
def preprocess_cache(self):
|
||||
|
@@ -39,42 +39,32 @@ class SeqNBVReconstructionDataset(BaseDataset):
|
||||
scene_name_list.append(scene_name)
|
||||
return scene_name_list
|
||||
|
||||
def get_datalist_new(self):
|
||||
datalist = []
|
||||
for scene_name in self.scene_name_list:
|
||||
label_num = DataLoadUtil.get_label_num(self.root_dir, scene_name)
|
||||
for i in range(label_num):
|
||||
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, i)
|
||||
label_data = DataLoadUtil.load_label(label_path)
|
||||
best_seq = label_data["best_sequence"]
|
||||
max_coverage_rate = label_data["max_coverage_rate"]
|
||||
first_frame = best_seq[0]
|
||||
best_seq_len = len(best_seq)
|
||||
datalist.append({
|
||||
"scene_name": scene_name,
|
||||
"first_frame": first_frame,
|
||||
"max_coverage_rate": max_coverage_rate,
|
||||
"best_seq_len": best_seq_len,
|
||||
"label_idx": i,
|
||||
})
|
||||
return datalist
|
||||
|
||||
def get_datalist(self):
|
||||
datalist = []
|
||||
for scene_name in self.scene_name_list:
|
||||
label_path = DataLoadUtil.get_label_path_old(self.root_dir, scene_name)
|
||||
seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name)
|
||||
scene_max_coverage_rate = 0
|
||||
scene_max_cr_idx = 0
|
||||
|
||||
for seq_idx in range(seq_num):
|
||||
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx)
|
||||
label_data = DataLoadUtil.load_label(label_path)
|
||||
max_coverage_rate = label_data["max_coverage_rate"]
|
||||
if max_coverage_rate > scene_max_coverage_rate:
|
||||
scene_max_coverage_rate = max_coverage_rate
|
||||
scene_max_cr_idx = seq_idx
|
||||
|
||||
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, scene_max_cr_idx)
|
||||
label_data = DataLoadUtil.load_label(label_path)
|
||||
best_seq = label_data["best_sequence"]
|
||||
max_coverage_rate = label_data["max_coverage_rate"]
|
||||
first_frame = best_seq[0]
|
||||
best_seq_len = len(best_seq)
|
||||
first_frame = label_data["best_sequence"][0]
|
||||
best_seq_len = len(label_data["best_sequence"])
|
||||
datalist.append({
|
||||
"scene_name": scene_name,
|
||||
"first_frame": first_frame,
|
||||
"max_coverage_rate": max_coverage_rate,
|
||||
"best_seq_len": best_seq_len,
|
||||
"best_seq": best_seq,
|
||||
})
|
||||
"scene_name": scene_name,
|
||||
"first_frame": first_frame,
|
||||
"max_coverage_rate": scene_max_coverage_rate,
|
||||
"best_seq_len": best_seq_len,
|
||||
"label_idx": scene_max_cr_idx,
|
||||
})
|
||||
return datalist
|
||||
|
||||
def __getitem__(self, index):
|
||||
@@ -110,8 +100,10 @@ class SeqNBVReconstructionDataset(BaseDataset):
|
||||
first_O_to_first_L_pose = np.dot(np.linalg.inv(first_left_cam_pose), first_center_cam_pose)
|
||||
scene_path = os.path.join(self.root_dir, scene_name)
|
||||
model_points_normals = DataLoadUtil.load_points_normals(self.root_dir, scene_name)
|
||||
|
||||
data_item = {
|
||||
"first_pts": np.asarray([first_downsampled_target_point_cloud],dtype=np.float32),
|
||||
"combined_scanned_pts": np.asarray(first_downsampled_target_point_cloud,dtype=np.float32),
|
||||
"first_to_world_9d": np.asarray([first_to_world_9d],dtype=np.float32),
|
||||
"scene_name": scene_name,
|
||||
"max_coverage_rate": max_coverage_rate,
|
||||
@@ -134,8 +126,9 @@ class SeqNBVReconstructionDataset(BaseDataset):
|
||||
collate_data = {}
|
||||
collate_data["first_pts"] = [torch.tensor(item['first_pts']) for item in batch]
|
||||
collate_data["first_to_world_9d"] = [torch.tensor(item['first_to_world_9d']) for item in batch]
|
||||
collate_data["combined_scanned_pts"] = torch.stack([torch.tensor(item['combined_scanned_pts']) for item in batch])
|
||||
for key in batch[0].keys():
|
||||
if key not in ["first_pts", "first_to_world_9d"]:
|
||||
if key not in ["first_pts", "first_to_world_9d", "combined_scanned_pts"]:
|
||||
collate_data[key] = [item[key] for item in batch]
|
||||
return collate_data
|
||||
return collate_fn
|
||||
|
Reference in New Issue
Block a user