update basic framework

This commit is contained in:
hofee
2024-08-21 17:11:56 +08:00
parent 73dcd592df
commit f977fd4b8e
29 changed files with 1393 additions and 719 deletions

135
utils/data_load.py Normal file
View File

@@ -0,0 +1,135 @@
import os
import OpenEXR
import Imath
import numpy as np
import json
import cv2
class DataLoadUtil:
@staticmethod
def get_path(root, scene_idx, frame_idx):
path = os.path.join(root, f"sequence.{scene_idx}", f"step{frame_idx}")
return path
@staticmethod
def read_exr_depth(depth_path):
file = OpenEXR.InputFile(depth_path)
dw = file.header()['dataWindow']
width = dw.max.x - dw.min.x + 1
height = dw.max.y - dw.min.y + 1
pix_type = Imath.PixelType(Imath.PixelType.FLOAT)
depth_map = np.frombuffer(file.channel('R', pix_type), dtype=np.float32)
depth_map.shape = (height, width)
return depth_map
@staticmethod
def load_depth(path):
depth_path = path + ".camera.Depth.exr"
depth_map = DataLoadUtil.read_exr_depth(depth_path)
return depth_map
@staticmethod
def load_rgb(path):
rgb_path = path + ".camera.png"
rgb_image = cv2.imread(rgb_path, cv2.IMREAD_COLOR)
return rgb_image
@staticmethod
def load_seg(path):
seg_path = path + ".camera.semantic segmentation.png"
seg_image = cv2.imread(seg_path, cv2.IMREAD_COLOR)
return seg_image
@staticmethod
def load_cam_info(path):
label_path = path + ".camera_params.json"
with open(label_path, 'r') as f:
label_data = json.load(f)
cam_transform = np.asarray(label_data['cam_to_world']).reshape(
(4, 4)
).T
offset = np.asarray([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
cam_to_world = cam_transform @ offset
f_x = label_data['f_x']
f_y = label_data['f_y']
c_x = label_data['c_x']
c_y = label_data['c_y']
cam_intrinsic = np.array([[f_x, 0, c_x], [0, f_y, c_y], [0, 0, 1]])
return {
"cam_to_world": cam_to_world,
"cam_intrinsic": cam_intrinsic
}
@staticmethod
def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(255,255,255)):
h, w = depth.shape
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
z = depth
x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0]
y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1]
points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3)
points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1)
points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3]
mask = mask.reshape(-1, 3)
target_mask = np.all(mask == target_mask_label, axis=-1)
return {
"points_world": points_world[target_mask],
"points_camera": points_camera[target_mask]
}
@staticmethod
def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(255,255,255)):
h, w = depth.shape
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
z = depth
x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0]
y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1]
points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3)
points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1)
points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3]
mask = mask.reshape(-1, 3)
target_mask = np.all(mask == target_mask_label, axis=-1)
return {
"points_world": points_world[target_mask],
"points_camera": points_camera[target_mask]
}
@staticmethod
def get_point_cloud_world_from_path(path):
cam_info = DataLoadUtil.load_cam_info(path)
depth = DataLoadUtil.load_depth(path)
mask = DataLoadUtil.load_seg(path)
point_cloud = DataLoadUtil.get_target_point_cloud(depth, cam_info['cam_intrinsic'], cam_info['cam_to_world'], mask)
return point_cloud['points_world']
@staticmethod
def get_point_cloud_list_from_seq(root, seq_idx, num_frames):
point_cloud_list = []
for idx in range(num_frames):
path = DataLoadUtil.get_path(root, seq_idx, idx)
point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path)
point_cloud_list.append(point_cloud)
return point_cloud_list

246
utils/pose.py Normal file
View File

@@ -0,0 +1,246 @@
import numpy as np
import torch
import torch.nn.functional as F
class PoseUtil:
ROTATION = 1
TRANSLATION = 2
SCALE = 3
@staticmethod
def get_uniform_translation(trans_m_min, trans_m_max, trans_unit, debug=False):
if isinstance(trans_m_min, list):
x_min, y_min, z_min = trans_m_min
x_max, y_max, z_max = trans_m_max
else:
x_min, y_min, z_min = trans_m_min, trans_m_min, trans_m_min
x_max, y_max, z_max = trans_m_max, trans_m_max, trans_m_max
x = np.random.uniform(x_min, x_max)
y = np.random.uniform(y_min, y_max)
z = np.random.uniform(z_min, z_max)
translation = np.array([x, y, z])
if trans_unit == "cm":
translation = translation / 100
if debug:
print("uniform translation:", translation)
return translation
@staticmethod
def get_uniform_rotation(rot_degree_min=0, rot_degree_max=180, debug=False):
axis = np.random.randn(3)
axis /= np.linalg.norm(axis)
theta = np.random.uniform(
rot_degree_min / 180 * np.pi, rot_degree_max / 180 * np.pi
)
K = np.array(
[[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]]
)
R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K)
if debug:
print("uniform rotation:", theta * 180 / np.pi)
return R
@staticmethod
def get_uniform_pose(
trans_min, trans_max, rot_min=0, rot_max=180, trans_unit="cm", debug=False
):
translation = PoseUtil.get_uniform_translation(
trans_min, trans_max, trans_unit, debug
)
rotation = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
pose = np.eye(4)
pose[:3, :3] = rotation
pose[:3, 3] = translation
return pose
@staticmethod
def get_n_uniform_pose(
trans_min,
trans_max,
rot_min=0,
rot_max=180,
n=1,
trans_unit="cm",
fix=None,
contain_canonical=True,
debug=False,
):
if fix == PoseUtil.ROTATION:
translations = np.zeros((n, 3))
for i in range(n):
translations[i] = PoseUtil.get_uniform_translation(
trans_min, trans_max, trans_unit, debug
)
if contain_canonical:
translations[0] = np.zeros(3)
rotations = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
elif fix == PoseUtil.TRANSLATION:
rotations = np.zeros((n, 3, 3))
for i in range(n):
rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
if contain_canonical:
rotations[0] = np.eye(3)
translations = PoseUtil.get_uniform_translation(
trans_min, trans_max, trans_unit, debug
)
else:
translations = np.zeros((n, 3))
rotations = np.zeros((n, 3, 3))
for i in range(n):
translations[i] = PoseUtil.get_uniform_translation(
trans_min, trans_max, trans_unit, debug
)
for i in range(n):
rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
if contain_canonical:
translations[0] = np.zeros(3)
rotations[0] = np.eye(3)
pose = np.eye(4, 4, k=0)[np.newaxis, :].repeat(n, axis=0)
pose[:, :3, :3] = rotations
pose[:, :3, 3] = translations
return pose
@staticmethod
def get_n_uniform_pose_batch(
trans_min,
trans_max,
rot_min=0,
rot_max=180,
n=1,
batch_size=1,
trans_unit="cm",
fix=None,
contain_canonical=False,
debug=False,
):
batch_poses = []
for i in range(batch_size):
pose = PoseUtil.get_n_uniform_pose(
trans_min,
trans_max,
rot_min,
rot_max,
n,
trans_unit,
fix,
contain_canonical,
debug,
)
batch_poses.append(pose)
pose_batch = np.stack(batch_poses, axis=0)
return pose_batch
@staticmethod
def get_uniform_scale(scale_min, scale_max, debug=False):
if isinstance(scale_min, list):
x_min, y_min, z_min = scale_min
x_max, y_max, z_max = scale_max
else:
x_min, y_min, z_min = scale_min, scale_min, scale_min
x_max, y_max, z_max = scale_max, scale_max, scale_max
x = np.random.uniform(x_min, x_max)
y = np.random.uniform(y_min, y_max)
z = np.random.uniform(z_min, z_max)
scale = np.array([x, y, z])
if debug:
print("uniform scale:", scale)
return scale
@staticmethod
def normalize_rotation(rotation, rotation_mode):
if rotation_mode == "quat_wxyz" or rotation_mode == "quat_xyzw":
rotation /= torch.norm(rotation, dim=-1, keepdim=True)
elif rotation_mode == "rot_matrix":
rot_matrix = PoseUtil.rotation_6d_to_matrix_tensor_batch(rotation)
rotation[:, :3] = rot_matrix[:, 0, :]
rotation[:, 3:6] = rot_matrix[:, 1, :]
elif rotation_mode == "euler_xyz_sx_cx":
rot_sin_theta = rotation[:, :3]
rot_cos_theta = rotation[:, 3:6]
theta = torch.atan2(rot_sin_theta, rot_cos_theta)
rotation[:, :3] = torch.sin(theta)
rotation[:, 3:6] = torch.cos(theta)
elif rotation_mode == "euler_xyz":
pass
else:
raise NotImplementedError
return rotation
@staticmethod
def get_pose_dim(rot_mode):
assert rot_mode in [
"quat_wxyz",
"quat_xyzw",
"euler_xyz",
"euler_xyz_sx_cx",
"rot_matrix",
], f"the rotation mode {rot_mode} is not supported!"
if rot_mode == "quat_wxyz" or rot_mode == "quat_xyzw":
pose_dim = 4
elif rot_mode == "euler_xyz":
pose_dim = 3
elif rot_mode == "euler_xyz_sx_cx" or rot_mode == "rot_matrix":
pose_dim = 6
else:
raise NotImplementedError
return pose_dim
@staticmethod
def rotation_6d_to_matrix_tensor_batch(d6: torch.Tensor) -> torch.Tensor:
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)
@staticmethod
def matrix_to_rotation_6d_tensor_batch(matrix: torch.Tensor) -> torch.Tensor:
batch_dim = matrix.size()[:-2]
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
@staticmethod
def rotation_6d_to_matrix_numpy(d6):
a1, a2 = d6[:3], d6[3:]
b1 = a1 / np.linalg.norm(a1)
b2 = a2 - np.dot(b1, a2) * b1
b2 = b2 / np.linalg.norm(b2)
b3 = np.cross(b1, b2)
return np.stack((b1, b2, b3), axis=-2)
@staticmethod
def matrix_to_rotation_6d_numpy(matrix):
return np.copy(matrix[:2, :]).reshape((6,))
""" ------------ Debug ------------ """
if __name__ == "__main__":
for _ in range(1):
PoseUtil.get_uniform_pose(
trans_min=[-25, -25, 10],
trans_max=[25, 25, 60],
rot_min=0,
rot_max=10,
debug=True,
)
PoseUtil.get_uniform_scale(scale_min=0.25, scale_max=0.30, debug=True)
PoseUtil.get_n_uniform_pose_batch(
trans_min=[-25, -25, 10],
trans_max=[25, 25, 60],
rot_min=0,
rot_max=10,
batch_size=2,
n=2,
fix=PoseUtil.TRANSLATION,
debug=True,
)

139
utils/reconstruction.py Normal file
View File

@@ -0,0 +1,139 @@
import numpy as np
import open3d as o3d
from scipy.spatial import cKDTree
class ReconstructionUtil:
@staticmethod
def compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold=0.01):
kdtree = cKDTree(combined_point_cloud)
distances, _ = kdtree.query(target_point_cloud)
covered_points = np.sum(distances < threshold)
coverage_rate = covered_points / target_point_cloud.shape[0]
return coverage_rate
@staticmethod
def compute_overlap_rate(point_cloud1, point_cloud2, threshold=0.01):
kdtree1 = cKDTree(point_cloud1)
kdtree2 = cKDTree(point_cloud2)
distances1, _ = kdtree2.query(point_cloud1)
distances2, _ = kdtree1.query(point_cloud2)
overlapping_points1 = np.sum(distances1 < threshold)
overlapping_points2 = np.sum(distances2 < threshold)
overlap_rate1 = overlapping_points1 / point_cloud1.shape[0]
overlap_rate2 = overlapping_points2 / point_cloud2.shape[0]
return (overlap_rate1 + overlap_rate2) / 2
@staticmethod
def combine_point_with_view_sequence(point_list, view_sequence):
selected_views = []
for view_index, _ in view_sequence:
selected_views.append(point_list[view_index])
return np.vstack(selected_views)
@staticmethod
def compute_next_view_coverage_list(views, combined_point_cloud, target_point_cloud, threshold=0.01):
best_view = None
best_coverage_increase = -1
current_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold)
for view_index, view in enumerate(views):
candidate_views = combined_point_cloud + [view]
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(candidate_views, threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
coverage_increase = new_coverage - current_coverage
if coverage_increase > best_coverage_increase:
best_coverage_increase = coverage_increase
best_view = view_index
return best_view, best_coverage_increase
@staticmethod
def compute_next_best_view_sequence(target_point_cloud, point_cloud_list, threshold=0.01):
selected_views = []
current_coverage = 0.0
remaining_views = list(range(len(point_cloud_list)))
view_sequence = []
target_point_cloud = ReconstructionUtil.downsample_point_cloud(target_point_cloud, threshold)
while remaining_views:
best_view = None
best_coverage_increase = -1
for view_index in remaining_views:
candidate_views = selected_views + [point_cloud_list[view_index]]
combined_point_cloud = np.vstack(candidate_views)
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_point_cloud,threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
coverage_increase = new_coverage - current_coverage
if coverage_increase > best_coverage_increase:
best_coverage_increase = coverage_increase
best_view = view_index
if best_view is not None:
if best_coverage_increase <=1e-3:
break
selected_views.append(point_cloud_list[best_view])
current_coverage += best_coverage_increase
view_sequence.append((best_view, current_coverage))
remaining_views.remove(best_view)
return view_sequence, remaining_views
@staticmethod
def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, threshold=0.01, overlap_threshold=0.3):
selected_views = []
current_coverage = 0.0
remaining_views = list(range(len(point_cloud_list)))
view_sequence = []
target_point_cloud = ReconstructionUtil.downsample_point_cloud(target_point_cloud, threshold)
while remaining_views:
best_view = None
best_coverage_increase = -1
for view_index in remaining_views:
if selected_views:
combined_old_point_cloud = np.vstack(selected_views)
down_sampled_old_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_old_point_cloud,threshold)
down_sampled_new_view_point_cloud = ReconstructionUtil.downsample_point_cloud(point_cloud_list[view_index],threshold)
overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_old_point_cloud,down_sampled_new_view_point_cloud , threshold)
if overlap_rate < overlap_threshold:
continue
candidate_views = selected_views + [point_cloud_list[view_index]]
combined_point_cloud = np.vstack(candidate_views)
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_point_cloud,threshold)
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
coverage_increase = new_coverage - current_coverage
#print(f"view_index: {view_index}, coverage_increase: {coverage_increase}")
if coverage_increase > best_coverage_increase:
best_coverage_increase = coverage_increase
best_view = view_index
if best_view is not None:
if best_coverage_increase <=1e-3:
break
selected_views.append(point_cloud_list[best_view])
remaining_views.remove(best_view)
if best_coverage_increase > 0:
current_coverage += best_coverage_increase
view_sequence.append((best_view, current_coverage))
else:
break
return view_sequence, remaining_views
def downsample_point_cloud(point_cloud, voxel_size=0.005):
o3d_pc = o3d.geometry.PointCloud()
o3d_pc.points = o3d.utility.Vector3dVector(point_cloud)
downsampled_pc = o3d_pc.voxel_down_sample(voxel_size)
return np.asarray(downsampled_pc.points)

View File

@@ -1,47 +0,0 @@
import torch
class TensorboardWriter:
@staticmethod
def write_tensorboard(writer, panel, data_dict, step):
complex_dict = False
if "scalars" in data_dict:
scalar_data_dict = data_dict["scalars"]
TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step)
complex_dict = True
if "images" in data_dict:
image_data_dict = data_dict["images"]
TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step)
complex_dict = True
if "points" in data_dict:
point_data_dict = data_dict["points"]
TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step)
complex_dict = True
if not complex_dict:
TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step)
@staticmethod
def write_scalar_tensorboard(writer, panel, data_dict, step):
for key, value in data_dict.items():
if isinstance(value, dict):
writer.add_scalars(f'{panel}/{key}', value, step)
else:
writer.add_scalar(f'{panel}/{key}', value, step)
@staticmethod
def write_image_tensorboard(writer, panel, data_dict, step):
pass
@staticmethod
def write_points_tensorboard(writer, panel, data_dict, step):
for key, value in data_dict.items():
if value.shape[-1] == 3:
colors = torch.zeros_like(value)
vertices = torch.cat([value, colors], dim=-1)
elif value.shape[-1] == 6:
vertices = value
else:
raise ValueError(f'Unexpected value shape: {value.shape}')
faces = None
writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)