multi-view policy test passed
This commit is contained in:
@@ -4,43 +4,109 @@ import numpy as np
|
||||
import rospy
|
||||
from .policy import MultiViewPolicy
|
||||
from .timer import Timer
|
||||
|
||||
from .active_perception_demo import APInferenceEngine
|
||||
from robot_helpers.spatial import Transform
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ActivePerceptionPolicy(MultiViewPolicy):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class RealTime3DVisualizer:
|
||||
def __init__(self):
|
||||
points = np.random.rand(1, 1, 3)
|
||||
self.points = points[0] # Extract the points (n, 3)
|
||||
self.fig = plt.figure()
|
||||
self.ax = self.fig.add_subplot(111, projection='3d')
|
||||
|
||||
# Initial plot setup
|
||||
self.scatter = self.ax.scatter(self.points[:, 0], self.points[:, 1], self.points[:, 2], c='b', marker='o')
|
||||
|
||||
# Set labels for each axis
|
||||
self.ax.set_xlabel('X')
|
||||
self.ax.set_ylabel('Y')
|
||||
self.ax.set_zlabel('Z')
|
||||
|
||||
# Set title
|
||||
self.ax.set_title('Real-time 3D Points Visualization')
|
||||
|
||||
# Show the plot in interactive mode
|
||||
plt.ion()
|
||||
plt.show()
|
||||
|
||||
|
||||
def update_points(self, new_points):
|
||||
# Ensure the points have the expected shape (1, n, 3)
|
||||
assert new_points.shape[0] == 1 and new_points.shape[2] == 3, "Input points must have shape (1, n, 3)"
|
||||
|
||||
# Update the stored points
|
||||
self.points = new_points[0] # Extract the points (n, 3)
|
||||
|
||||
# Remove the old scatter plot and draw new points
|
||||
self.scatter.remove()
|
||||
self.scatter = self.ax.scatter(self.points[:, 0], self.points[:, 1], self.points[:, 2], c='b', marker='o')
|
||||
|
||||
# Pause briefly to allow the plot to update
|
||||
plt.pause(0.001)
|
||||
|
||||
|
||||
class ActivePerceptionMultiViewPolicy(MultiViewPolicy):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_views = rospy.get_param("ap_grasp/max_views")
|
||||
self.ap_config_path = rospy.get_param("ap_grasp/ap_config_path")
|
||||
self.ap_inference_engine = APInferenceEngine(self.ap_config_path)
|
||||
self.pcdvis = RealTime3DVisualizer()
|
||||
|
||||
|
||||
def activate(self, bbox, view_sphere):
|
||||
super().activate(bbox, view_sphere)
|
||||
|
||||
def update(self, img, seg, target_id, x, q):
|
||||
target_points, scene_points = self.depth_image_to_ap_input(img, seg, target_id)
|
||||
# if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
|
||||
# self.done = True
|
||||
# else:
|
||||
# with Timer("state_update"):
|
||||
# self.integrate(img, x, q)
|
||||
# with Timer("view_generation"):
|
||||
# views = self.generate_views(q)
|
||||
# with Timer("ig_computation"):
|
||||
# gains = [self.ig_fn(v, self.downsample) for v in views]
|
||||
# with Timer("cost_computation"):
|
||||
# costs = [self.cost_fn(v) for v in views]
|
||||
# utilities = gains / np.sum(gains) - costs / np.sum(costs)
|
||||
# self.vis.ig_views(self.base_frame, self.intrinsic, views, utilities)
|
||||
# i = np.argmax(utilities)
|
||||
# nbv, gain = views[i], gains[i]
|
||||
if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
|
||||
self.done = True
|
||||
else:
|
||||
with Timer("state_update"):
|
||||
self.integrate(img, x, q)
|
||||
with Timer("view_generation"):
|
||||
target_points, scene_points = self.depth_image_to_ap_input(img, seg, target_id)
|
||||
ap_input = {'target_pts': target_points,
|
||||
'scene_pts': scene_points}
|
||||
ap_output = self.ap_inference_engine.inference(ap_input)
|
||||
delta_rot_6d = ap_output['estimated_delta_rot_6d']
|
||||
|
||||
# if gain < self.min_gain and len(self.views) > self.T:
|
||||
# self.done = True
|
||||
|
||||
# self.x_d = nbv
|
||||
current_cam_pose = torch.from_numpy(x.as_matrix()).float().to("cuda:0")
|
||||
est_delta_rot_mat = self.rotation_6d_to_matrix_tensor_batch(delta_rot_6d)[0]
|
||||
look_at_center = torch.from_numpy(self.bbox.center).float().to("cuda:0")
|
||||
nbv_tensor = self.get_transformed_mat(current_cam_pose,
|
||||
est_delta_rot_mat,
|
||||
look_at_center)
|
||||
nbv_numpy = nbv_tensor.cpu().numpy()
|
||||
nbv_transform = Transform.from_matrix(nbv_numpy)
|
||||
self.x_d = nbv_transform
|
||||
|
||||
|
||||
def rotation_6d_to_matrix_tensor_batch(self, d6: torch.Tensor) -> torch.Tensor:
|
||||
a1, a2 = d6[..., :3], d6[..., 3:]
|
||||
b1 = F.normalize(a1, dim=-1)
|
||||
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
||||
b2 = F.normalize(b2, dim=-1)
|
||||
b3 = torch.cross(b1, b2, dim=-1)
|
||||
return torch.stack((b1, b2, b3), dim=-2)
|
||||
|
||||
|
||||
def get_transformed_mat(self, src_mat, delta_rot, target_center_w):
|
||||
src_rot = src_mat[:3, :3]
|
||||
dst_rot = src_rot @ delta_rot.T
|
||||
dst_mat = torch.eye(4).to(dst_rot.device)
|
||||
dst_mat[:3, :3] = dst_rot
|
||||
distance = torch.norm(target_center_w - src_mat[:3, 3])
|
||||
z_axis_camera = dst_rot[:3, 2].reshape(-1)
|
||||
new_camera_position_w = target_center_w - distance * z_axis_camera
|
||||
dst_mat[:3, 3] = new_camera_position_w
|
||||
return dst_mat
|
||||
|
||||
def depth_image_to_ap_input(self, depth_img, seg_img, target_id):
|
||||
target_points = []
|
||||
scene_points = []
|
||||
@@ -76,8 +142,11 @@ class ActivePerceptionPolicy(MultiViewPolicy):
|
||||
|
||||
target_points = np.asarray(target_points)
|
||||
target_points = target_points.reshape(1, target_points.shape[0], 3)
|
||||
self.pcdvis.update_points(target_points)
|
||||
target_points = torch.from_numpy(target_points).float().to("cuda:0")
|
||||
scene_points = np.asarray(scene_points)
|
||||
scene_points = scene_points.reshape(1, scene_points.shape[0], 3)
|
||||
scene_points = torch.from_numpy(scene_points).float().to("cuda:0")
|
||||
|
||||
return target_points, scene_points
|
||||
|
||||
|
Reference in New Issue
Block a user