Pass arguments directly to the policy
This commit is contained in:
@@ -13,7 +13,7 @@ class SingleView(BasePolicy):
|
||||
|
||||
def update(self, img, extrinsic):
|
||||
self.integrate_img(img, extrinsic)
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class TopView(BasePolicy):
|
||||
self.integrate_img(img, extrinsic)
|
||||
error = extrinsic.translation - self.target.translation
|
||||
if np.linalg.norm(error) < 0.01:
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
return self.target
|
||||
|
||||
@@ -58,7 +58,7 @@ class RandomView(BasePolicy):
|
||||
self.integrate_img(img, extrinsic)
|
||||
error = extrinsic.translation - self.target.translation
|
||||
if np.linalg.norm(error) < 0.01:
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
return self.target
|
||||
|
||||
@@ -83,7 +83,7 @@ class FixedTrajectory(BasePolicy):
|
||||
self.integrate_img(img, extrinsic)
|
||||
elapsed_time = (rospy.Time.now() - self.tic).to_sec()
|
||||
if elapsed_time > self.duration:
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
else:
|
||||
t = self.m(elapsed_time)
|
||||
@@ -106,7 +106,7 @@ class AlignmentView(BasePolicy):
|
||||
self.integrate_img(img, extrinsic)
|
||||
|
||||
if not self.target:
|
||||
grasp = self.predict_best_grasp()
|
||||
grasp = self.compute_best_grasp()
|
||||
if not grasp:
|
||||
self.done = True
|
||||
return
|
||||
@@ -118,6 +118,6 @@ class AlignmentView(BasePolicy):
|
||||
|
||||
error = extrinsic.translation - self.target.translation
|
||||
if np.linalg.norm(error) < 0.01:
|
||||
self.best_grasp = self.predict_best_grasp()
|
||||
self.best_grasp = self.compute_best_grasp()
|
||||
self.done = True
|
||||
return self.target
|
||||
|
@@ -3,10 +3,9 @@ import cv_bridge
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import CameraInfo, Image
|
||||
from sensor_msgs.msg import Image
|
||||
|
||||
from .bbox import from_bbox_msg
|
||||
from .policy import make
|
||||
from .timer import Timer
|
||||
from active_grasp.srv import Reset, ResetRequest
|
||||
from robot_helpers.ros import tf
|
||||
@@ -16,19 +15,18 @@ from robot_helpers.spatial import Rotation, Transform
|
||||
|
||||
|
||||
class GraspController:
|
||||
def __init__(self, policy_id):
|
||||
def __init__(self, policy):
|
||||
self.policy = policy
|
||||
self.reset_env = rospy.ServiceProxy("reset", Reset)
|
||||
self.load_parameters()
|
||||
self.lookup_transforms()
|
||||
self.init_robot_connection()
|
||||
self.init_camera_stream()
|
||||
self.make_policy(policy_id)
|
||||
|
||||
def load_parameters(self):
|
||||
self.base_frame = rospy.get_param("~base_frame_id")
|
||||
self.ee_frame = rospy.get_param("~ee_frame_id")
|
||||
self.cam_frame = rospy.get_param("~camera/frame_id")
|
||||
self.info_topic = rospy.get_param("~camera/info_topic")
|
||||
self.depth_topic = rospy.get_param("~camera/depth_topic")
|
||||
self.T_grasp_ee = Transform.from_list(rospy.get_param("~ee_grasp_offset")).inv()
|
||||
|
||||
@@ -44,17 +42,12 @@ class GraspController:
|
||||
self.target_pose_pub.publish(msg)
|
||||
|
||||
def init_camera_stream(self):
|
||||
msg = rospy.wait_for_message(self.info_topic, CameraInfo, rospy.Duration(2.0))
|
||||
self.intrinsic = from_camera_info_msg(msg)
|
||||
self.cv_bridge = cv_bridge.CvBridge()
|
||||
rospy.Subscriber(self.depth_topic, Image, self.sensor_cb, queue_size=1)
|
||||
|
||||
def sensor_cb(self, msg):
|
||||
self.latest_depth_msg = msg
|
||||
|
||||
def make_policy(self, name):
|
||||
self.policy = make(name, self.intrinsic)
|
||||
|
||||
def run(self):
|
||||
bbox = self.reset()
|
||||
with Timer("search_time"):
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from sensor_msgs.msg import CameraInfo
|
||||
from pathlib import Path
|
||||
import rospy
|
||||
|
||||
@@ -19,15 +20,17 @@ class Policy:
|
||||
|
||||
|
||||
class BasePolicy(Policy):
|
||||
def __init__(self, intrinsic):
|
||||
self.intrinsic = intrinsic
|
||||
self.rate = 5
|
||||
def __init__(self, rate=5):
|
||||
self.rate = rate
|
||||
self.load_parameters()
|
||||
self.init_visualizer()
|
||||
|
||||
def load_parameters(self):
|
||||
self.base_frame = rospy.get_param("active_grasp/base_frame_id")
|
||||
self.task_frame = "task"
|
||||
info_topic = rospy.get_param("active_grasp/camera/info_topic")
|
||||
msg = rospy.wait_for_message(info_topic, CameraInfo, rospy.Duration(2.0))
|
||||
self.intrinsic = from_camera_info_msg(msg)
|
||||
self.vgn = VGN(Path(rospy.get_param("vgn/model")))
|
||||
|
||||
def init_visualizer(self):
|
||||
@@ -56,6 +59,9 @@ class BasePolicy(Policy):
|
||||
self.visualizer.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud())
|
||||
self.visualizer.path(self.viewpoints)
|
||||
|
||||
def compute_best_grasp(self):
|
||||
return self.predict_best_grasp()
|
||||
|
||||
def predict_best_grasp(self):
|
||||
tsdf_grid = self.tsdf.get_grid()
|
||||
out = self.vgn.predict(tsdf_grid)
|
||||
|
Reference in New Issue
Block a user