Define task frame based on bbox
This commit is contained in:
@@ -20,11 +20,10 @@ class BtSimNode:
|
||||
def __init__(self):
|
||||
self.gui = rospy.get_param("~gui", True)
|
||||
self.sim = Simulation(gui=self.gui)
|
||||
self._init_plugins()
|
||||
self._advertise_services()
|
||||
self._broadcast_transforms()
|
||||
self.init_plugins()
|
||||
self.advertise_services()
|
||||
|
||||
def _init_plugins(self):
|
||||
def init_plugins(self):
|
||||
self.plugins = [
|
||||
PhysicsPlugin(self.sim),
|
||||
JointStatePlugin(self.sim.arm, self.sim.gripper),
|
||||
@@ -34,20 +33,10 @@ class BtSimNode:
|
||||
CameraPlugin(self.sim.camera),
|
||||
]
|
||||
|
||||
def _advertise_services(self):
|
||||
def advertise_services(self):
|
||||
rospy.Service("seed", Seed, self.seed)
|
||||
rospy.Service("reset", Reset, self.reset)
|
||||
|
||||
def _broadcast_transforms(self):
|
||||
self.static_broadcaster = tf2_ros.StaticTransformBroadcaster()
|
||||
msgs = [
|
||||
to_transform_stamped_msg(self.sim.T_W_B, "world", "panda_link0"),
|
||||
to_transform_stamped_msg(
|
||||
Transform.translation(self.sim.origin), "world", "task"
|
||||
),
|
||||
]
|
||||
self.static_broadcaster.sendTransform(msgs)
|
||||
|
||||
def seed(self, req):
|
||||
self.sim.seed(req.seed)
|
||||
return SeedResponse()
|
||||
@@ -63,10 +52,10 @@ class BtSimNode:
|
||||
return res
|
||||
|
||||
def run(self):
|
||||
self._start_plugins()
|
||||
self.start_plugins()
|
||||
rospy.spin()
|
||||
|
||||
def _start_plugins(self):
|
||||
def start_plugins(self):
|
||||
for plugin in self.plugins:
|
||||
plugin.thread.start()
|
||||
plugin.is_running = True
|
||||
@@ -77,17 +66,17 @@ class Plugin:
|
||||
|
||||
def __init__(self, rate):
|
||||
self.rate = rate
|
||||
self.thread = Thread(target=self._loop, daemon=True)
|
||||
self.thread = Thread(target=self.loop, daemon=True)
|
||||
self.is_running = False
|
||||
|
||||
def _loop(self):
|
||||
def loop(self):
|
||||
rate = rospy.Rate(self.rate)
|
||||
while not rospy.is_shutdown():
|
||||
if self.is_running:
|
||||
self._update()
|
||||
self.update()
|
||||
rate.sleep()
|
||||
|
||||
def _update(self):
|
||||
def update(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -96,7 +85,7 @@ class PhysicsPlugin(Plugin):
|
||||
super().__init__(sim.rate)
|
||||
self.sim = sim
|
||||
|
||||
def _update(self):
|
||||
def update(self):
|
||||
self.sim.step()
|
||||
|
||||
|
||||
@@ -107,7 +96,7 @@ class JointStatePlugin(Plugin):
|
||||
self.gripper = gripper
|
||||
self.pub = rospy.Publisher("joint_states", JointState, queue_size=10)
|
||||
|
||||
def _update(self):
|
||||
def update(self):
|
||||
q, _ = self.arm.get_state()
|
||||
width = self.gripper.read()
|
||||
msg = JointState()
|
||||
@@ -125,13 +114,13 @@ class ArmControllerPlugin(Plugin):
|
||||
super().__init__(rate)
|
||||
self.arm = arm
|
||||
self.controller = controller
|
||||
rospy.Subscriber("command", PoseStamped, self._target_cb)
|
||||
rospy.Subscriber("command", PoseStamped, self.target_cb)
|
||||
|
||||
def _target_cb(self, msg):
|
||||
def target_cb(self, msg):
|
||||
assert msg.header.frame_id == self.arm.base_frame
|
||||
self.controller.x_d = from_pose_msg(msg.pose)
|
||||
|
||||
def _update(self):
|
||||
def update(self):
|
||||
q, _ = self.arm.get_state()
|
||||
cmd = self.controller.update(q)
|
||||
self.arm.set_desired_joint_velocities(cmd)
|
||||
@@ -142,20 +131,20 @@ class MoveActionPlugin(Plugin):
|
||||
super().__init__(rate)
|
||||
self.gripper = gripper
|
||||
self.dt = 1.0 / self.rate
|
||||
self._init_action_server()
|
||||
self.init_action_server()
|
||||
|
||||
def _init_action_server(self):
|
||||
def init_action_server(self):
|
||||
name = "/franka_gripper/move"
|
||||
self.action_server = SimpleActionServer(name, MoveAction, auto_start=False)
|
||||
self.action_server.register_goal_callback(self._action_goal_cb)
|
||||
self.action_server.register_goal_callback(self.action_goal_cb)
|
||||
self.action_server.start()
|
||||
|
||||
def _action_goal_cb(self):
|
||||
def action_goal_cb(self):
|
||||
self.elapsed_time = 0.0
|
||||
goal = self.action_server.accept_new_goal()
|
||||
self.gripper.set_desired_width(goal.width)
|
||||
|
||||
def _update(self):
|
||||
def update(self):
|
||||
if self.action_server.is_active():
|
||||
self.elapsed_time += self.dt
|
||||
if self.elapsed_time > 1.0:
|
||||
@@ -167,20 +156,20 @@ class GraspActionPlugin(Plugin):
|
||||
super().__init__(rate)
|
||||
self.gripper = gripper
|
||||
self.dt = 1.0 / self.rate
|
||||
self._init_action_server()
|
||||
self.init_action_server()
|
||||
|
||||
def _init_action_server(self):
|
||||
def init_action_server(self):
|
||||
name = "/franka_gripper/grasp"
|
||||
self.action_server = SimpleActionServer(name, GraspAction, auto_start=False)
|
||||
self.action_server.register_goal_callback(self._action_goal_cb)
|
||||
self.action_server.register_goal_callback(self.action_goal_cb)
|
||||
self.action_server.start()
|
||||
|
||||
def _action_goal_cb(self):
|
||||
def action_goal_cb(self):
|
||||
self.elapsed_time = 0.0
|
||||
goal = self.action_server.accept_new_goal()
|
||||
self.gripper.set_desired_width(goal.width)
|
||||
|
||||
def _update(self):
|
||||
def update(self):
|
||||
if self.action_server.is_active():
|
||||
self.elapsed_time += self.dt
|
||||
if self.elapsed_time > 1.0:
|
||||
@@ -188,21 +177,20 @@ class GraspActionPlugin(Plugin):
|
||||
|
||||
|
||||
class CameraPlugin(Plugin):
|
||||
def __init__(self, camera, name="camera"):
|
||||
rate = rospy.get_param("~cam_rate", 5)
|
||||
def __init__(self, camera, name="camera", rate=5):
|
||||
super().__init__(rate)
|
||||
self.camera = camera
|
||||
self.name = name
|
||||
self.cv_bridge = cv_bridge.CvBridge()
|
||||
self._init_publishers()
|
||||
self.init_publishers()
|
||||
|
||||
def _init_publishers(self):
|
||||
def init_publishers(self):
|
||||
topic = self.name + "/depth/camera_info"
|
||||
self.info_pub = rospy.Publisher(topic, CameraInfo, queue_size=10)
|
||||
topic = self.name + "/depth/image_raw"
|
||||
self.depth_pub = rospy.Publisher(topic, Image, queue_size=10)
|
||||
|
||||
def _update(self):
|
||||
def update(self):
|
||||
stamp = rospy.Time.now()
|
||||
|
||||
msg = to_camera_info_msg(self.camera.intrinsic)
|
||||
|
@@ -6,10 +6,33 @@ import rospy
|
||||
from tqdm import tqdm
|
||||
|
||||
from active_grasp.controller import *
|
||||
from active_grasp.policy import make, registry
|
||||
from active_grasp.policy import registry
|
||||
from active_grasp.srv import Seed
|
||||
|
||||
|
||||
def main():
|
||||
rospy.init_node("active_grasp")
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
controller = GraspController(args.policy)
|
||||
logger = Logger(args.logdir, args.policy)
|
||||
|
||||
seed_simulation(args.seed)
|
||||
|
||||
for _ in tqdm(range(args.runs)):
|
||||
info = controller.run()
|
||||
logger.log_run(info)
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||
parser.add_argument("--runs", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=Path, default="logs")
|
||||
parser.add_argument("--seed", type=int, default=12)
|
||||
return parser
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, logdir, policy):
|
||||
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
@@ -21,34 +44,10 @@ class Logger:
|
||||
df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False)
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||
parser.add_argument("--runs", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=Path, default="logs")
|
||||
parser.add_argument("--seed", type=int, default=12)
|
||||
return parser
|
||||
|
||||
|
||||
def seed_simulation(seed):
|
||||
rospy.ServiceProxy("seed", Seed)(seed)
|
||||
rospy.sleep(1.0)
|
||||
|
||||
|
||||
def main():
|
||||
rospy.init_node("active_grasp")
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
policy = make(args.policy)
|
||||
controller = GraspController(policy)
|
||||
logger = Logger(args.logdir, args.policy)
|
||||
|
||||
seed_simulation(args.seed)
|
||||
|
||||
for _ in tqdm(range(args.runs)):
|
||||
info = controller.run()
|
||||
logger.log_run(info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user