nbv_sim/active_grasp/controller.py

134 lines
4.8 KiB
Python
Raw Normal View History

2021-09-04 15:50:29 +02:00
from controller_manager_msgs.srv import *
2021-08-03 18:11:30 +02:00
import copy
import cv_bridge
2021-09-03 22:39:17 +02:00
from geometry_msgs.msg import Twist
2021-07-06 14:00:04 +02:00
import numpy as np
import rospy
2021-08-06 15:23:50 +02:00
from sensor_msgs.msg import Image
2021-07-06 14:00:04 +02:00
2021-08-05 13:45:22 +02:00
from .bbox import from_bbox_msg
from .timer import Timer
2021-07-07 16:29:50 +02:00
from active_grasp.srv import Reset, ResetRequest
2021-08-03 18:11:30 +02:00
from robot_helpers.ros import tf
from robot_helpers.ros.conversions import *
2021-07-22 11:05:30 +02:00
from robot_helpers.ros.panda import PandaGripperClient
2021-09-04 15:50:29 +02:00
from robot_helpers.ros.moveit import MoveItClient
2021-07-22 11:05:30 +02:00
from robot_helpers.spatial import Rotation, Transform
2021-07-06 14:00:04 +02:00
class GraspController:
2021-08-06 15:23:50 +02:00
def __init__(self, policy):
self.policy = policy
2021-08-03 18:11:30 +02:00
self.load_parameters()
2021-09-04 15:50:29 +02:00
self.init_service_proxies()
2021-08-03 18:11:30 +02:00
self.init_robot_connection()
2021-09-06 16:28:20 +02:00
self.init_moveit()
2021-08-03 18:11:30 +02:00
self.init_camera_stream()
def load_parameters(self):
self.base_frame = rospy.get_param("~base_frame_id")
self.ee_frame = rospy.get_param("~ee_frame_id")
self.cam_frame = rospy.get_param("~camera/frame_id")
self.depth_topic = rospy.get_param("~camera/depth_topic")
self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv()
2021-09-04 15:50:29 +02:00
def init_service_proxies(self):
self.reset_env = rospy.ServiceProxy("reset", Reset)
self.switch_controller = rospy.ServiceProxy(
"controller_manager/switch_controller", SwitchController
)
2021-08-03 18:11:30 +02:00
def init_robot_connection(self):
2021-09-03 22:39:17 +02:00
self.cartesian_vel_pub = rospy.Publisher("command", Twist, queue_size=10)
2021-07-22 11:05:30 +02:00
self.gripper = PandaGripperClient()
2021-09-06 16:28:20 +02:00
def init_moveit(self):
2021-09-04 15:50:29 +02:00
self.moveit = MoveItClient("panda_arm")
2021-09-06 16:28:20 +02:00
rospy.sleep(1.0) # wait for connections to be established
2021-09-08 16:50:53 +02:00
# msg = to_pose_stamped_msg(Transform.t([0.4, 0, 0.4]), self.base_frame)
# self.moveit.scene.add_box("table", msg, size=(0.5, 0.5, 0.02))
2021-09-11 12:00:52 +02:00
self.policy.moveit = self.moveit
2021-09-04 15:50:29 +02:00
def switch_to_cartesian_velocity_control(self):
req = SwitchControllerRequest()
req.start_controllers = ["cartesian_velocity_controller"]
req.stop_controllers = ["position_joint_trajectory_controller"]
self.switch_controller(req)
def switch_to_joint_trajectory_control(self):
req = SwitchControllerRequest()
req.start_controllers = ["position_joint_trajectory_controller"]
req.stop_controllers = ["cartesian_velocity_controller"]
self.switch_controller(req)
2021-07-22 11:05:30 +02:00
2021-08-03 18:11:30 +02:00
def init_camera_stream(self):
self.cv_bridge = cv_bridge.CvBridge()
rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1)
def sensor_cb(self, msg):
self.latest_depth_msg = msg
2021-07-06 14:00:04 +02:00
def run(self):
2021-08-03 18:11:30 +02:00
bbox = self.reset()
2021-09-04 15:50:29 +02:00
self.switch_to_cartesian_velocity_control()
2021-07-22 11:05:30 +02:00
with Timer("search_time"):
2021-08-03 18:11:30 +02:00
grasp = self.search_grasp(bbox)
2021-09-04 15:50:29 +02:00
self.switch_to_joint_trajectory_control()
2021-09-08 16:50:53 +02:00
with Timer("grasp_time"):
2021-09-04 15:50:29 +02:00
res = self.execute_grasp(grasp)
2021-08-03 18:11:30 +02:00
return self.collect_info(res)
2021-07-22 11:05:30 +02:00
2021-08-03 18:11:30 +02:00
def reset(self):
res = self.reset_env(ResetRequest())
2021-07-07 16:29:50 +02:00
rospy.sleep(1.0) # wait for states to be updated
return from_bbox_msg(res.bbox)
2021-07-06 14:00:04 +02:00
2021-08-03 18:11:30 +02:00
def search_grasp(self, bbox):
2021-07-07 15:08:32 +02:00
self.policy.activate(bbox)
2021-07-06 14:00:04 +02:00
r = rospy.Rate(self.policy.rate)
2021-09-04 15:50:29 +02:00
while not self.policy.done:
2021-09-03 22:39:17 +02:00
img, pose = self.get_state()
cmd = self.policy.update(img, pose)
2021-09-04 15:50:29 +02:00
self.cartesian_vel_pub.publish(to_twist_msg(cmd))
r.sleep()
2021-07-06 14:00:04 +02:00
return self.policy.best_grasp
2021-08-03 18:11:30 +02:00
def get_state(self):
msg = copy.deepcopy(self.latest_depth_msg)
img = self.cv_bridge.imgmsg_to_cv2(msg).astype(np.float32)
2021-09-03 22:39:17 +02:00
pose = tf.lookup(self.base_frame, self.cam_frame, msg.header.stamp)
return img, pose
2021-08-03 18:11:30 +02:00
def execute_grasp(self, grasp):
2021-07-12 13:12:36 +02:00
if not grasp:
return "aborted"
2021-08-03 18:11:30 +02:00
T_base_grasp = self.postprocess(grasp.pose)
2021-09-04 15:50:29 +02:00
self.gripper.move(0.08)
self.moveit.goto(T_base_grasp * Transform.t([0, 0, -0.05]) * self.T_grasp_ee)
self.moveit.goto(T_base_grasp * self.T_grasp_ee)
2021-07-06 14:00:04 +02:00
self.gripper.grasp()
2021-09-04 16:52:11 +02:00
self.moveit.goto(Transform.t([0, 0, 0.1]) * T_base_grasp * self.T_grasp_ee)
2021-07-12 13:12:36 +02:00
success = self.gripper.read() > 0.005
return "succeeded" if success else "failed"
2021-07-06 14:00:04 +02:00
2021-08-03 18:11:30 +02:00
def postprocess(self, T_base_grasp):
rot = T_base_grasp.rotation
2021-09-04 15:50:29 +02:00
if rot.as_matrix()[:, 0][0] < 0: # Ensure that the camera is pointing forward
2021-08-03 18:11:30 +02:00
T_base_grasp.rotation = rot * Rotation.from_euler("z", np.pi)
2021-09-06 13:36:14 +02:00
T_base_grasp *= Transform.t([0.0, 0.0, 0.01])
2021-08-03 18:11:30 +02:00
return T_base_grasp
2021-07-12 13:12:36 +02:00
2021-08-03 18:11:30 +02:00
def collect_info(self, result):
2021-08-25 18:29:10 +02:00
points = [p.translation for p in self.policy.views]
2021-07-12 13:12:36 +02:00
d = np.sum([np.linalg.norm(p2 - p1) for p1, p2 in zip(points, points[1:])])
info = {
"result": result,
2021-08-25 18:29:10 +02:00
"view_count": len(points),
2021-07-22 11:05:30 +02:00
"distance": d,
2021-07-12 13:12:36 +02:00
}
info.update(Timer.timers)
return info