Log metrics
This commit is contained in:
parent
1375cedcb5
commit
66cbf39516
1
.gitignore
vendored
1
.gitignore
vendored
@ -132,3 +132,4 @@ dmypy.json
|
||||
.vscode/
|
||||
|
||||
assets/
|
||||
logs/
|
||||
|
@ -20,9 +20,11 @@ class GraspController:
|
||||
|
||||
def run(self):
|
||||
bbox = self.reset()
|
||||
grasp = self.explore(bbox)
|
||||
if grasp:
|
||||
self.execute_grasp(grasp)
|
||||
with Timer("exploration_time"):
|
||||
grasp = self.explore(bbox)
|
||||
with Timer("execution_time"):
|
||||
res = self.execute_grasp(grasp)
|
||||
return self.collect_info(res)
|
||||
|
||||
def reset(self):
|
||||
req = ResetRequest()
|
||||
@ -42,6 +44,9 @@ class GraspController:
|
||||
return self.policy.best_grasp
|
||||
|
||||
def execute_grasp(self, grasp):
|
||||
if not grasp:
|
||||
return "aborted"
|
||||
|
||||
T_B_G = self.postprocess(grasp)
|
||||
|
||||
self.gripper.move(0.08)
|
||||
@ -65,7 +70,9 @@ class GraspController:
|
||||
rospy.sleep(2.0)
|
||||
|
||||
# Check whether the object remains in the hand
|
||||
return self.gripper.read() > 0.005
|
||||
success = self.gripper.read() > 0.005
|
||||
|
||||
return "succeeded" if success else "failed"
|
||||
|
||||
def postprocess(self, T_B_G):
|
||||
# Ensure that the camera is pointing forward.
|
||||
@ -73,3 +80,16 @@ class GraspController:
|
||||
if rot.as_matrix()[:, 0][0] < 0:
|
||||
T_B_G.rotation = rot * Rotation.from_euler("z", np.pi)
|
||||
return T_B_G
|
||||
|
||||
def collect_info(self, result):
|
||||
points = [p.translation for p in self.policy.viewpoints]
|
||||
d = np.sum([np.linalg.norm(p2 - p1) for p1, p2 in zip(points, points[1:])])
|
||||
|
||||
info = {
|
||||
"result": result,
|
||||
"viewpoint_count": len(points),
|
||||
"distance_travelled": d,
|
||||
}
|
||||
info.update(self.policy.info)
|
||||
info.update(Timer.timers)
|
||||
return info
|
||||
|
@ -28,6 +28,7 @@ class BasePolicy:
|
||||
self.connect_to_rviz()
|
||||
|
||||
self.rate = 5
|
||||
self.info = {}
|
||||
|
||||
def load_parameters(self):
|
||||
self.task_frame = rospy.get_param("~frame_id")
|
||||
|
@ -1,5 +1,8 @@
|
||||
from datetime import datetime
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
import pandas as pd
|
||||
import rospy
|
||||
import time
|
||||
|
||||
import active_grasp.msg
|
||||
from robot_utils.ros.conversions import *
|
||||
@ -34,3 +37,35 @@ def to_bbox_msg(bbox):
|
||||
msg.min = to_point_msg(bbox.min)
|
||||
msg.max = to_point_msg(bbox.max)
|
||||
return msg
|
||||
|
||||
|
||||
class Timer:
|
||||
timers = dict()
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
self.tic = time.perf_counter()
|
||||
|
||||
def stop(self):
|
||||
elapsed_time = time.perf_counter() - self.tic
|
||||
self.timers[self.name] = elapsed_time
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, logdir, policy, desc):
|
||||
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
name = "{}_policy={},{}".format(stamp, policy, desc).strip(",")
|
||||
self.path = logdir / (name + ".csv")
|
||||
|
||||
def log_run(self, info):
|
||||
df = pd.DataFrame.from_records([info])
|
||||
df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False)
|
||||
|
@ -1,13 +1,18 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import rospy
|
||||
from tqdm import tqdm
|
||||
|
||||
from active_grasp.controller import GraspController
|
||||
from active_grasp.controller import *
|
||||
from active_grasp.policy import make, registry
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--policy", type=str, choices=registry.keys())
|
||||
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||
parser.add_argument("--runs", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=Path, default="logs")
|
||||
parser.add_argument("--desc", type=str, default="")
|
||||
return parser
|
||||
|
||||
|
||||
@ -17,9 +22,11 @@ def main():
|
||||
args = parser.parse_args()
|
||||
policy = make(args.policy)
|
||||
controller = GraspController(policy)
|
||||
logger = Logger(args.logdir, args.policy, args.desc)
|
||||
|
||||
while True:
|
||||
controller.run()
|
||||
for _ in tqdm(range(args.runs)):
|
||||
info = controller.run()
|
||||
logger.log_run(info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user