From 8989115bd7844812fda0a4b6ed88096946acc881 Mon Sep 17 00:00:00 2001 From: Michel Breyer Date: Wed, 14 Jul 2021 16:52:53 +0200 Subject: [PATCH] Publish simulation state in multiple threads --- active_grasp/controller.py | 2 +- active_grasp/simulation.py | 5 +- launch/active_grasp.launch | 2 +- launch/active_grasp.yaml | 7 +- scripts/bt_sim_node.py | 201 +++++++++++++++++++++---------------- 5 files changed, 124 insertions(+), 93 deletions(-) diff --git a/active_grasp/controller.py b/active_grasp/controller.py index fb0036c..c531405 100644 --- a/active_grasp/controller.py +++ b/active_grasp/controller.py @@ -59,7 +59,7 @@ class GraspController: # Approach grasp pose. self.controller.send_target(T_B_G * self.T_G_EE) - rospy.sleep(1.0) + rospy.sleep(2.0) # Close the fingers. self.gripper.grasp() diff --git a/active_grasp/simulation.py b/active_grasp/simulation.py index 9e2ad9b..e0813ef 100644 --- a/active_grasp/simulation.py +++ b/active_grasp/simulation.py @@ -19,7 +19,6 @@ class Simulation(BtSim): self.load_table() self.load_robot() self.load_controller() - self.reset() def configure_visualizer(self): # p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0) @@ -45,7 +44,9 @@ class Simulation(BtSim): self.camera = BtCamera(320, 240, 1.047, 0.1, 1.0, self.arm.uid, 11) def load_controller(self): - self.controller = CartesianPoseController(self.model, self.arm.ee_frame, None) + q, _ = self.arm.get_state() + x0 = self.model.pose(self.arm.ee_frame, q) + self.controller = CartesianPoseController(self.model, self.arm.ee_frame, x0) def reset(self): self.remove_all_objects() diff --git a/launch/active_grasp.launch b/launch/active_grasp.launch index dcafce5..a034301 100644 --- a/launch/active_grasp.launch +++ b/launch/active_grasp.launch @@ -1,6 +1,6 @@ - + diff --git a/launch/active_grasp.yaml b/launch/active_grasp.yaml index 727018e..a1cd0e8 100644 --- a/launch/active_grasp.yaml +++ b/launch/active_grasp.yaml @@ -1,7 +1,6 @@ bt_sim: gui: True seed: 12 - cam_pub_rate: 10 active_grasp: frame_id: task @@ -10,9 +9,9 @@ active_grasp: ee_frame_id: panda_hand ee_grasp_offset: [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.065] camera: - frame_id: cam_optical_frame - info_topic: /cam/depth/camera_info - depth_topic: /cam/depth/image_raw + frame_id: camera_optical_frame + info_topic: /camera/depth/camera_info + depth_topic: /camera/depth/image_raw vgn: model: $(find vgn)/assets/models/vgn_conv.pth diff --git a/scripts/bt_sim_node.py b/scripts/bt_sim_node.py index be1ed42..13dda65 100755 --- a/scripts/bt_sim_node.py +++ b/scripts/bt_sim_node.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 -import actionlib +from actionlib import SimpleActionServer import cv_bridge -import franka_gripper.msg +from franka_gripper.msg import * from geometry_msgs.msg import PoseStamped import numpy as np import rospy from sensor_msgs.msg import JointState, Image, CameraInfo +from threading import Thread import tf2_ros from active_grasp.srv import Reset, ResetResponse @@ -17,31 +18,29 @@ from robot_utils.ros.conversions import * class BtSimNode: def __init__(self): - self.load_parameters() - rng = self.get_rng() - self.sim = Simulation(gui=self.gui, rng=rng) - self.robot_state_interface = RobotStateInterface(self.sim.arm, self.sim.gripper) - self.arm_interface = ArmInterface(self.sim.arm, self.sim.controller) - self.gripper_interface = GripperInterface(self.sim.gripper) - self.camera_interface = CameraInterface(self.sim.camera) - self.step_cnt = 0 - self.reset_requested = False - - self.advertise_services() - self.broadcast_transforms() - - def load_parameters(self): self.gui = rospy.get_param("~gui", True) - self.cam_pub_rate = rospy.get_param("~cam_pub_rate") - - def get_rng(self): seed = rospy.get_param("~seed", None) - return np.random.default_rng(seed) if seed else np.random - def advertise_services(self): + rng = np.random.default_rng(seed) if seed else np.random + self.sim = Simulation(gui=self.gui, rng=rng) + + self._init_plugins() + self._advertise_services() + self._broadcast_transforms() + + def _init_plugins(self): + self.plugins = [ + PhysicsPlugin(self.sim), + JointStatePlugin(self.sim.arm, self.sim.gripper), + ArmControllerPlugin(self.sim.arm, self.sim.controller), + GripperControllerPlugin(self.sim.gripper), + CameraPlugin(self.sim.camera), + ] + + def _advertise_services(self): rospy.Service("reset", Reset, self.reset) - def broadcast_transforms(self): + def _broadcast_transforms(self): self.static_broadcaster = tf2_ros.StaticTransformBroadcaster() msgs = [ to_transform_stamped_msg(self.sim.T_W_B, "world", "panda_link0"), @@ -52,38 +51,63 @@ class BtSimNode: self.static_broadcaster.sendTransform(msgs) def reset(self, req): - self.reset_requested = True - rospy.sleep(1.0) # wait for the latest sim step to finish + for plugin in self.plugins: + plugin.is_running = False + rospy.sleep(1.0) # TODO replace with a read-write lock + bbox = self.sim.reset() res = ResetResponse(to_bbox_msg(bbox)) - self.step_cnt = 0 - self.reset_requested = False + + for plugin in self.plugins: + plugin.is_running = True return res def run(self): - rate = rospy.Rate(self.sim.rate) + self._start_plugins() + rospy.spin() + + def _start_plugins(self): + for plugin in self.plugins: + plugin.thread.start() + plugin.is_running = True + + +class Plugin: + """A plugin that spins at a constant rate in its own thread.""" + + def __init__(self, rate): + self.rate = rate + self.thread = Thread(target=self._loop, daemon=True) + self.is_running = False + + def _loop(self): + rate = rospy.Rate(self.rate) while not rospy.is_shutdown(): - if not self.reset_requested: - self.handle_updates() - self.sim.step() - self.step_cnt = (self.step_cnt + 1) % self.sim.rate + if self.is_running: + self._update() rate.sleep() - def handle_updates(self): - self.robot_state_interface.update() - self.arm_interface.update() - self.gripper_interface.update(self.sim.dt) - if self.step_cnt % int(self.sim.rate / self.cam_pub_rate) == 0: - self.camera_interface.update() + def _update(self): + raise NotImplementedError -class RobotStateInterface: - def __init__(self, arm, gripper): +class PhysicsPlugin(Plugin): + def __init__(self, sim): + super().__init__(sim.rate) + self.sim = sim + + def _update(self): + self.sim.step() + + +class JointStatePlugin(Plugin): + def __init__(self, arm, gripper, rate=30): + super().__init__(rate) self.arm = arm self.gripper = gripper - self.joint_pub = rospy.Publisher("joint_states", JointState, queue_size=10) + self.pub = rospy.Publisher("joint_states", JointState, queue_size=10) - def update(self): + def _update(self): q, _ = self.arm.get_state() width = self.gripper.read() msg = JointState() @@ -93,86 +117,93 @@ class RobotStateInterface: "panda_finger_joint2", ] msg.position = np.r_[q, 0.5 * width, 0.5 * width] - self.joint_pub.publish(msg) + self.pub.publish(msg) -class ArmInterface: - def __init__(self, arm, controller): +class ArmControllerPlugin(Plugin): + def __init__(self, arm, controller, rate=30): + super().__init__(rate) self.arm = arm self.controller = controller - rospy.Subscriber("command", PoseStamped, self.target_cb) + rospy.Subscriber("command", PoseStamped, self._target_cb) - def update(self): + def _target_cb(self, msg): + assert msg.header.frame_id == self.arm.base_frame + self.controller.x_d = from_pose_msg(msg.pose) + + def _update(self): q, _ = self.arm.get_state() cmd = self.controller.update(q) self.arm.set_desired_joint_velocities(cmd) - def target_cb(self, msg): - assert msg.header.frame_id == self.arm.base_frame - self.controller.x_d = from_pose_msg(msg.pose) - -class GripperInterface: - def __init__(self, gripper): +class GripperControllerPlugin(Plugin): + def __init__(self, gripper, rate=10): + super().__init__(rate) self.gripper = gripper - self.move_server = actionlib.SimpleActionServer( - "/franka_gripper/move", - franka_gripper.msg.MoveAction, - auto_start=False, - ) - self.move_server.register_goal_callback(self.move_action_goal_cb) + self.dt = 1.0 / self.rate + self._init_move_action_server() + self._init_grasp_action_server() + + def _init_move_action_server(self): + name = "/franka_gripper/move" + self.move_server = SimpleActionServer(name, MoveAction, auto_start=False) + self.move_server.register_goal_callback(self._move_action_goal_cb) self.move_server.start() - self.grasp_server = actionlib.SimpleActionServer( - "/franka_gripper/grasp", - franka_gripper.msg.GraspAction, - auto_start=False, - ) - self.grasp_server.register_goal_callback(self.grasp_action_goal_cb) + def _init_grasp_action_server(self): + name = "/franka_gripper/grasp" + self.grasp_server = SimpleActionServer(name, GraspAction, auto_start=False) + self.grasp_server.register_goal_callback(self._grasp_action_goal_cb) self.grasp_server.start() - def move_action_goal_cb(self): + def _move_action_goal_cb(self): self.elapsed_time_since_move_action_goal = 0.0 goal = self.move_server.accept_new_goal() self.gripper.set_desired_width(goal.width) - def grasp_action_goal_cb(self): + def _grasp_action_goal_cb(self): self.elapsed_time_since_grasp_action_goal = 0.0 goal = self.grasp_server.accept_new_goal() self.gripper.set_desired_width(goal.width) - def update(self, dt): + def _update(self): if self.move_server.is_active(): - self.elapsed_time_since_move_action_goal += dt + self.elapsed_time_since_move_action_goal += self.dt if self.elapsed_time_since_move_action_goal > 1.0: self.move_server.set_succeeded() if self.grasp_server.is_active(): - self.elapsed_time_since_grasp_action_goal += dt + self.elapsed_time_since_grasp_action_goal += self.dt if self.elapsed_time_since_grasp_action_goal > 1.0: self.grasp_server.set_succeeded() -class CameraInterface: - def __init__(self, camera): +class CameraPlugin(Plugin): + def __init__(self, camera, name="camera", rate=10): + super().__init__(rate) self.camera = camera + self.name = name self.cv_bridge = cv_bridge.CvBridge() - self.cam_info_msg = to_camera_info_msg(self.camera.intrinsic) - self.cam_info_msg.header.frame_id = "cam_optical_frame" - self.cam_info_pub = rospy.Publisher( - "/cam/depth/camera_info", - CameraInfo, - queue_size=10, - ) - self.depth_pub = rospy.Publisher("/cam/depth/image_raw", Image, queue_size=10) + self._init_publishers() - def update(self): + def _init_publishers(self): + topic = self.name + "/depth/camera_info" + self.info_pub = rospy.Publisher(topic, CameraInfo, queue_size=10) + topic = self.name + "/depth/image_raw" + self.depth_pub = rospy.Publisher(topic, Image, queue_size=10) + + def _update(self): stamp = rospy.Time.now() - self.cam_info_msg.header.stamp = stamp - self.cam_info_pub.publish(self.cam_info_msg) + + msg = to_camera_info_msg(self.camera.intrinsic) + msg.header.frame_id = self.name + "_optical_frame" + msg.header.stamp = stamp + self.info_pub.publish(msg) + img = self.camera.get_image() - depth_msg = self.cv_bridge.cv2_to_imgmsg(img.depth) - depth_msg.header.stamp = stamp - self.depth_pub.publish(depth_msg) + msg = self.cv_bridge.cv2_to_imgmsg(img.depth) + msg.header.stamp = stamp + self.depth_pub.publish(msg) def main():