From 66cbf3951615bed391008443d4ca9f2565168a4d Mon Sep 17 00:00:00 2001 From: Michel Breyer Date: Mon, 12 Jul 2021 13:12:36 +0200 Subject: [PATCH] Log metrics --- .gitignore | 1 + active_grasp/controller.py | 28 ++++++++++++++++++++++++---- active_grasp/policy.py | 1 + active_grasp/utils.py | 35 +++++++++++++++++++++++++++++++++++ scripts/run.py | 15 +++++++++++---- 5 files changed, 72 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 751e1fa..9783f60 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,4 @@ dmypy.json .vscode/ assets/ +logs/ diff --git a/active_grasp/controller.py b/active_grasp/controller.py index f7a47b5..fb0036c 100644 --- a/active_grasp/controller.py +++ b/active_grasp/controller.py @@ -20,9 +20,11 @@ class GraspController: def run(self): bbox = self.reset() - grasp = self.explore(bbox) - if grasp: - self.execute_grasp(grasp) + with Timer("exploration_time"): + grasp = self.explore(bbox) + with Timer("execution_time"): + res = self.execute_grasp(grasp) + return self.collect_info(res) def reset(self): req = ResetRequest() @@ -42,6 +44,9 @@ class GraspController: return self.policy.best_grasp def execute_grasp(self, grasp): + if not grasp: + return "aborted" + T_B_G = self.postprocess(grasp) self.gripper.move(0.08) @@ -65,7 +70,9 @@ class GraspController: rospy.sleep(2.0) # Check whether the object remains in the hand - return self.gripper.read() > 0.005 + success = self.gripper.read() > 0.005 + + return "succeeded" if success else "failed" def postprocess(self, T_B_G): # Ensure that the camera is pointing forward. @@ -73,3 +80,16 @@ class GraspController: if rot.as_matrix()[:, 0][0] < 0: T_B_G.rotation = rot * Rotation.from_euler("z", np.pi) return T_B_G + + def collect_info(self, result): + points = [p.translation for p in self.policy.viewpoints] + d = np.sum([np.linalg.norm(p2 - p1) for p1, p2 in zip(points, points[1:])]) + + info = { + "result": result, + "viewpoint_count": len(points), + "distance_travelled": d, + } + info.update(self.policy.info) + info.update(Timer.timers) + return info diff --git a/active_grasp/policy.py b/active_grasp/policy.py index d9c2045..6f18dcd 100644 --- a/active_grasp/policy.py +++ b/active_grasp/policy.py @@ -28,6 +28,7 @@ class BasePolicy: self.connect_to_rviz() self.rate = 5 + self.info = {} def load_parameters(self): self.task_frame = rospy.get_param("~frame_id") diff --git a/active_grasp/utils.py b/active_grasp/utils.py index ac6e5fd..0822c86 100644 --- a/active_grasp/utils.py +++ b/active_grasp/utils.py @@ -1,5 +1,8 @@ +from datetime import datetime from geometry_msgs.msg import PoseStamped +import pandas as pd import rospy +import time import active_grasp.msg from robot_utils.ros.conversions import * @@ -34,3 +37,35 @@ def to_bbox_msg(bbox): msg.min = to_point_msg(bbox.min) msg.max = to_point_msg(bbox.max) return msg + + +class Timer: + timers = dict() + + def __init__(self, name): + self.name = name + + def __enter__(self): + self.start() + return self + + def __exit__(self, *exc_info): + self.stop() + + def start(self): + self.tic = time.perf_counter() + + def stop(self): + elapsed_time = time.perf_counter() - self.tic + self.timers[self.name] = elapsed_time + + +class Logger: + def __init__(self, logdir, policy, desc): + stamp = datetime.now().strftime("%y%m%d-%H%M%S") + name = "{}_policy={},{}".format(stamp, policy, desc).strip(",") + self.path = logdir / (name + ".csv") + + def log_run(self, info): + df = pd.DataFrame.from_records([info]) + df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False) diff --git a/scripts/run.py b/scripts/run.py index f05792d..ade6b07 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -1,13 +1,18 @@ import argparse +from pathlib import Path import rospy +from tqdm import tqdm -from active_grasp.controller import GraspController +from active_grasp.controller import * from active_grasp.policy import make, registry def create_parser(): parser = argparse.ArgumentParser() - parser.add_argument("--policy", type=str, choices=registry.keys()) + parser.add_argument("policy", type=str, choices=registry.keys()) + parser.add_argument("--runs", type=int, default=10) + parser.add_argument("--logdir", type=Path, default="logs") + parser.add_argument("--desc", type=str, default="") return parser @@ -17,9 +22,11 @@ def main(): args = parser.parse_args() policy = make(args.policy) controller = GraspController(policy) + logger = Logger(args.logdir, args.policy, args.desc) - while True: - controller.run() + for _ in tqdm(range(args.runs)): + info = controller.run() + logger.log_run(info) if __name__ == "__main__":