diff --git a/scripts/bt_sim_node.py b/scripts/bt_sim_node.py index 768b115..f151406 100755 --- a/scripts/bt_sim_node.py +++ b/scripts/bt_sim_node.py @@ -4,12 +4,14 @@ from actionlib import SimpleActionServer import control_msgs.msg as control_msgs from controller_manager_msgs.srv import * import cv_bridge +from franka_msgs.msg import FrankaState, ErrorRecoveryAction from franka_gripper.msg import * from geometry_msgs.msg import Twist import numpy as np import rospy from sensor_msgs.msg import JointState, Image, CameraInfo from scipy import interpolate +from std_msgs.msg import Header from threading import Thread from active_grasp.bbox import to_bbox_msg @@ -31,11 +33,12 @@ class BtSimNode: def init_plugins(self): self.plugins = [ PhysicsPlugin(self.sim), - JointStatePlugin(self.sim.arm, self.sim.gripper), + RobotStatePlugin(self.sim.arm, self.sim.gripper), MoveActionPlugin(self.sim.gripper), GraspActionPlugin(self.sim.gripper), GripperActionPlugin(), CameraPlugin(self.sim.camera), + MockActionsPlugin(), ] self.controllers = { "cartesian_velocity_controller": CartesianVelocityControllerPlugin( @@ -131,24 +134,41 @@ class PhysicsPlugin(Plugin): self.sim.step() -class JointStatePlugin(Plugin): +class RobotStatePlugin(Plugin): def __init__(self, arm, gripper, rate=30): super().__init__(rate) self.arm = arm self.gripper = gripper - self.pub = rospy.Publisher("joint_states", JointState, queue_size=10) + self.arm_state_pub = rospy.Publisher( + "/franka_state_controller/franka_states", FrankaState, queue_size=10 + ) + self.gripper_state_pub = rospy.Publisher( + "/franka_gripper/joint_states", JointState, queue_size=10 + ) + self.joint_states_pub = rospy.Publisher( + "joint_states", JointState, queue_size=10 + ) def update(self): - q, _ = self.arm.get_state() + q, dq = self.arm.get_state() width = self.gripper.read() - msg = JointState() - msg.header.stamp = rospy.Time.now() + header = Header(stamp=rospy.Time.now()) + + msg = FrankaState(header=header, q=q, dq=dq) + self.arm_state_pub.publish(msg) + + msg = JointState(header=header) + msg.name = ["panda_finger_joint1", "panda_finger_joint2"] + msg.position = [0.5 * width, 0.5 * width] + self.gripper_state_pub.publish(msg) + + msg = JointState(header=header) msg.name = ["panda_joint{}".format(i) for i in range(1, 8)] + [ "panda_finger_joint1", "panda_finger_joint2", ] msg.position = np.r_[q, 0.5 * width, 0.5 * width] - self.pub.publish(msg) + self.joint_states_pub.publish(msg) class CartesianVelocityControllerPlugin(Plugin): @@ -322,6 +342,33 @@ class CameraPlugin(Plugin): self.depth_pub.publish(msg) +class MockActionsPlugin(Plugin): + def __init__(self): + super().__init__(1) + self.init_recovery_action_server() + self.init_homing_action_server() + + def init_homing_action_server(self): + self.homing_as = SimpleActionServer( + "/franka_gripper/homing", HomingAction, auto_start=False + ) + self.homing_as.register_goal_callback(self.action_goal_cb) + self.homing_as.start() + + def init_recovery_action_server(self): + self.recovery_as = SimpleActionServer( + "/franka_control/error_recovery", ErrorRecoveryAction, auto_start=False + ) + self.recovery_as.register_goal_callback(self.action_goal_cb) + self.recovery_as.start() + + def action_goal_cb(self): + pass + + def update(self): + pass + + def main(): rospy.init_node("bt_sim") server = BtSimNode() diff --git a/src/active_grasp/policy.py b/src/active_grasp/policy.py index 91a126c..0632523 100644 --- a/src/active_grasp/policy.py +++ b/src/active_grasp/policy.py @@ -112,8 +112,16 @@ class SingleViewPolicy(Policy): self.views.append(x) self.tsdf.integrate(img, self.intrinsic, x.inv() * self.T_base_task) tsdf_grid, voxel_size = self.tsdf.get_grid(), self.tsdf.voxel_size - self.vis.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud()) - self.vis.map_cloud(self.task_frame, self.tsdf.get_map_cloud()) + + scene_cloud = self.tsdf.get_scene_cloud() + self.vis.scene_cloud(self.task_frame, np.asarray(scene_cloud.points)) + + map_cloud = self.tsdf.get_map_cloud() + self.vis.map_cloud( + self.task_frame, + np.asarray(map_cloud.points), + np.expand_dims(np.asarray(map_cloud.colors)[:, 0], 1), + ) out = self.vgn.predict(tsdf_grid) self.vis.quality(self.task_frame, voxel_size, out.qual, 0.5) @@ -142,8 +150,16 @@ class MultiViewPolicy(Policy): with Timer("tsdf_integration"): self.tsdf.integrate(img, self.intrinsic, x.inv() * self.T_base_task) - self.vis.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud()) - self.vis.map_cloud(self.task_frame, self.tsdf.get_map_cloud()) + + scene_cloud = self.tsdf.get_scene_cloud() + self.vis.scene_cloud(self.task_frame, np.asarray(scene_cloud.points)) + + map_cloud = self.tsdf.get_map_cloud() + self.vis.map_cloud( + self.task_frame, + np.asarray(map_cloud.points), + np.expand_dims(np.asarray(map_cloud.colors)[:, 0], 1), + ) with Timer("grasp_prediction"): tsdf_grid = self.tsdf.get_grid()