Define task frame based on bbox
This commit is contained in:
@@ -6,10 +6,33 @@ import rospy
|
||||
from tqdm import tqdm
|
||||
|
||||
from active_grasp.controller import *
|
||||
from active_grasp.policy import make, registry
|
||||
from active_grasp.policy import registry
|
||||
from active_grasp.srv import Seed
|
||||
|
||||
|
||||
def main():
|
||||
rospy.init_node("active_grasp")
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
controller = GraspController(args.policy)
|
||||
logger = Logger(args.logdir, args.policy)
|
||||
|
||||
seed_simulation(args.seed)
|
||||
|
||||
for _ in tqdm(range(args.runs)):
|
||||
info = controller.run()
|
||||
logger.log_run(info)
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||
parser.add_argument("--runs", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=Path, default="logs")
|
||||
parser.add_argument("--seed", type=int, default=12)
|
||||
return parser
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, logdir, policy):
|
||||
stamp = datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
@@ -21,34 +44,10 @@ class Logger:
|
||||
df.to_csv(self.path, mode="a", header=not self.path.exists(), index=False)
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("policy", type=str, choices=registry.keys())
|
||||
parser.add_argument("--runs", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=Path, default="logs")
|
||||
parser.add_argument("--seed", type=int, default=12)
|
||||
return parser
|
||||
|
||||
|
||||
def seed_simulation(seed):
|
||||
rospy.ServiceProxy("seed", Seed)(seed)
|
||||
rospy.sleep(1.0)
|
||||
|
||||
|
||||
def main():
|
||||
rospy.init_node("active_grasp")
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
policy = make(args.policy)
|
||||
controller = GraspController(policy)
|
||||
logger = Logger(args.logdir, args.policy)
|
||||
|
||||
seed_simulation(args.seed)
|
||||
|
||||
for _ in tqdm(range(args.runs)):
|
||||
info = controller.run()
|
||||
logger.log_run(info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user