Seed the simulation

This commit is contained in:
Michel Breyer
2021-07-22 11:05:30 +02:00
parent ed40db562e
commit 1aa676f340
17 changed files with 348 additions and 400 deletions

View File

@@ -10,20 +10,16 @@ from sensor_msgs.msg import JointState, Image, CameraInfo
from threading import Thread
import tf2_ros
from active_grasp.srv import Reset, ResetResponse
from active_grasp.bbox import to_bbox_msg
from active_grasp.srv import *
from active_grasp.simulation import Simulation
from active_grasp.utils import *
from robot_utils.ros.conversions import *
from robot_helpers.ros.conversions import *
class BtSimNode:
def __init__(self):
self.gui = rospy.get_param("~gui", True)
seed = rospy.get_param("~seed", None)
rng = np.random.default_rng(seed) if seed else np.random
self.sim = Simulation(gui=self.gui, rng=rng)
self.sim = Simulation(gui=self.gui)
self._init_plugins()
self._advertise_services()
self._broadcast_transforms()
@@ -33,11 +29,13 @@ class BtSimNode:
PhysicsPlugin(self.sim),
JointStatePlugin(self.sim.arm, self.sim.gripper),
ArmControllerPlugin(self.sim.arm, self.sim.controller),
GripperControllerPlugin(self.sim.gripper),
MoveActionPlugin(self.sim.gripper),
GraspActionPlugin(self.sim.gripper),
CameraPlugin(self.sim.camera),
]
def _advertise_services(self):
rospy.Service("seed", Seed, self.seed)
rospy.Service("reset", Reset, self.reset)
def _broadcast_transforms(self):
@@ -50,14 +48,16 @@ class BtSimNode:
]
self.static_broadcaster.sendTransform(msgs)
def seed(self, req):
self.sim.seed(req.seed)
return SeedResponse()
def reset(self, req):
for plugin in self.plugins:
plugin.is_running = False
rospy.sleep(1.0) # TODO replace with a read-write lock
bbox = self.sim.reset()
res = ResetResponse(to_bbox_msg(bbox))
for plugin in self.plugins:
plugin.is_running = True
return res
@@ -137,45 +137,54 @@ class ArmControllerPlugin(Plugin):
self.arm.set_desired_joint_velocities(cmd)
class GripperControllerPlugin(Plugin):
class MoveActionPlugin(Plugin):
def __init__(self, gripper, rate=10):
super().__init__(rate)
self.gripper = gripper
self.dt = 1.0 / self.rate
self._init_move_action_server()
self._init_grasp_action_server()
self._init_action_server()
def _init_move_action_server(self):
def _init_action_server(self):
name = "/franka_gripper/move"
self.move_server = SimpleActionServer(name, MoveAction, auto_start=False)
self.move_server.register_goal_callback(self._move_action_goal_cb)
self.move_server.start()
self.action_server = SimpleActionServer(name, MoveAction, auto_start=False)
self.action_server.register_goal_callback(self._action_goal_cb)
self.action_server.start()
def _init_grasp_action_server(self):
name = "/franka_gripper/grasp"
self.grasp_server = SimpleActionServer(name, GraspAction, auto_start=False)
self.grasp_server.register_goal_callback(self._grasp_action_goal_cb)
self.grasp_server.start()
def _move_action_goal_cb(self):
self.elapsed_time_since_move_action_goal = 0.0
goal = self.move_server.accept_new_goal()
self.gripper.set_desired_width(goal.width)
def _grasp_action_goal_cb(self):
self.elapsed_time_since_grasp_action_goal = 0.0
goal = self.grasp_server.accept_new_goal()
def _action_goal_cb(self):
self.elapsed_time = 0.0
goal = self.action_server.accept_new_goal()
self.gripper.set_desired_width(goal.width)
def _update(self):
if self.move_server.is_active():
self.elapsed_time_since_move_action_goal += self.dt
if self.elapsed_time_since_move_action_goal > 1.0:
self.move_server.set_succeeded()
if self.grasp_server.is_active():
self.elapsed_time_since_grasp_action_goal += self.dt
if self.elapsed_time_since_grasp_action_goal > 1.0:
self.grasp_server.set_succeeded()
if self.action_server.is_active():
self.elapsed_time += self.dt
if self.elapsed_time > 1.0:
self.action_server.set_succeeded()
class GraspActionPlugin(Plugin):
def __init__(self, gripper, rate=10):
super().__init__(rate)
self.gripper = gripper
self.dt = 1.0 / self.rate
self._init_action_server()
def _init_action_server(self):
name = "/franka_gripper/grasp"
self.action_server = SimpleActionServer(name, GraspAction, auto_start=False)
self.action_server.register_goal_callback(self._action_goal_cb)
self.action_server.start()
def _action_goal_cb(self):
self.elapsed_time = 0.0
goal = self.action_server.accept_new_goal()
self.gripper.set_desired_width(goal.width)
def _update(self):
if self.action_server.is_active():
self.elapsed_time += self.dt
if self.elapsed_time > 1.0:
self.action_server.set_succeeded()
class CameraPlugin(Plugin):
@@ -201,8 +210,8 @@ class CameraPlugin(Plugin):
msg.header.stamp = stamp
self.info_pub.publish(msg)
img = self.camera.get_image()
msg = self.cv_bridge.cv2_to_imgmsg(img.depth)
_, depth, _ = self.camera.get_image()
msg = self.cv_bridge.cv2_to_imgmsg(depth)
msg.header.stamp = stamp
self.depth_pub.publish(msg)

View File

@@ -1,10 +1,24 @@
import argparse
from datetime import datetime
import pandas as pd
from pathlib import Path
import rospy
from tqdm import tqdm
from active_grasp.controller import *
from active_grasp.policy import make, registry
from active_grasp.srv import Seed
class Logger:
def __init__(self, logdir, policy):
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
name = "{}_policy={}".format(stamp, policy)
self.path = logdir / (name + ".csv")
def log_run(self, info):
df = pd.DataFrame.from_records([info])
df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False)
def create_parser():
@@ -12,17 +26,24 @@ def create_parser():
parser.add_argument("policy", type=str, choices=registry.keys())
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--logdir", type=Path, default="logs")
parser.add_argument("--desc", type=str, default="")
parser.add_argument("--seed", type=int, default=12)
return parser
def seed_simulation(seed):
rospy.ServiceProxy("seed", Seed)(seed)
rospy.sleep(1.0)
def main():
rospy.init_node("active_grasp")
parser = create_parser()
args = parser.parse_args()
policy = make(args.policy)
controller = GraspController(policy)
logger = Logger(args.logdir, args.policy, args.desc)
logger = Logger(args.logdir, args.policy)
seed_simulation(args.seed)
for _ in tqdm(range(args.runs)):
info = controller.run()