Pass arguments directly to the policy
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from sensor_msgs.msg import CameraInfo
|
||||
from pathlib import Path
|
||||
import rospy
|
||||
|
||||
@@ -19,15 +20,17 @@ class Policy:
|
||||
|
||||
|
||||
class BasePolicy(Policy):
|
||||
def __init__(self, intrinsic):
|
||||
self.intrinsic = intrinsic
|
||||
self.rate = 5
|
||||
def __init__(self, rate=5):
|
||||
self.rate = rate
|
||||
self.load_parameters()
|
||||
self.init_visualizer()
|
||||
|
||||
def load_parameters(self):
|
||||
self.base_frame = rospy.get_param("active_grasp/base_frame_id")
|
||||
self.task_frame = "task"
|
||||
info_topic = rospy.get_param("active_grasp/camera/info_topic")
|
||||
msg = rospy.wait_for_message(info_topic, CameraInfo, rospy.Duration(2.0))
|
||||
self.intrinsic = from_camera_info_msg(msg)
|
||||
self.vgn = VGN(Path(rospy.get_param("vgn/model")))
|
||||
|
||||
def init_visualizer(self):
|
||||
@@ -56,6 +59,9 @@ class BasePolicy(Policy):
|
||||
self.visualizer.scene_cloud(self.task_frame, self.tsdf.get_scene_cloud())
|
||||
self.visualizer.path(self.viewpoints)
|
||||
|
||||
def compute_best_grasp(self):
|
||||
return self.predict_best_grasp()
|
||||
|
||||
def predict_best_grasp(self):
|
||||
tsdf_grid = self.tsdf.get_grid()
|
||||
out = self.vgn.predict(tsdf_grid)
|
||||
|
Reference in New Issue
Block a user