update basic framework
This commit is contained in:
135
utils/data_load.py
Normal file
135
utils/data_load.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import os
|
||||
import OpenEXR
|
||||
import Imath
|
||||
import numpy as np
|
||||
import json
|
||||
import cv2
|
||||
|
||||
class DataLoadUtil:
|
||||
|
||||
@staticmethod
|
||||
def get_path(root, scene_idx, frame_idx):
|
||||
path = os.path.join(root, f"sequence.{scene_idx}", f"step{frame_idx}")
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def read_exr_depth(depth_path):
|
||||
file = OpenEXR.InputFile(depth_path)
|
||||
|
||||
dw = file.header()['dataWindow']
|
||||
width = dw.max.x - dw.min.x + 1
|
||||
height = dw.max.y - dw.min.y + 1
|
||||
|
||||
pix_type = Imath.PixelType(Imath.PixelType.FLOAT)
|
||||
depth_map = np.frombuffer(file.channel('R', pix_type), dtype=np.float32)
|
||||
|
||||
depth_map.shape = (height, width)
|
||||
|
||||
return depth_map
|
||||
|
||||
@staticmethod
|
||||
def load_depth(path):
|
||||
depth_path = path + ".camera.Depth.exr"
|
||||
depth_map = DataLoadUtil.read_exr_depth(depth_path)
|
||||
return depth_map
|
||||
|
||||
@staticmethod
|
||||
def load_rgb(path):
|
||||
rgb_path = path + ".camera.png"
|
||||
rgb_image = cv2.imread(rgb_path, cv2.IMREAD_COLOR)
|
||||
return rgb_image
|
||||
|
||||
@staticmethod
|
||||
def load_seg(path):
|
||||
seg_path = path + ".camera.semantic segmentation.png"
|
||||
seg_image = cv2.imread(seg_path, cv2.IMREAD_COLOR)
|
||||
return seg_image
|
||||
|
||||
@staticmethod
|
||||
def load_cam_info(path):
|
||||
label_path = path + ".camera_params.json"
|
||||
with open(label_path, 'r') as f:
|
||||
label_data = json.load(f)
|
||||
cam_transform = np.asarray(label_data['cam_to_world']).reshape(
|
||||
(4, 4)
|
||||
).T
|
||||
|
||||
offset = np.asarray([
|
||||
[1, 0, 0, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]])
|
||||
|
||||
cam_to_world = cam_transform @ offset
|
||||
|
||||
|
||||
|
||||
f_x = label_data['f_x']
|
||||
f_y = label_data['f_y']
|
||||
c_x = label_data['c_x']
|
||||
c_y = label_data['c_y']
|
||||
cam_intrinsic = np.array([[f_x, 0, c_x], [0, f_y, c_y], [0, 0, 1]])
|
||||
|
||||
return {
|
||||
"cam_to_world": cam_to_world,
|
||||
"cam_intrinsic": cam_intrinsic
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(255,255,255)):
|
||||
h, w = depth.shape
|
||||
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
|
||||
|
||||
z = depth
|
||||
x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0]
|
||||
y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1]
|
||||
|
||||
points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3)
|
||||
points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1)
|
||||
|
||||
points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3]
|
||||
mask = mask.reshape(-1, 3)
|
||||
target_mask = np.all(mask == target_mask_label, axis=-1)
|
||||
return {
|
||||
"points_world": points_world[target_mask],
|
||||
"points_camera": points_camera[target_mask]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_target_point_cloud(depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(255,255,255)):
|
||||
h, w = depth.shape
|
||||
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
|
||||
|
||||
z = depth
|
||||
x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0]
|
||||
y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1]
|
||||
|
||||
points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3)
|
||||
points_camera_aug = np.concatenate([points_camera, np.ones((points_camera.shape[0], 1))], axis=-1)
|
||||
|
||||
points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3]
|
||||
mask = mask.reshape(-1, 3)
|
||||
target_mask = np.all(mask == target_mask_label, axis=-1)
|
||||
return {
|
||||
"points_world": points_world[target_mask],
|
||||
"points_camera": points_camera[target_mask]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_point_cloud_world_from_path(path):
|
||||
cam_info = DataLoadUtil.load_cam_info(path)
|
||||
depth = DataLoadUtil.load_depth(path)
|
||||
mask = DataLoadUtil.load_seg(path)
|
||||
point_cloud = DataLoadUtil.get_target_point_cloud(depth, cam_info['cam_intrinsic'], cam_info['cam_to_world'], mask)
|
||||
return point_cloud['points_world']
|
||||
|
||||
@staticmethod
|
||||
def get_point_cloud_list_from_seq(root, seq_idx, num_frames):
|
||||
point_cloud_list = []
|
||||
for idx in range(num_frames):
|
||||
path = DataLoadUtil.get_path(root, seq_idx, idx)
|
||||
point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path)
|
||||
point_cloud_list.append(point_cloud)
|
||||
return point_cloud_list
|
||||
|
246
utils/pose.py
Normal file
246
utils/pose.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PoseUtil:
|
||||
ROTATION = 1
|
||||
TRANSLATION = 2
|
||||
SCALE = 3
|
||||
|
||||
@staticmethod
|
||||
def get_uniform_translation(trans_m_min, trans_m_max, trans_unit, debug=False):
|
||||
if isinstance(trans_m_min, list):
|
||||
x_min, y_min, z_min = trans_m_min
|
||||
x_max, y_max, z_max = trans_m_max
|
||||
else:
|
||||
x_min, y_min, z_min = trans_m_min, trans_m_min, trans_m_min
|
||||
x_max, y_max, z_max = trans_m_max, trans_m_max, trans_m_max
|
||||
|
||||
x = np.random.uniform(x_min, x_max)
|
||||
y = np.random.uniform(y_min, y_max)
|
||||
z = np.random.uniform(z_min, z_max)
|
||||
translation = np.array([x, y, z])
|
||||
if trans_unit == "cm":
|
||||
translation = translation / 100
|
||||
if debug:
|
||||
print("uniform translation:", translation)
|
||||
return translation
|
||||
|
||||
@staticmethod
|
||||
def get_uniform_rotation(rot_degree_min=0, rot_degree_max=180, debug=False):
|
||||
axis = np.random.randn(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
theta = np.random.uniform(
|
||||
rot_degree_min / 180 * np.pi, rot_degree_max / 180 * np.pi
|
||||
)
|
||||
|
||||
K = np.array(
|
||||
[[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]]
|
||||
)
|
||||
R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K)
|
||||
if debug:
|
||||
print("uniform rotation:", theta * 180 / np.pi)
|
||||
return R
|
||||
|
||||
@staticmethod
|
||||
def get_uniform_pose(
|
||||
trans_min, trans_max, rot_min=0, rot_max=180, trans_unit="cm", debug=False
|
||||
):
|
||||
translation = PoseUtil.get_uniform_translation(
|
||||
trans_min, trans_max, trans_unit, debug
|
||||
)
|
||||
rotation = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
|
||||
pose = np.eye(4)
|
||||
pose[:3, :3] = rotation
|
||||
pose[:3, 3] = translation
|
||||
return pose
|
||||
|
||||
@staticmethod
|
||||
def get_n_uniform_pose(
|
||||
trans_min,
|
||||
trans_max,
|
||||
rot_min=0,
|
||||
rot_max=180,
|
||||
n=1,
|
||||
trans_unit="cm",
|
||||
fix=None,
|
||||
contain_canonical=True,
|
||||
debug=False,
|
||||
):
|
||||
if fix == PoseUtil.ROTATION:
|
||||
translations = np.zeros((n, 3))
|
||||
for i in range(n):
|
||||
translations[i] = PoseUtil.get_uniform_translation(
|
||||
trans_min, trans_max, trans_unit, debug
|
||||
)
|
||||
if contain_canonical:
|
||||
translations[0] = np.zeros(3)
|
||||
rotations = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
|
||||
elif fix == PoseUtil.TRANSLATION:
|
||||
rotations = np.zeros((n, 3, 3))
|
||||
for i in range(n):
|
||||
rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
|
||||
if contain_canonical:
|
||||
rotations[0] = np.eye(3)
|
||||
translations = PoseUtil.get_uniform_translation(
|
||||
trans_min, trans_max, trans_unit, debug
|
||||
)
|
||||
else:
|
||||
translations = np.zeros((n, 3))
|
||||
rotations = np.zeros((n, 3, 3))
|
||||
for i in range(n):
|
||||
translations[i] = PoseUtil.get_uniform_translation(
|
||||
trans_min, trans_max, trans_unit, debug
|
||||
)
|
||||
for i in range(n):
|
||||
rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug)
|
||||
if contain_canonical:
|
||||
translations[0] = np.zeros(3)
|
||||
rotations[0] = np.eye(3)
|
||||
|
||||
pose = np.eye(4, 4, k=0)[np.newaxis, :].repeat(n, axis=0)
|
||||
pose[:, :3, :3] = rotations
|
||||
pose[:, :3, 3] = translations
|
||||
|
||||
return pose
|
||||
|
||||
@staticmethod
|
||||
def get_n_uniform_pose_batch(
|
||||
trans_min,
|
||||
trans_max,
|
||||
rot_min=0,
|
||||
rot_max=180,
|
||||
n=1,
|
||||
batch_size=1,
|
||||
trans_unit="cm",
|
||||
fix=None,
|
||||
contain_canonical=False,
|
||||
debug=False,
|
||||
):
|
||||
|
||||
batch_poses = []
|
||||
for i in range(batch_size):
|
||||
pose = PoseUtil.get_n_uniform_pose(
|
||||
trans_min,
|
||||
trans_max,
|
||||
rot_min,
|
||||
rot_max,
|
||||
n,
|
||||
trans_unit,
|
||||
fix,
|
||||
contain_canonical,
|
||||
debug,
|
||||
)
|
||||
batch_poses.append(pose)
|
||||
pose_batch = np.stack(batch_poses, axis=0)
|
||||
return pose_batch
|
||||
|
||||
@staticmethod
|
||||
def get_uniform_scale(scale_min, scale_max, debug=False):
|
||||
if isinstance(scale_min, list):
|
||||
x_min, y_min, z_min = scale_min
|
||||
x_max, y_max, z_max = scale_max
|
||||
else:
|
||||
x_min, y_min, z_min = scale_min, scale_min, scale_min
|
||||
x_max, y_max, z_max = scale_max, scale_max, scale_max
|
||||
|
||||
x = np.random.uniform(x_min, x_max)
|
||||
y = np.random.uniform(y_min, y_max)
|
||||
z = np.random.uniform(z_min, z_max)
|
||||
scale = np.array([x, y, z])
|
||||
if debug:
|
||||
print("uniform scale:", scale)
|
||||
return scale
|
||||
|
||||
@staticmethod
|
||||
def normalize_rotation(rotation, rotation_mode):
|
||||
if rotation_mode == "quat_wxyz" or rotation_mode == "quat_xyzw":
|
||||
rotation /= torch.norm(rotation, dim=-1, keepdim=True)
|
||||
elif rotation_mode == "rot_matrix":
|
||||
rot_matrix = PoseUtil.rotation_6d_to_matrix_tensor_batch(rotation)
|
||||
rotation[:, :3] = rot_matrix[:, 0, :]
|
||||
rotation[:, 3:6] = rot_matrix[:, 1, :]
|
||||
elif rotation_mode == "euler_xyz_sx_cx":
|
||||
rot_sin_theta = rotation[:, :3]
|
||||
rot_cos_theta = rotation[:, 3:6]
|
||||
theta = torch.atan2(rot_sin_theta, rot_cos_theta)
|
||||
rotation[:, :3] = torch.sin(theta)
|
||||
rotation[:, 3:6] = torch.cos(theta)
|
||||
elif rotation_mode == "euler_xyz":
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return rotation
|
||||
|
||||
@staticmethod
|
||||
def get_pose_dim(rot_mode):
|
||||
assert rot_mode in [
|
||||
"quat_wxyz",
|
||||
"quat_xyzw",
|
||||
"euler_xyz",
|
||||
"euler_xyz_sx_cx",
|
||||
"rot_matrix",
|
||||
], f"the rotation mode {rot_mode} is not supported!"
|
||||
|
||||
if rot_mode == "quat_wxyz" or rot_mode == "quat_xyzw":
|
||||
pose_dim = 4
|
||||
elif rot_mode == "euler_xyz":
|
||||
pose_dim = 3
|
||||
elif rot_mode == "euler_xyz_sx_cx" or rot_mode == "rot_matrix":
|
||||
pose_dim = 6
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return pose_dim
|
||||
|
||||
@staticmethod
|
||||
def rotation_6d_to_matrix_tensor_batch(d6: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
a1, a2 = d6[..., :3], d6[..., 3:]
|
||||
b1 = F.normalize(a1, dim=-1)
|
||||
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
||||
b2 = F.normalize(b2, dim=-1)
|
||||
b3 = torch.cross(b1, b2, dim=-1)
|
||||
return torch.stack((b1, b2, b3), dim=-2)
|
||||
|
||||
@staticmethod
|
||||
def matrix_to_rotation_6d_tensor_batch(matrix: torch.Tensor) -> torch.Tensor:
|
||||
batch_dim = matrix.size()[:-2]
|
||||
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
|
||||
|
||||
@staticmethod
|
||||
def rotation_6d_to_matrix_numpy(d6):
|
||||
a1, a2 = d6[:3], d6[3:]
|
||||
b1 = a1 / np.linalg.norm(a1)
|
||||
b2 = a2 - np.dot(b1, a2) * b1
|
||||
b2 = b2 / np.linalg.norm(b2)
|
||||
b3 = np.cross(b1, b2)
|
||||
return np.stack((b1, b2, b3), axis=-2)
|
||||
|
||||
@staticmethod
|
||||
def matrix_to_rotation_6d_numpy(matrix):
|
||||
return np.copy(matrix[:2, :]).reshape((6,))
|
||||
|
||||
|
||||
""" ------------ Debug ------------ """
|
||||
|
||||
if __name__ == "__main__":
|
||||
for _ in range(1):
|
||||
PoseUtil.get_uniform_pose(
|
||||
trans_min=[-25, -25, 10],
|
||||
trans_max=[25, 25, 60],
|
||||
rot_min=0,
|
||||
rot_max=10,
|
||||
debug=True,
|
||||
)
|
||||
PoseUtil.get_uniform_scale(scale_min=0.25, scale_max=0.30, debug=True)
|
||||
PoseUtil.get_n_uniform_pose_batch(
|
||||
trans_min=[-25, -25, 10],
|
||||
trans_max=[25, 25, 60],
|
||||
rot_min=0,
|
||||
rot_max=10,
|
||||
batch_size=2,
|
||||
n=2,
|
||||
fix=PoseUtil.TRANSLATION,
|
||||
debug=True,
|
||||
)
|
139
utils/reconstruction.py
Normal file
139
utils/reconstruction.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import numpy as np
|
||||
import open3d as o3d
|
||||
from scipy.spatial import cKDTree
|
||||
|
||||
class ReconstructionUtil:
|
||||
|
||||
@staticmethod
|
||||
def compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold=0.01):
|
||||
kdtree = cKDTree(combined_point_cloud)
|
||||
distances, _ = kdtree.query(target_point_cloud)
|
||||
covered_points = np.sum(distances < threshold)
|
||||
coverage_rate = covered_points / target_point_cloud.shape[0]
|
||||
return coverage_rate
|
||||
|
||||
@staticmethod
|
||||
def compute_overlap_rate(point_cloud1, point_cloud2, threshold=0.01):
|
||||
kdtree1 = cKDTree(point_cloud1)
|
||||
kdtree2 = cKDTree(point_cloud2)
|
||||
distances1, _ = kdtree2.query(point_cloud1)
|
||||
distances2, _ = kdtree1.query(point_cloud2)
|
||||
overlapping_points1 = np.sum(distances1 < threshold)
|
||||
overlapping_points2 = np.sum(distances2 < threshold)
|
||||
|
||||
overlap_rate1 = overlapping_points1 / point_cloud1.shape[0]
|
||||
overlap_rate2 = overlapping_points2 / point_cloud2.shape[0]
|
||||
|
||||
return (overlap_rate1 + overlap_rate2) / 2
|
||||
|
||||
@staticmethod
|
||||
def combine_point_with_view_sequence(point_list, view_sequence):
|
||||
selected_views = []
|
||||
for view_index, _ in view_sequence:
|
||||
selected_views.append(point_list[view_index])
|
||||
return np.vstack(selected_views)
|
||||
|
||||
@staticmethod
|
||||
def compute_next_view_coverage_list(views, combined_point_cloud, target_point_cloud, threshold=0.01):
|
||||
best_view = None
|
||||
best_coverage_increase = -1
|
||||
current_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold)
|
||||
|
||||
for view_index, view in enumerate(views):
|
||||
candidate_views = combined_point_cloud + [view]
|
||||
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(candidate_views, threshold)
|
||||
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
|
||||
coverage_increase = new_coverage - current_coverage
|
||||
if coverage_increase > best_coverage_increase:
|
||||
best_coverage_increase = coverage_increase
|
||||
best_view = view_index
|
||||
return best_view, best_coverage_increase
|
||||
|
||||
@staticmethod
|
||||
def compute_next_best_view_sequence(target_point_cloud, point_cloud_list, threshold=0.01):
|
||||
selected_views = []
|
||||
current_coverage = 0.0
|
||||
remaining_views = list(range(len(point_cloud_list)))
|
||||
view_sequence = []
|
||||
target_point_cloud = ReconstructionUtil.downsample_point_cloud(target_point_cloud, threshold)
|
||||
while remaining_views:
|
||||
best_view = None
|
||||
best_coverage_increase = -1
|
||||
|
||||
for view_index in remaining_views:
|
||||
candidate_views = selected_views + [point_cloud_list[view_index]]
|
||||
combined_point_cloud = np.vstack(candidate_views)
|
||||
|
||||
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_point_cloud,threshold)
|
||||
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
|
||||
coverage_increase = new_coverage - current_coverage
|
||||
|
||||
if coverage_increase > best_coverage_increase:
|
||||
best_coverage_increase = coverage_increase
|
||||
best_view = view_index
|
||||
|
||||
if best_view is not None:
|
||||
if best_coverage_increase <=1e-3:
|
||||
break
|
||||
selected_views.append(point_cloud_list[best_view])
|
||||
current_coverage += best_coverage_increase
|
||||
view_sequence.append((best_view, current_coverage))
|
||||
remaining_views.remove(best_view)
|
||||
return view_sequence, remaining_views
|
||||
|
||||
|
||||
@staticmethod
|
||||
def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, threshold=0.01, overlap_threshold=0.3):
|
||||
selected_views = []
|
||||
current_coverage = 0.0
|
||||
remaining_views = list(range(len(point_cloud_list)))
|
||||
view_sequence = []
|
||||
target_point_cloud = ReconstructionUtil.downsample_point_cloud(target_point_cloud, threshold)
|
||||
|
||||
while remaining_views:
|
||||
best_view = None
|
||||
best_coverage_increase = -1
|
||||
|
||||
for view_index in remaining_views:
|
||||
|
||||
if selected_views:
|
||||
combined_old_point_cloud = np.vstack(selected_views)
|
||||
down_sampled_old_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_old_point_cloud,threshold)
|
||||
down_sampled_new_view_point_cloud = ReconstructionUtil.downsample_point_cloud(point_cloud_list[view_index],threshold)
|
||||
overlap_rate = ReconstructionUtil.compute_overlap_rate(down_sampled_old_point_cloud,down_sampled_new_view_point_cloud , threshold)
|
||||
if overlap_rate < overlap_threshold:
|
||||
continue
|
||||
|
||||
candidate_views = selected_views + [point_cloud_list[view_index]]
|
||||
combined_point_cloud = np.vstack(candidate_views)
|
||||
down_sampled_combined_point_cloud = ReconstructionUtil.downsample_point_cloud(combined_point_cloud,threshold)
|
||||
new_coverage = ReconstructionUtil.compute_coverage_rate(target_point_cloud, down_sampled_combined_point_cloud, threshold)
|
||||
coverage_increase = new_coverage - current_coverage
|
||||
#print(f"view_index: {view_index}, coverage_increase: {coverage_increase}")
|
||||
if coverage_increase > best_coverage_increase:
|
||||
best_coverage_increase = coverage_increase
|
||||
best_view = view_index
|
||||
|
||||
if best_view is not None:
|
||||
if best_coverage_increase <=1e-3:
|
||||
break
|
||||
selected_views.append(point_cloud_list[best_view])
|
||||
remaining_views.remove(best_view)
|
||||
if best_coverage_increase > 0:
|
||||
current_coverage += best_coverage_increase
|
||||
|
||||
view_sequence.append((best_view, current_coverage))
|
||||
|
||||
else:
|
||||
break
|
||||
|
||||
return view_sequence, remaining_views
|
||||
|
||||
|
||||
def downsample_point_cloud(point_cloud, voxel_size=0.005):
|
||||
o3d_pc = o3d.geometry.PointCloud()
|
||||
o3d_pc.points = o3d.utility.Vector3dVector(point_cloud)
|
||||
downsampled_pc = o3d_pc.voxel_down_sample(voxel_size)
|
||||
return np.asarray(downsampled_pc.points)
|
||||
|
||||
|
@@ -1,47 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TensorboardWriter:
|
||||
@staticmethod
|
||||
def write_tensorboard(writer, panel, data_dict, step):
|
||||
complex_dict = False
|
||||
if "scalars" in data_dict:
|
||||
scalar_data_dict = data_dict["scalars"]
|
||||
TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step)
|
||||
complex_dict = True
|
||||
if "images" in data_dict:
|
||||
image_data_dict = data_dict["images"]
|
||||
TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step)
|
||||
complex_dict = True
|
||||
if "points" in data_dict:
|
||||
point_data_dict = data_dict["points"]
|
||||
TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step)
|
||||
complex_dict = True
|
||||
|
||||
if not complex_dict:
|
||||
TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step)
|
||||
|
||||
@staticmethod
|
||||
def write_scalar_tensorboard(writer, panel, data_dict, step):
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, dict):
|
||||
writer.add_scalars(f'{panel}/{key}', value, step)
|
||||
else:
|
||||
writer.add_scalar(f'{panel}/{key}', value, step)
|
||||
|
||||
@staticmethod
|
||||
def write_image_tensorboard(writer, panel, data_dict, step):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def write_points_tensorboard(writer, panel, data_dict, step):
|
||||
for key, value in data_dict.items():
|
||||
if value.shape[-1] == 3:
|
||||
colors = torch.zeros_like(value)
|
||||
vertices = torch.cat([value, colors], dim=-1)
|
||||
elif value.shape[-1] == 6:
|
||||
vertices = value
|
||||
else:
|
||||
raise ValueError(f'Unexpected value shape: {value.shape}')
|
||||
faces = None
|
||||
writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)
|
Reference in New Issue
Block a user