This commit is contained in:
2024-09-27 08:06:55 +00:00
5 changed files with 81 additions and 51 deletions

View File

@@ -73,7 +73,6 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
device = next(self.parameters()).device
pts_feat_seq_list = []
pose_feat_seq_list = []
for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch:
@@ -82,10 +81,10 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
main_feat = self.pose_seq_encoder.encode_sequence(pose_feat_seq_list)
if self.enable_global_scanned_feat:
combined_scanned_pts_batch = data['combined_scanned_pts']
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
combined_scanned_pts_batch = data['combined_scanned_pts']
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
if torch.isnan(main_feat).any():

View File

@@ -39,42 +39,32 @@ class SeqNBVReconstructionDataset(BaseDataset):
scene_name_list.append(scene_name)
return scene_name_list
def get_datalist_new(self):
datalist = []
for scene_name in self.scene_name_list:
label_num = DataLoadUtil.get_label_num(self.root_dir, scene_name)
for i in range(label_num):
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, i)
label_data = DataLoadUtil.load_label(label_path)
best_seq = label_data["best_sequence"]
max_coverage_rate = label_data["max_coverage_rate"]
first_frame = best_seq[0]
best_seq_len = len(best_seq)
datalist.append({
"scene_name": scene_name,
"first_frame": first_frame,
"max_coverage_rate": max_coverage_rate,
"best_seq_len": best_seq_len,
"label_idx": i,
})
return datalist
def get_datalist(self):
datalist = []
for scene_name in self.scene_name_list:
label_path = DataLoadUtil.get_label_path_old(self.root_dir, scene_name)
seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name)
scene_max_coverage_rate = 0
scene_max_cr_idx = 0
for seq_idx in range(seq_num):
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx)
label_data = DataLoadUtil.load_label(label_path)
max_coverage_rate = label_data["max_coverage_rate"]
if max_coverage_rate > scene_max_coverage_rate:
scene_max_coverage_rate = max_coverage_rate
scene_max_cr_idx = seq_idx
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, scene_max_cr_idx)
label_data = DataLoadUtil.load_label(label_path)
best_seq = label_data["best_sequence"]
max_coverage_rate = label_data["max_coverage_rate"]
first_frame = best_seq[0]
best_seq_len = len(best_seq)
first_frame = label_data["best_sequence"][0]
best_seq_len = len(label_data["best_sequence"])
datalist.append({
"scene_name": scene_name,
"first_frame": first_frame,
"max_coverage_rate": max_coverage_rate,
"best_seq_len": best_seq_len,
"best_seq": best_seq,
})
"scene_name": scene_name,
"first_frame": first_frame,
"max_coverage_rate": scene_max_coverage_rate,
"best_seq_len": best_seq_len,
"label_idx": scene_max_cr_idx,
})
return datalist
def __getitem__(self, index):
@@ -110,8 +100,10 @@ class SeqNBVReconstructionDataset(BaseDataset):
first_O_to_first_L_pose = np.dot(np.linalg.inv(first_left_cam_pose), first_center_cam_pose)
scene_path = os.path.join(self.root_dir, scene_name)
model_points_normals = DataLoadUtil.load_points_normals(self.root_dir, scene_name)
data_item = {
"first_pts": np.asarray([first_downsampled_target_point_cloud],dtype=np.float32),
"combined_scanned_pts": np.asarray(first_downsampled_target_point_cloud,dtype=np.float32),
"first_to_world_9d": np.asarray([first_to_world_9d],dtype=np.float32),
"scene_name": scene_name,
"max_coverage_rate": max_coverage_rate,
@@ -134,8 +126,9 @@ class SeqNBVReconstructionDataset(BaseDataset):
collate_data = {}
collate_data["first_pts"] = [torch.tensor(item['first_pts']) for item in batch]
collate_data["first_to_world_9d"] = [torch.tensor(item['first_to_world_9d']) for item in batch]
collate_data["combined_scanned_pts"] = torch.stack([torch.tensor(item['combined_scanned_pts']) for item in batch])
for key in batch[0].keys():
if key not in ["first_pts", "first_to_world_9d"]:
if key not in ["first_pts", "first_to_world_9d", "combined_scanned_pts"]:
collate_data[key] = [item[key] for item in batch]
return collate_data
return collate_fn