Seed the simulation
This commit is contained in:
@@ -10,20 +10,16 @@ from sensor_msgs.msg import JointState, Image, CameraInfo
|
||||
from threading import Thread
|
||||
import tf2_ros
|
||||
|
||||
from active_grasp.srv import Reset, ResetResponse
|
||||
from active_grasp.bbox import to_bbox_msg
|
||||
from active_grasp.srv import *
|
||||
from active_grasp.simulation import Simulation
|
||||
from active_grasp.utils import *
|
||||
from robot_utils.ros.conversions import *
|
||||
from robot_helpers.ros.conversions import *
|
||||
|
||||
|
||||
class BtSimNode:
|
||||
def __init__(self):
|
||||
self.gui = rospy.get_param("~gui", True)
|
||||
seed = rospy.get_param("~seed", None)
|
||||
|
||||
rng = np.random.default_rng(seed) if seed else np.random
|
||||
self.sim = Simulation(gui=self.gui, rng=rng)
|
||||
|
||||
self.sim = Simulation(gui=self.gui)
|
||||
self._init_plugins()
|
||||
self._advertise_services()
|
||||
self._broadcast_transforms()
|
||||
@@ -33,11 +29,13 @@ class BtSimNode:
|
||||
PhysicsPlugin(self.sim),
|
||||
JointStatePlugin(self.sim.arm, self.sim.gripper),
|
||||
ArmControllerPlugin(self.sim.arm, self.sim.controller),
|
||||
GripperControllerPlugin(self.sim.gripper),
|
||||
MoveActionPlugin(self.sim.gripper),
|
||||
GraspActionPlugin(self.sim.gripper),
|
||||
CameraPlugin(self.sim.camera),
|
||||
]
|
||||
|
||||
def _advertise_services(self):
|
||||
rospy.Service("seed", Seed, self.seed)
|
||||
rospy.Service("reset", Reset, self.reset)
|
||||
|
||||
def _broadcast_transforms(self):
|
||||
@@ -50,14 +48,16 @@ class BtSimNode:
|
||||
]
|
||||
self.static_broadcaster.sendTransform(msgs)
|
||||
|
||||
def seed(self, req):
|
||||
self.sim.seed(req.seed)
|
||||
return SeedResponse()
|
||||
|
||||
def reset(self, req):
|
||||
for plugin in self.plugins:
|
||||
plugin.is_running = False
|
||||
rospy.sleep(1.0) # TODO replace with a read-write lock
|
||||
|
||||
bbox = self.sim.reset()
|
||||
res = ResetResponse(to_bbox_msg(bbox))
|
||||
|
||||
for plugin in self.plugins:
|
||||
plugin.is_running = True
|
||||
return res
|
||||
@@ -137,45 +137,54 @@ class ArmControllerPlugin(Plugin):
|
||||
self.arm.set_desired_joint_velocities(cmd)
|
||||
|
||||
|
||||
class GripperControllerPlugin(Plugin):
|
||||
class MoveActionPlugin(Plugin):
|
||||
def __init__(self, gripper, rate=10):
|
||||
super().__init__(rate)
|
||||
self.gripper = gripper
|
||||
self.dt = 1.0 / self.rate
|
||||
self._init_move_action_server()
|
||||
self._init_grasp_action_server()
|
||||
self._init_action_server()
|
||||
|
||||
def _init_move_action_server(self):
|
||||
def _init_action_server(self):
|
||||
name = "/franka_gripper/move"
|
||||
self.move_server = SimpleActionServer(name, MoveAction, auto_start=False)
|
||||
self.move_server.register_goal_callback(self._move_action_goal_cb)
|
||||
self.move_server.start()
|
||||
self.action_server = SimpleActionServer(name, MoveAction, auto_start=False)
|
||||
self.action_server.register_goal_callback(self._action_goal_cb)
|
||||
self.action_server.start()
|
||||
|
||||
def _init_grasp_action_server(self):
|
||||
name = "/franka_gripper/grasp"
|
||||
self.grasp_server = SimpleActionServer(name, GraspAction, auto_start=False)
|
||||
self.grasp_server.register_goal_callback(self._grasp_action_goal_cb)
|
||||
self.grasp_server.start()
|
||||
|
||||
def _move_action_goal_cb(self):
|
||||
self.elapsed_time_since_move_action_goal = 0.0
|
||||
goal = self.move_server.accept_new_goal()
|
||||
self.gripper.set_desired_width(goal.width)
|
||||
|
||||
def _grasp_action_goal_cb(self):
|
||||
self.elapsed_time_since_grasp_action_goal = 0.0
|
||||
goal = self.grasp_server.accept_new_goal()
|
||||
def _action_goal_cb(self):
|
||||
self.elapsed_time = 0.0
|
||||
goal = self.action_server.accept_new_goal()
|
||||
self.gripper.set_desired_width(goal.width)
|
||||
|
||||
def _update(self):
|
||||
if self.move_server.is_active():
|
||||
self.elapsed_time_since_move_action_goal += self.dt
|
||||
if self.elapsed_time_since_move_action_goal > 1.0:
|
||||
self.move_server.set_succeeded()
|
||||
if self.grasp_server.is_active():
|
||||
self.elapsed_time_since_grasp_action_goal += self.dt
|
||||
if self.elapsed_time_since_grasp_action_goal > 1.0:
|
||||
self.grasp_server.set_succeeded()
|
||||
if self.action_server.is_active():
|
||||
self.elapsed_time += self.dt
|
||||
if self.elapsed_time > 1.0:
|
||||
self.action_server.set_succeeded()
|
||||
|
||||
|
||||
class GraspActionPlugin(Plugin):
|
||||
def __init__(self, gripper, rate=10):
|
||||
super().__init__(rate)
|
||||
self.gripper = gripper
|
||||
self.dt = 1.0 / self.rate
|
||||
self._init_action_server()
|
||||
|
||||
def _init_action_server(self):
|
||||
name = "/franka_gripper/grasp"
|
||||
self.action_server = SimpleActionServer(name, GraspAction, auto_start=False)
|
||||
self.action_server.register_goal_callback(self._action_goal_cb)
|
||||
self.action_server.start()
|
||||
|
||||
def _action_goal_cb(self):
|
||||
self.elapsed_time = 0.0
|
||||
goal = self.action_server.accept_new_goal()
|
||||
self.gripper.set_desired_width(goal.width)
|
||||
|
||||
def _update(self):
|
||||
if self.action_server.is_active():
|
||||
self.elapsed_time += self.dt
|
||||
if self.elapsed_time > 1.0:
|
||||
self.action_server.set_succeeded()
|
||||
|
||||
|
||||
class CameraPlugin(Plugin):
|
||||
@@ -201,8 +210,8 @@ class CameraPlugin(Plugin):
|
||||
msg.header.stamp = stamp
|
||||
self.info_pub.publish(msg)
|
||||
|
||||
img = self.camera.get_image()
|
||||
msg = self.cv_bridge.cv2_to_imgmsg(img.depth)
|
||||
_, depth, _ = self.camera.get_image()
|
||||
msg = self.cv_bridge.cv2_to_imgmsg(depth)
|
||||
msg.header.stamp = stamp
|
||||
self.depth_pub.publish(msg)
|
||||
|
||||
|
@@ -1,10 +1,24 @@
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import rospy
|
||||
from tqdm import tqdm
|
||||
|
||||
from active_grasp.controller import *
|
||||
from active_grasp.policy import make, registry
|
||||
from active_grasp.srv import Seed
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, logdir, policy):
|
||||
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
name = "{}_policy={}".format(stamp, policy)
|
||||
self.path = logdir / (name + ".csv")
|
||||
|
||||
def log_run(self, info):
|
||||
df = pd.DataFrame.from_records([info])
|
||||
df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False)
|
||||
|
||||
|
||||
def create_parser():
|
||||
@@ -12,17 +26,24 @@ def create_parser():
|
||||
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||
parser.add_argument("--runs", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=Path, default="logs")
|
||||
parser.add_argument("--desc", type=str, default="")
|
||||
parser.add_argument("--seed", type=int, default=12)
|
||||
return parser
|
||||
|
||||
|
||||
def seed_simulation(seed):
|
||||
rospy.ServiceProxy("seed", Seed)(seed)
|
||||
rospy.sleep(1.0)
|
||||
|
||||
|
||||
def main():
|
||||
rospy.init_node("active_grasp")
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
policy = make(args.policy)
|
||||
controller = GraspController(policy)
|
||||
logger = Logger(args.logdir, args.policy, args.desc)
|
||||
logger = Logger(args.logdir, args.policy)
|
||||
|
||||
seed_simulation(args.seed)
|
||||
|
||||
for _ in tqdm(range(args.runs)):
|
||||
info = controller.run()
|
||||
|
Reference in New Issue
Block a user