nbv_sim/scripts/run.py

79 lines
2.1 KiB
Python
Raw Normal View History

2021-12-03 14:13:59 +01:00
#!/usr/bin/env python3
2021-07-06 14:00:04 +02:00
import argparse
2021-07-22 11:05:30 +02:00
from datetime import datetime
import pandas as pd
2021-07-12 13:12:36 +02:00
from pathlib import Path
2021-07-06 14:00:04 +02:00
import rospy
2021-07-12 13:12:36 +02:00
from tqdm import tqdm
2024-10-13 05:34:35 -05:00
import torch
2021-07-06 14:00:04 +02:00
2021-07-12 13:12:36 +02:00
from active_grasp.controller import *
2021-08-06 15:23:50 +02:00
from active_grasp.policy import make, registry
2021-07-22 11:05:30 +02:00
from active_grasp.srv import Seed
2021-09-11 12:00:52 +02:00
from robot_helpers.ros import tf
2021-07-22 11:05:30 +02:00
2021-08-03 18:11:30 +02:00
def main():
2024-10-13 05:34:35 -05:00
torch.cuda.empty_cache()
2021-09-03 22:39:17 +02:00
rospy.init_node("grasp_controller")
2021-09-11 12:00:52 +02:00
tf.init()
2021-08-06 15:23:50 +02:00
2021-08-03 18:11:30 +02:00
parser = create_parser()
args = parser.parse_args()
2021-08-06 15:23:50 +02:00
2021-09-11 20:49:55 +02:00
policy = make(args.policy)
2021-08-06 15:23:50 +02:00
controller = GraspController(policy)
logger = Logger(args)
2021-07-22 11:05:30 +02:00
2021-08-03 18:11:30 +02:00
seed_simulation(args.seed)
2021-09-11 20:49:55 +02:00
rospy.sleep(1.0) # Prevents a rare race condiion
2021-08-03 18:11:30 +02:00
2021-12-06 10:49:18 +01:00
for _ in tqdm(range(args.runs), disable=args.wait_for_input):
2021-12-06 09:46:46 +01:00
if args.wait_for_input:
controller.gripper.move(0.08)
controller.switch_to_joint_trajectory_control()
2021-12-06 14:18:51 +01:00
controller.moveit.goto("ready", velocity_scaling=0.4)
2021-12-06 10:22:04 +01:00
i = input("Run policy? [y/n] ")
if i != "y":
exit()
2021-12-06 09:46:46 +01:00
rospy.loginfo("Running policy ...")
2021-08-03 18:11:30 +02:00
info = controller.run()
logger.log_run(info)
2021-07-06 14:00:04 +02:00
def create_parser():
parser = argparse.ArgumentParser()
2021-07-12 13:12:36 +02:00
parser.add_argument("policy", type=str, choices=registry.keys())
2024-10-21 11:36:21 -05:00
parser.add_argument("--runs", type=int, default=5)
2021-12-06 09:46:46 +01:00
parser.add_argument("--wait-for-input", action="store_true")
2021-07-12 13:12:36 +02:00
parser.add_argument("--logdir", type=Path, default="logs")
2021-09-06 17:02:25 +02:00
parser.add_argument("--seed", type=int, default=1)
2021-07-06 14:00:04 +02:00
return parser
2021-08-03 18:11:30 +02:00
class Logger:
2021-08-06 15:23:50 +02:00
def __init__(self, args):
2021-10-13 15:44:06 +02:00
args.logdir.mkdir(parents=True, exist_ok=True)
2021-08-03 18:11:30 +02:00
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
2021-09-11 20:49:55 +02:00
name = "{}_policy={},seed={}.csv".format(
stamp,
args.policy,
args.seed,
)
self.path = args.logdir / name
2021-07-22 11:05:30 +02:00
2021-08-03 18:11:30 +02:00
def log_run(self, info):
df = pd.DataFrame.from_records([info])
df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False)
2021-07-22 11:05:30 +02:00
2021-07-06 14:00:04 +02:00
2021-08-03 18:11:30 +02:00
def seed_simulation(seed):
rospy.ServiceProxy("seed", Seed)(seed)
rospy.sleep(1.0)
2021-07-06 14:00:04 +02:00
if __name__ == "__main__":
main()