Seed the simulation
This commit is contained in:
@@ -10,20 +10,16 @@ from sensor_msgs.msg import JointState, Image, CameraInfo
|
||||
from threading import Thread
|
||||
import tf2_ros
|
||||
|
||||
from active_grasp.srv import Reset, ResetResponse
|
||||
from active_grasp.bbox import to_bbox_msg
|
||||
from active_grasp.srv import *
|
||||
from active_grasp.simulation import Simulation
|
||||
from active_grasp.utils import *
|
||||
from robot_utils.ros.conversions import *
|
||||
from robot_helpers.ros.conversions import *
|
||||
|
||||
|
||||
class BtSimNode:
|
||||
def __init__(self):
|
||||
self.gui = rospy.get_param("~gui", True)
|
||||
seed = rospy.get_param("~seed", None)
|
||||
|
||||
rng = np.random.default_rng(seed) if seed else np.random
|
||||
self.sim = Simulation(gui=self.gui, rng=rng)
|
||||
|
||||
self.sim = Simulation(gui=self.gui)
|
||||
self._init_plugins()
|
||||
self._advertise_services()
|
||||
self._broadcast_transforms()
|
||||
@@ -33,11 +29,13 @@ class BtSimNode:
|
||||
PhysicsPlugin(self.sim),
|
||||
JointStatePlugin(self.sim.arm, self.sim.gripper),
|
||||
ArmControllerPlugin(self.sim.arm, self.sim.controller),
|
||||
GripperControllerPlugin(self.sim.gripper),
|
||||
MoveActionPlugin(self.sim.gripper),
|
||||
GraspActionPlugin(self.sim.gripper),
|
||||
CameraPlugin(self.sim.camera),
|
||||
]
|
||||
|
||||
def _advertise_services(self):
|
||||
rospy.Service("seed", Seed, self.seed)
|
||||
rospy.Service("reset", Reset, self.reset)
|
||||
|
||||
def _broadcast_transforms(self):
|
||||
@@ -50,14 +48,16 @@ class BtSimNode:
|
||||
]
|
||||
self.static_broadcaster.sendTransform(msgs)
|
||||
|
||||
def seed(self, req):
|
||||
self.sim.seed(req.seed)
|
||||
return SeedResponse()
|
||||
|
||||
def reset(self, req):
|
||||
for plugin in self.plugins:
|
||||
plugin.is_running = False
|
||||
rospy.sleep(1.0) # TODO replace with a read-write lock
|
||||
|
||||
bbox = self.sim.reset()
|
||||
res = ResetResponse(to_bbox_msg(bbox))
|
||||
|
||||
for plugin in self.plugins:
|
||||
plugin.is_running = True
|
||||
return res
|
||||
@@ -137,45 +137,54 @@ class ArmControllerPlugin(Plugin):
|
||||
self.arm.set_desired_joint_velocities(cmd)
|
||||
|
||||
|
||||
class GripperControllerPlugin(Plugin):
|
||||
class MoveActionPlugin(Plugin):
|
||||
def __init__(self, gripper, rate=10):
|
||||
super().__init__(rate)
|
||||
self.gripper = gripper
|
||||
self.dt = 1.0 / self.rate
|
||||
self._init_move_action_server()
|
||||
self._init_grasp_action_server()
|
||||
self._init_action_server()
|
||||
|
||||
def _init_move_action_server(self):
|
||||
def _init_action_server(self):
|
||||
name = "/franka_gripper/move"
|
||||
self.move_server = SimpleActionServer(name, MoveAction, auto_start=False)
|
||||
self.move_server.register_goal_callback(self._move_action_goal_cb)
|
||||
self.move_server.start()
|
||||
self.action_server = SimpleActionServer(name, MoveAction, auto_start=False)
|
||||
self.action_server.register_goal_callback(self._action_goal_cb)
|
||||
self.action_server.start()
|
||||
|
||||
def _init_grasp_action_server(self):
|
||||
name = "/franka_gripper/grasp"
|
||||
self.grasp_server = SimpleActionServer(name, GraspAction, auto_start=False)
|
||||
self.grasp_server.register_goal_callback(self._grasp_action_goal_cb)
|
||||
self.grasp_server.start()
|
||||
|
||||
def _move_action_goal_cb(self):
|
||||
self.elapsed_time_since_move_action_goal = 0.0
|
||||
goal = self.move_server.accept_new_goal()
|
||||
self.gripper.set_desired_width(goal.width)
|
||||
|
||||
def _grasp_action_goal_cb(self):
|
||||
self.elapsed_time_since_grasp_action_goal = 0.0
|
||||
goal = self.grasp_server.accept_new_goal()
|
||||
def _action_goal_cb(self):
|
||||
self.elapsed_time = 0.0
|
||||
goal = self.action_server.accept_new_goal()
|
||||
self.gripper.set_desired_width(goal.width)
|
||||
|
||||
def _update(self):
|
||||
if self.move_server.is_active():
|
||||
self.elapsed_time_since_move_action_goal += self.dt
|
||||
if self.elapsed_time_since_move_action_goal > 1.0:
|
||||
self.move_server.set_succeeded()
|
||||
if self.grasp_server.is_active():
|
||||
self.elapsed_time_since_grasp_action_goal += self.dt
|
||||
if self.elapsed_time_since_grasp_action_goal > 1.0:
|
||||
self.grasp_server.set_succeeded()
|
||||
if self.action_server.is_active():
|
||||
self.elapsed_time += self.dt
|
||||
if self.elapsed_time > 1.0:
|
||||
self.action_server.set_succeeded()
|
||||
|
||||
|
||||
class GraspActionPlugin(Plugin):
|
||||
def __init__(self, gripper, rate=10):
|
||||
super().__init__(rate)
|
||||
self.gripper = gripper
|
||||
self.dt = 1.0 / self.rate
|
||||
self._init_action_server()
|
||||
|
||||
def _init_action_server(self):
|
||||
name = "/franka_gripper/grasp"
|
||||
self.action_server = SimpleActionServer(name, GraspAction, auto_start=False)
|
||||
self.action_server.register_goal_callback(self._action_goal_cb)
|
||||
self.action_server.start()
|
||||
|
||||
def _action_goal_cb(self):
|
||||
self.elapsed_time = 0.0
|
||||
goal = self.action_server.accept_new_goal()
|
||||
self.gripper.set_desired_width(goal.width)
|
||||
|
||||
def _update(self):
|
||||
if self.action_server.is_active():
|
||||
self.elapsed_time += self.dt
|
||||
if self.elapsed_time > 1.0:
|
||||
self.action_server.set_succeeded()
|
||||
|
||||
|
||||
class CameraPlugin(Plugin):
|
||||
@@ -201,8 +210,8 @@ class CameraPlugin(Plugin):
|
||||
msg.header.stamp = stamp
|
||||
self.info_pub.publish(msg)
|
||||
|
||||
img = self.camera.get_image()
|
||||
msg = self.cv_bridge.cv2_to_imgmsg(img.depth)
|
||||
_, depth, _ = self.camera.get_image()
|
||||
msg = self.cv_bridge.cv2_to_imgmsg(depth)
|
||||
msg.header.stamp = stamp
|
||||
self.depth_pub.publish(msg)
|
||||
|
||||
|
Reference in New Issue
Block a user