Log metrics

This commit is contained in:
Michel Breyer 2021-07-12 13:12:36 +02:00
parent 1375cedcb5
commit 66cbf39516
5 changed files with 72 additions and 8 deletions

1
.gitignore vendored
View File

@ -132,3 +132,4 @@ dmypy.json
.vscode/
assets/
logs/

View File

@ -20,9 +20,11 @@ class GraspController:
def run(self):
bbox = self.reset()
with Timer("exploration_time"):
grasp = self.explore(bbox)
if grasp:
self.execute_grasp(grasp)
with Timer("execution_time"):
res = self.execute_grasp(grasp)
return self.collect_info(res)
def reset(self):
req = ResetRequest()
@ -42,6 +44,9 @@ class GraspController:
return self.policy.best_grasp
def execute_grasp(self, grasp):
if not grasp:
return "aborted"
T_B_G = self.postprocess(grasp)
self.gripper.move(0.08)
@ -65,7 +70,9 @@ class GraspController:
rospy.sleep(2.0)
# Check whether the object remains in the hand
return self.gripper.read() > 0.005
success = self.gripper.read() > 0.005
return "succeeded" if success else "failed"
def postprocess(self, T_B_G):
# Ensure that the camera is pointing forward.
@ -73,3 +80,16 @@ class GraspController:
if rot.as_matrix()[:, 0][0] < 0:
T_B_G.rotation = rot * Rotation.from_euler("z", np.pi)
return T_B_G
def collect_info(self, result):
points = [p.translation for p in self.policy.viewpoints]
d = np.sum([np.linalg.norm(p2 - p1) for p1, p2 in zip(points, points[1:])])
info = {
"result": result,
"viewpoint_count": len(points),
"distance_travelled": d,
}
info.update(self.policy.info)
info.update(Timer.timers)
return info

View File

@ -28,6 +28,7 @@ class BasePolicy:
self.connect_to_rviz()
self.rate = 5
self.info = {}
def load_parameters(self):
self.task_frame = rospy.get_param("~frame_id")

View File

@ -1,5 +1,8 @@
from datetime import datetime
from geometry_msgs.msg import PoseStamped
import pandas as pd
import rospy
import time
import active_grasp.msg
from robot_utils.ros.conversions import *
@ -34,3 +37,35 @@ def to_bbox_msg(bbox):
msg.min = to_point_msg(bbox.min)
msg.max = to_point_msg(bbox.max)
return msg
class Timer:
timers = dict()
def __init__(self, name):
self.name = name
def __enter__(self):
self.start()
return self
def __exit__(self, *exc_info):
self.stop()
def start(self):
self.tic = time.perf_counter()
def stop(self):
elapsed_time = time.perf_counter() - self.tic
self.timers[self.name] = elapsed_time
class Logger:
def __init__(self, logdir, policy, desc):
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
name = "{}_policy={},{}".format(stamp, policy, desc).strip(",")
self.path = logdir / (name + ".csv")
def log_run(self, info):
df = pd.DataFrame.from_records([info])
df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False)

View File

@ -1,13 +1,18 @@
import argparse
from pathlib import Path
import rospy
from tqdm import tqdm
from active_grasp.controller import GraspController
from active_grasp.controller import *
from active_grasp.policy import make, registry
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--policy", type=str, choices=registry.keys())
parser.add_argument("policy", type=str, choices=registry.keys())
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--logdir", type=Path, default="logs")
parser.add_argument("--desc", type=str, default="")
return parser
@ -17,9 +22,11 @@ def main():
args = parser.parse_args()
policy = make(args.policy)
controller = GraspController(policy)
logger = Logger(args.logdir, args.policy, args.desc)
while True:
controller.run()
for _ in tqdm(range(args.runs)):
info = controller.run()
logger.log_run(info)
if __name__ == "__main__":