add multi seq training

This commit is contained in:
2024-09-23 14:30:51 +08:00
parent 6cdff9c83f
commit 3c4077ec4f
12 changed files with 152 additions and 93 deletions

View File

@@ -3,6 +3,7 @@ import numpy as np
import json
import cv2
import trimesh
import torch
from utils.pts import PtsUtil
class DataLoadUtil:
@@ -13,8 +14,21 @@ class DataLoadUtil:
return path
@staticmethod
def get_label_path(root, scene_name):
path = os.path.join(root,scene_name, f"label.json")
def get_label_num(root, scene_name):
label_dir = os.path.join(root,scene_name,"label")
return len(os.listdir(label_dir))
@staticmethod
def get_label_path(root, scene_name, seq_idx):
label_dir = os.path.join(root,scene_name,"label")
if not os.path.exists(label_dir):
os.makedirs(label_dir)
path = os.path.join(label_dir,f"{seq_idx}.json")
return path
@staticmethod
def get_label_path_old(root, scene_name):
path = os.path.join(root,scene_name,"label.json")
return path
@staticmethod
@@ -45,11 +59,14 @@ class DataLoadUtil:
mesh.export(model_path)
@staticmethod
def save_target_mesh_at_world_space(root, model_dir, scene_name):
def save_target_mesh_at_world_space(root, model_dir, scene_name, display_table_as_world_space_origin=True):
scene_info = DataLoadUtil.load_scene_info(root, scene_name)
target_name = scene_info["target_name"]
transformation = scene_info[target_name]
location = transformation["location"]
if display_table_as_world_space_origin:
location = transformation["location"] - DataLoadUtil.DISPLAY_TABLE_POSITION
else:
location = transformation["location"]
rotation_euler = transformation["rotation_euler"]
pose_mat = trimesh.transformations.euler_matrix(*rotation_euler)
pose_mat[:3, 3] = location
@@ -181,7 +198,9 @@ class DataLoadUtil:
@staticmethod
def get_real_cam_O_from_cam_L(cam_L, cam_O_to_cam_L, display_table_as_world_space_origin=True):
nO_to_display_table_pose = cam_L.cpu().numpy() @ cam_O_to_cam_L
if isinstance(cam_L, torch.Tensor):
cam_L = cam_L.cpu().numpy()
nO_to_display_table_pose = cam_L @ cam_O_to_cam_L
if display_table_as_world_space_origin:
display_table_to_world = np.eye(4)
display_table_to_world[:3, 3] = DataLoadUtil.DISPLAY_TABLE_POSITION

View File

@@ -45,12 +45,17 @@ class ReconstructionUtil:
@staticmethod
def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, display_table_point_cloud_list = None,threshold=0.01, overlap_threshold=0.3, status_info=None):
selected_views = []
current_coverage = 0.0
def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list,threshold=0.01, overlap_threshold=0.3, init_view = 0, status_info=None):
selected_views = [point_cloud_list[init_view]]
combined_point_cloud = np.vstack(selected_views)
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
current_coverage = new_coverage
remaining_views = list(range(len(point_cloud_list)))
view_sequence = []
view_sequence = [(init_view, current_coverage)]
cnt_processed_view = 0
remaining_views.remove(init_view)
while remaining_views:
best_view = None
best_coverage_increase = -1
@@ -70,14 +75,13 @@ class ReconstructionUtil:
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
coverage_increase = new_coverage - current_coverage
#print(f"view_index: {view_index}, coverage_increase: {coverage_increase}")
if coverage_increase > best_coverage_increase:
best_coverage_increase = coverage_increase
best_view = view_index
if best_view is not None:
if best_coverage_increase <=1e-3:
if best_coverage_increase <=3e-3:
break
selected_views.append(point_cloud_list[best_view])
remaining_views.remove(best_view)

View File

@@ -12,8 +12,8 @@ class RenderUtil:
def render_pts(cam_pose, scene_path,script_path, model_points_normals, voxel_threshold=0.005, filter_degree=75, nO_to_nL_pose=None, require_full_scene=False):
nO_to_world_pose = DataLoadUtil.get_real_cam_O_from_cam_L(cam_pose, nO_to_nL_pose)
with tempfile.TemporaryDirectory() as temp_dir:
params = {
"cam_pose": nO_to_world_pose.tolist(),
@@ -30,7 +30,6 @@ class RenderUtil:
print(result.stderr)
return None
path = os.path.join(temp_dir, "tmp")
point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True)
cam_params = DataLoadUtil.load_cam_info(path, binocular=True)
filtered_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=filter_degree)
@@ -44,4 +43,5 @@ class RenderUtil:
point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536)
full_scene_point_cloud = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R)
return filtered_point_cloud, full_scene_point_cloud