successfully applies segmentation input for policy class
This commit is contained in:
@@ -18,7 +18,7 @@ class ActivePerceptionPolicy(MultiViewPolicy):
|
||||
def activate(self, bbox, view_sphere):
|
||||
super().activate(bbox, view_sphere)
|
||||
|
||||
def update(self, img, x, q):
|
||||
def update(self, img, seg, x, q):
|
||||
self.depth_image_to_ap_input(img)
|
||||
# if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
|
||||
# self.done = True
|
||||
|
@@ -32,6 +32,7 @@ class GraspController:
|
||||
self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv()
|
||||
self.cam_frame = rospy.get_param("~camera/frame_id")
|
||||
self.depth_topic = rospy.get_param("~camera/depth_topic")
|
||||
self.seg_topic = rospy.get_param("~camera/seg_topic")
|
||||
self.min_z_dist = rospy.get_param("~camera/min_z_dist")
|
||||
self.control_rate = rospy.get_param("~control_rate")
|
||||
self.linear_vel = rospy.get_param("~linear_vel")
|
||||
@@ -71,10 +72,14 @@ class GraspController:
|
||||
|
||||
def init_camera_stream(self):
|
||||
self.cv_bridge = cv_bridge.CvBridge()
|
||||
rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1)
|
||||
rospy.Subscriber(self.depth_topic, Image, self.depth_cb, queue_size=1)
|
||||
rospy.Subscriber(self.seg_topic, Image, self.seg_cb, queue_size=1)
|
||||
|
||||
def sensor_cb(self, msg):
|
||||
def depth_cb(self, msg):
|
||||
self.latest_depth_msg = msg
|
||||
|
||||
def seg_cb(self, msg):
|
||||
self.latest_seg_msg = msg
|
||||
|
||||
def run(self):
|
||||
bbox = self.reset()
|
||||
@@ -102,8 +107,8 @@ class GraspController:
|
||||
timer = rospy.Timer(rospy.Duration(1.0 / self.control_rate), self.send_vel_cmd)
|
||||
r = rospy.Rate(self.policy_rate)
|
||||
while not self.policy.done:
|
||||
img, pose, q = self.get_state()
|
||||
self.policy.update(img, pose, q)
|
||||
depth_img, seg_image, pose, q = self.get_state()
|
||||
self.policy.update(depth_img, seg_image, pose, q)
|
||||
r.sleep()
|
||||
rospy.sleep(0.2) # Wait for a zero command to be sent to the robot.
|
||||
self.policy.deactivate()
|
||||
@@ -113,9 +118,11 @@ class GraspController:
|
||||
def get_state(self):
|
||||
q, _ = self.arm.get_state()
|
||||
msg = copy.deepcopy(self.latest_depth_msg)
|
||||
img = self.cv_bridge.imgmsg_to_cv2(msg).astype(np.float32) * 0.001
|
||||
depth_img = self.cv_bridge.imgmsg_to_cv2(msg).astype(np.float32) * 0.001
|
||||
msg = copy.deepcopy(self.latest_seg_msg)
|
||||
seg_img = self.cv_bridge.imgmsg_to_cv2(msg).astype(np.float32)
|
||||
pose = tf.lookup(self.base_frame, self.cam_frame, msg.header.stamp)
|
||||
return img, pose, q
|
||||
return depth_img, seg_img, pose, q
|
||||
|
||||
def send_vel_cmd(self, event):
|
||||
if self.policy.x_d is None or self.policy.done:
|
||||
|
@@ -84,7 +84,7 @@ class NextBestView(MultiViewPolicy):
|
||||
def activate(self, bbox, view_sphere):
|
||||
super().activate(bbox, view_sphere)
|
||||
|
||||
def update(self, img, x, q):
|
||||
def update(self, img, seg, x, q):
|
||||
if len(self.views) > self.max_views or self.best_grasp_prediction_is_stable():
|
||||
self.done = True
|
||||
else:
|
||||
|
@@ -75,7 +75,7 @@ class Policy:
|
||||
rospy.sleep(1.0) # Wait for tf tree to be updated
|
||||
self.vis.roi(self.task_frame, 0.3)
|
||||
|
||||
def update(self, img, x, q):
|
||||
def update(self, img, seg, x, q):
|
||||
raise NotImplementedError
|
||||
|
||||
def filter_grasps(self, out, q):
|
||||
@@ -106,7 +106,7 @@ def select_best_grasp(grasps, qualities):
|
||||
|
||||
|
||||
class SingleViewPolicy(Policy):
|
||||
def update(self, img, x, q):
|
||||
def update(self, img, seg, x, q):
|
||||
linear, _ = compute_error(self.x_d, x)
|
||||
if np.linalg.norm(linear) < 0.02:
|
||||
self.views.append(x)
|
||||
|
Reference in New Issue
Block a user