From f6c4db859ea1b03a5e0e1087ccb1b4194f7231bd Mon Sep 17 00:00:00 2001 From: hofee Date: Thu, 10 Oct 2024 14:42:57 +0800 Subject: [PATCH] add multiprocess --- utils/preprocess_util.py | 129 +++++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 66 deletions(-) diff --git a/utils/preprocess_util.py b/utils/preprocess_util.py index 737b48a..dbef367 100644 --- a/utils/preprocess_util.py +++ b/utils/preprocess_util.py @@ -4,7 +4,7 @@ import time import sys np.random.seed(0) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - +from concurrent.futures import ThreadPoolExecutor, as_completed from utils.reconstruction_util import ReconstructionUtil from utils.data_load import DataLoadUtil from utils.pts_util import PtsUtil @@ -57,81 +57,78 @@ def get_scan_points_indices(scan_points, mask, display_table_mask_label, cam_int selected_points_indices = np.where((mask_colors == display_table_mask_label).all(axis=-1))[0] selected_points_indices = np.where(valid_indices)[0][selected_points_indices] return selected_points_indices - -def save_scene_data(root, scene, scene_idx=0, scene_total=1,file_type="txt"): +def process_frame(frame_id, root, scene, scan_points, file_type, target_mask_label, display_table_mask_label, random_downsample_N, voxel_size, filter_degree, min_z, max_z): + Log.info(f"[frame({frame_id})]Processing {scene} frame {frame_id}") + path = DataLoadUtil.get_path(root, scene, frame_id) + cam_info = DataLoadUtil.load_cam_info(path, binocular=True) + depth_L, depth_R = DataLoadUtil.load_depth( + path, cam_info["near_plane"], + cam_info["far_plane"], + binocular=True + ) + mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) + + target_mask_img_L = (mask_L == target_mask_label).all(axis=-1) + target_mask_img_R = (mask_R == target_mask_label).all(axis=-1) + + target_points_L = get_world_points(depth_L, target_mask_img_L, cam_info["cam_intrinsic"], cam_info["cam_to_world"]) + target_points_R = get_world_points(depth_R, target_mask_img_R, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"]) + + sampled_target_points_L = PtsUtil.random_downsample_point_cloud( + target_points_L, random_downsample_N + ) + sampled_target_points_R = PtsUtil.random_downsample_point_cloud( + target_points_R, random_downsample_N + ) + + has_points = sampled_target_points_L.shape[0] > 0 and sampled_target_points_R.shape[0] > 0 + target_points = np.zeros((0, 3)) + + if has_points: + target_points = PtsUtil.get_overlapping_points( + sampled_target_points_L, sampled_target_points_R, voxel_size + ) + + if has_points and target_points.shape[0] > 0: + points_normals = DataLoadUtil.load_points_normals(root, scene, display_table_as_world_space_origin=True) + target_points = PtsUtil.filter_points( + target_points, points_normals, cam_info["cam_to_world"], voxel_size=0.002, theta=filter_degree, z_range=(min_z, max_z) + ) + + scan_points_indices_L = get_scan_points_indices(scan_points, mask_L, display_table_mask_label, cam_info["cam_intrinsic"], cam_info["cam_to_world"]) + scan_points_indices_R = get_scan_points_indices(scan_points, mask_R, display_table_mask_label, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"]) + scan_points_indices = np.intersect1d(scan_points_indices_L, scan_points_indices_R) + + if not has_points: + target_points = np.zeros((0, 3)) - ''' configuration ''' + save_target_points(root, scene, frame_id, target_points, file_type=file_type) + save_scan_points_indices(root, scene, frame_id, scan_points_indices, file_type=file_type) + +def save_scene_data(root, scene, file_type="txt"): target_mask_label = (0, 255, 0, 255) - display_table_mask_label=(0, 0, 255, 255) + display_table_mask_label = (0, 0, 255, 255) random_downsample_N = 32768 - voxel_size=0.002 + voxel_size = 0.002 filter_degree = 75 min_z = 0.2 max_z = 0.5 - - ''' scan points ''' - scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0,display_table_radius=0.25)) - - ''' read frame data(depth|mask|normal) ''' + + scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0, display_table_radius=0.25)) frame_num = DataLoadUtil.get_scene_seq_length(root, scene) - for frame_id in range(frame_num): - Log.info(f"[frame({frame_id}/{frame_num})]Processing {scene} frame {frame_id}") - path = DataLoadUtil.get_path(root, scene, frame_id) - cam_info = DataLoadUtil.load_cam_info(path, binocular=True) - depth_L, depth_R = DataLoadUtil.load_depth( - path, cam_info["near_plane"], - cam_info["far_plane"], - binocular=True - ) - mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) + + with ThreadPoolExecutor() as executor: + futures = {executor.submit(process_frame, frame_id, root, scene, scan_points, file_type, target_mask_label, display_table_mask_label, random_downsample_N, voxel_size, filter_degree, min_z, max_z): frame_id for frame_id in range(frame_num)} - ''' target points ''' - mask_img_L = mask_L - mask_img_R = mask_R + for future in as_completed(futures): + frame_id = futures[future] + try: + future.result() + except Exception as e: + Log.error(f"Error processing frame {frame_id}: {e}") - target_mask_img_L = (mask_L == target_mask_label).all(axis=-1) - target_mask_img_R = (mask_R == target_mask_label).all(axis=-1) - - - target_points_L = get_world_points(depth_L, target_mask_img_L, cam_info["cam_intrinsic"], cam_info["cam_to_world"]) - target_points_R = get_world_points(depth_R, target_mask_img_R, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"]) - - sampled_target_points_L = PtsUtil.random_downsample_point_cloud( - target_points_L, random_downsample_N - ) - sampled_target_points_R = PtsUtil.random_downsample_point_cloud( - target_points_R, random_downsample_N - ) - - has_points = sampled_target_points_L.shape[0] > 0 and sampled_target_points_R.shape[0] > 0 - if has_points: - target_points = PtsUtil.get_overlapping_points( - sampled_target_points_L, sampled_target_points_R, voxel_size - ) - - if has_points: - has_points = target_points.shape[0] > 0 - - if has_points: - points_normals = DataLoadUtil.load_points_normals(root, scene, display_table_as_world_space_origin=True) - target_points = PtsUtil.filter_points( - target_points, points_normals, cam_info["cam_to_world"],voxel_size=0.002, theta = filter_degree, z_range=(min_z, max_z) - ) - - - ''' scan points indices ''' - scan_points_indices_L = get_scan_points_indices(scan_points, mask_img_L, display_table_mask_label, cam_info["cam_intrinsic"], cam_info["cam_to_world"]) - scan_points_indices_R = get_scan_points_indices(scan_points, mask_img_R, display_table_mask_label, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"]) - scan_points_indices = np.intersect1d(scan_points_indices_L, scan_points_indices_R) - - if not has_points: - target_points = np.zeros((0, 3)) - - save_target_points(root, scene, frame_id, target_points, file_type=file_type) - save_scan_points_indices(root, scene, frame_id, scan_points_indices, file_type=file_type) - - save_scan_points(root, scene, scan_points) # The "done" flag of scene preprocess + save_scan_points(root, scene, scan_points) # The "done" flag of scene preprocess if __name__ == "__main__":