add preprocess

This commit is contained in:
hofee
2024-10-03 01:59:13 +08:00
parent f460e6e6b2
commit d7561738c6
8 changed files with 243 additions and 142 deletions

View File

@@ -157,8 +157,8 @@ class DataLoadUtil:
return depth_meters
@staticmethod
def load_seg(path, binocular=False):
if binocular:
def load_seg(path, binocular=False, left_only=False):
if binocular and not left_only:
def clean_mask(mask_image):
green = [0, 255, 0, 255]
@@ -182,11 +182,41 @@ class DataLoadUtil:
mask_image_R = clean_mask(cv2.imread(mask_path_R, cv2.IMREAD_UNCHANGED))
return mask_image_L, mask_image_R
else:
mask_path = os.path.join(
os.path.dirname(path), "mask", os.path.basename(path) + ".png"
)
if binocular and left_only:
mask_path = os.path.join(
os.path.dirname(path), "mask", os.path.basename(path) + "_L.png"
)
else:
mask_path = os.path.join(
os.path.dirname(path), "mask", os.path.basename(path) + ".png"
)
mask_image = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
return mask_image
@staticmethod
def load_normal(path, binocular=False, left_only=False):
if binocular and not left_only:
normal_path_L = os.path.join(
os.path.dirname(path), "normal", os.path.basename(path) + "_L.png"
)
normal_image_L = cv2.imread(normal_path_L, cv2.IMREAD_UNCHANGED)
normal_path_R = os.path.join(
os.path.dirname(path), "normal", os.path.basename(path) + "_R.png"
)
normal_image_R = cv2.imread(normal_path_R, cv2.IMREAD_UNCHANGED)
return normal_image_L, normal_image_R
else:
if binocular and left_only:
normal_path = os.path.join(
os.path.dirname(path), "normal", os.path.basename(path) + "_L.png"
)
else:
normal_path = os.path.join(
os.path.dirname(path), "normal", os.path.basename(path) + ".png"
)
normal_image = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED)
return normal_image
@staticmethod
def load_label(path):
@@ -273,7 +303,7 @@ class DataLoadUtil:
@staticmethod
def get_target_point_cloud(
depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(0, 255, 0, 255)
depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(0, 255, 0, 255), require_full_points=False
):
h, w = depth.shape
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy")
@@ -293,10 +323,11 @@ class DataLoadUtil:
)
target_points_world = np.dot(cam_extrinsic, target_points_camera_aug.T).T[:, :3]
return {
data = {
"points_world": target_points_world,
"points_camera": target_points_camera,
}
return data
@staticmethod
def get_point_cloud(depth, cam_intrinsic, cam_extrinsic):
@@ -323,7 +354,8 @@ class DataLoadUtil:
voxel_size=0.005,
target_mask_label=(0, 255, 0, 255),
display_table_mask_label=(0, 0, 255, 255),
get_display_table_pts=False
get_display_table_pts=False,
require_normal=False,
):
cam_info = DataLoadUtil.load_cam_info(path, binocular=binocular)
if binocular:
@@ -351,34 +383,9 @@ class DataLoadUtil:
point_cloud_R = PtsUtil.random_downsample_point_cloud(
point_cloud_R, random_downsample_N
)
overlap_points = DataLoadUtil.get_overlapping_points(
overlap_points = PtsUtil.get_overlapping_points(
point_cloud_L, point_cloud_R, voxel_size
)
if get_display_table_pts:
display_pts_L = DataLoadUtil.get_target_point_cloud(
depth_L,
cam_info["cam_intrinsic"],
cam_info["cam_to_world"],
mask_L,
display_table_mask_label,
)["points_world"]
display_pts_R = DataLoadUtil.get_target_point_cloud(
depth_R,
cam_info["cam_intrinsic"],
cam_info["cam_to_world_R"],
mask_R,
display_table_mask_label,
)["points_world"]
display_pts_L = PtsUtil.random_downsample_point_cloud(
display_pts_L, random_downsample_N
)
point_cloud_R = PtsUtil.random_downsample_point_cloud(
display_pts_R, random_downsample_N
)
display_pts_overlap = DataLoadUtil.get_overlapping_points(
display_pts_L, display_pts_R, voxel_size
)
return overlap_points, display_pts_overlap
return overlap_points
else:
depth = DataLoadUtil.load_depth(
@@ -390,27 +397,6 @@ class DataLoadUtil:
)["points_world"]
return point_cloud
@staticmethod
def voxelize_points(points, voxel_size):
voxel_indices = np.floor(points / voxel_size).astype(np.int32)
unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True)
return unique_voxels
@staticmethod
def get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size=0.005):
voxels_L, indices_L = DataLoadUtil.voxelize_points(point_cloud_L, voxel_size)
voxels_R, _ = DataLoadUtil.voxelize_points(point_cloud_R, voxel_size)
voxel_indices_L = voxels_L.view([("", voxels_L.dtype)] * 3)
voxel_indices_R = voxels_R.view([("", voxels_R.dtype)] * 3)
overlapping_voxels = np.intersect1d(voxel_indices_L, voxel_indices_R)
mask_L = np.isin(
indices_L, np.where(np.isin(voxel_indices_L, overlapping_voxels))[0]
)
overlapping_points = point_cloud_L[mask_L]
return overlapping_points
@staticmethod
def load_points_normals(root, scene_name, display_table_as_world_space_origin=True):
points_path = os.path.join(root, scene_name, "points_and_normals.txt")

View File

@@ -18,13 +18,49 @@ class PtsUtil:
return points_h[:, :3]
@staticmethod
def random_downsample_point_cloud(point_cloud, num_points):
def random_downsample_point_cloud(point_cloud, num_points, require_idx=False):
if point_cloud.shape[0] == 0:
return point_cloud
idx = np.random.choice(len(point_cloud), num_points, replace=True)
if require_idx:
return point_cloud[idx], idx
return point_cloud[idx]
@staticmethod
def random_downsample_point_cloud_tensor(point_cloud, num_points):
idx = torch.randint(0, len(point_cloud), (num_points,))
return point_cloud[idx]
return point_cloud[idx]
@staticmethod
def voxelize_points(points, voxel_size):
voxel_indices = np.floor(points / voxel_size).astype(np.int32)
unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True)
return unique_voxels
@staticmethod
def get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size=0.005, require_idx=False):
voxels_L, indices_L = PtsUtil.voxelize_points(point_cloud_L, voxel_size)
voxels_R, _ = PtsUtil.voxelize_points(point_cloud_R, voxel_size)
voxel_indices_L = voxels_L.view([("", voxels_L.dtype)] * 3)
voxel_indices_R = voxels_R.view([("", voxels_R.dtype)] * 3)
overlapping_voxels = np.intersect1d(voxel_indices_L, voxel_indices_R)
mask_L = np.isin(
indices_L, np.where(np.isin(voxel_indices_L, overlapping_voxels))[0]
)
overlapping_points = point_cloud_L[mask_L]
if require_idx:
return overlapping_points, mask_L
return overlapping_points
@staticmethod
def filter_points(points, normals, cam_pose, theta=75, require_idx=False):
camera_axis = -cam_pose[:3, 2]
normals_normalized = normals / np.linalg.norm(normals, axis=1, keepdims=True)
cos_theta = np.dot(normals_normalized, camera_axis)
theta_rad = np.deg2rad(theta)
idx = cos_theta > np.cos(theta_rad)
filtered_points= points[idx]
if require_idx:
return filtered_points, idx
return filtered_points

View File

@@ -129,22 +129,7 @@ class ReconstructionUtil:
runner_name = status_info["runner_name"]
sm.set_progress(app_name, runner_name, "processed view", len(point_cloud_list), len(point_cloud_list))
return view_sequence, remaining_views, down_sampled_combined_point_cloud
@staticmethod
def filter_points(points, points_normals, cam_pose, voxel_size=0.005, theta=75):
sampled_points = PtsUtil.voxel_downsample_point_cloud(points, voxel_size)
kdtree = cKDTree(points_normals[:,:3])
_, indices = kdtree.query(sampled_points)
nearest_points = points_normals[indices]
normals = nearest_points[:, 3:]
camera_axis = -cam_pose[:3, 2]
normals_normalized = normals / np.linalg.norm(normals, axis=1, keepdims=True)
cos_theta = np.dot(normals_normalized, camera_axis)
theta_rad = np.deg2rad(theta)
filtered_sampled_points= sampled_points[cos_theta > np.cos(theta_rad)]
return filtered_sampled_points[:, :3]
@staticmethod
def generate_scan_points(display_table_top, display_table_radius, min_distance=0.03, max_points_num = 100, max_attempts = 1000):

View File

@@ -33,12 +33,11 @@ class RenderUtil:
print(result.stderr)
return None
path = os.path.join(temp_dir, "tmp")
# ------ Debug Start ------
# import ipdb;ipdb.set_trace()
# ------ Debug End ------
point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True)
cam_params = DataLoadUtil.load_cam_info(path, binocular=True)
filtered_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=filter_degree)
''' TODO: old code: filter_points api is changed, need to update the code '''
filtered_point_cloud = PtsUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=filter_degree)
full_scene_point_cloud = None
if require_full_scene:
depth_L, depth_R = DataLoadUtil.load_depth(path, cam_params['near_plane'], cam_params['far_plane'], binocular=True)
@@ -47,7 +46,7 @@ class RenderUtil:
point_cloud_L = PtsUtil.random_downsample_point_cloud(point_cloud_L, 65536)
point_cloud_R = PtsUtil.random_downsample_point_cloud(point_cloud_R, 65536)
full_scene_point_cloud = DataLoadUtil.get_overlapping_points(point_cloud_L, point_cloud_R)
full_scene_point_cloud = PtsUtil.get_overlapping_points(point_cloud_L, point_cloud_R)
return filtered_point_cloud, full_scene_point_cloud