Pass arguments directly to the policy

This commit is contained in:
Michel Breyer
2021-08-06 15:23:50 +02:00
parent 6fa4007727
commit 4eeb309a8f
4 changed files with 28 additions and 25 deletions

View File

@@ -6,16 +6,19 @@ import rospy
from tqdm import tqdm
from active_grasp.controller import *
from active_grasp.policy import registry
from active_grasp.policy import make, registry
from active_grasp.srv import Seed
def main():
rospy.init_node("active_grasp")
parser = create_parser()
args = parser.parse_args()
controller = GraspController(args.policy)
logger = Logger(args.logdir, args.policy)
policy = make(args.policy, args.rate)
controller = GraspController(policy)
logger = Logger(args)
seed_simulation(args.seed)
@@ -29,15 +32,16 @@ def create_parser():
parser.add_argument("policy", type=str, choices=registry.keys())
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--logdir", type=Path, default="logs")
parser.add_argument("--rate", type=int, default=5)
parser.add_argument("--seed", type=int, default=12)
return parser
class Logger:
def __init__(self, logdir, policy):
def __init__(self, args):
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
name = "{}_policy={}".format(stamp, policy)
self.path = logdir / (name + ".csv")
descr = "policy={},rate={}".format(args.policy, args.rate)
self.path = args.logdir / (stamp + "_" + descr + ".csv")
def log_run(self, info):
df = pd.DataFrame.from_records([info])