2025-01-15 14:42:54 +08:00

182 lines
8.3 KiB
Python

import os
import numpy as np
import requests
from PytorchBoot.runners import Runner
from PytorchBoot.config import ConfigManager
import PytorchBoot.stereotype as stereotype
from PytorchBoot.utils.log_util import Log
from utils.pose_util import PoseUtil
from utils.control_util import ControlUtil
from utils.communicate_util import CommunicateUtil
from utils.pts_util import PtsUtil
from utils.view_util import ViewUtil
from scipy.spatial.transform import Rotation as R
@stereotype.runner("inference_runner")
class InferenceRunner(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.load_experiment("inference")
self.inference_config = ConfigManager.get("runner", "inference")
self.server_url = self.inference_config["server_url"]
self.max_iter = self.inference_config["max_iter"]
self.max_fail = self.inference_config["max_fail"]
self.max_no_new_pts = self.inference_config["max_no_new_pts"]
self.min_delta_pts_num = self.inference_config["min_delta_pts_num"]
def check_stop(self, cnt, fail, no_new_pts):
if cnt > self.max_iter:
return True
if fail > self.max_fail:
return True
if no_new_pts > self.max_no_new_pts:
return True
return False
def split_scan_pts_and_obj_pts(self, world_pts, z_threshold=0.005):
scan_pts = world_pts[world_pts[:, 2] < z_threshold]
obj_pts = world_pts[world_pts[:, 2] >= z_threshold]
return scan_pts, obj_pts
def run(self):
ControlUtil.connect_robot()
ControlUtil.init()
scanned_pts_list = []
scanned_n_to_world_pose = []
cnt = 0
fail = 0
no_new_pts = 0
view_data = CommunicateUtil.get_view_data(init=True)
first_cam_to_real_world = ControlUtil.get_pose()
if view_data is None:
Log.error("No view data received")
fail += 1
return
cam_shot_pts = ViewUtil.get_pts(view_data)
# ########################################### DEBUG ###########################################
# sensor_pts = PtsUtil.transform_point_cloud(cam_shot_pts, np.linalg.inv(ControlUtil.CAMERA_TO_LEFT_CAMERA))
# np.savetxt('/home/yan20/Downloads/left_pts_0.txt', cam_shot_pts)
# np.savetxt('/home/yan20/Downloads/sensor_pts_0.txt', sensor_pts)
# #############################################################################################
world_shot_pts = PtsUtil.transform_point_cloud(
cam_shot_pts, first_cam_to_real_world
)
#import ipdb; ipdb.set_trace()
_, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts(
world_shot_pts
)
curr_pts = world_splitted_shot_pts
curr_pose = first_cam_to_real_world
curr_pose_6d = PoseUtil.matrix_to_rotation_6d_numpy(curr_pose[:3,:3])
curr_pose_9d = np.concatenate([curr_pose_6d, curr_pose[:3, 3]], axis=0)
scanned_pts_list.append(curr_pts.tolist())
scanned_n_to_world_pose.append(curr_pose_9d.tolist())
combined_pts = np.concatenate(scanned_pts_list, axis=0)
downsampled_combined_pts = PtsUtil.voxel_downsample_point_cloud(combined_pts, 0.003)
last_downsampled_combined_pts_num = downsampled_combined_pts.shape[0]
Log.info(f"First downsampled combined pts: {last_downsampled_combined_pts_num}")
####################################### DEBUG #######################################
# scan_count = 0
# save_path = "/home/yan20/Downloads/pts"
# if not os.path.exists(save_path):
# os.makedirs(save_path)
#####################################################################################
while not self.check_stop(cnt, fail, no_new_pts):
data = {
"scanned_pts": scanned_pts_list,
"scanned_n_to_world_pose_9d": scanned_n_to_world_pose
}
# pts = np.array(data['scanned_pts'][-1])
# np.savetxt(f'{save_path}/pts_{scan_count}.txt', pts)
# scan_count += 1
response = requests.post(self.server_url, json=data)
result = response.json()
pred_pose_9d = np.array(result["pred_pose_9d"])
pred_rot_6d = pred_pose_9d[0, :6]
pred_trans = pred_pose_9d[0, 6:]
pred_rot_mat = PoseUtil.rotation_6d_to_matrix_numpy(pred_rot_6d)
pred_pose = np.eye(4)
pred_pose[:3, :3] = pred_rot_mat
pred_pose[:3, 3] = pred_trans
target_camera_pose = pred_pose @ ControlUtil.CAMERA_CORRECTION
ControlUtil.move_to(target_camera_pose)
cnt += 1
view_data = CommunicateUtil.get_view_data()
if view_data is None:
Log.error("No view data received")
fail += 1
continue
cam_shot_pts = ViewUtil.get_pts(view_data)
left_cam_to_first_left_cam = ViewUtil.get_camera_pose(view_data)
curr_pose = first_cam_to_real_world @ left_cam_to_first_left_cam @ np.linalg.inv(ControlUtil.CAMERA_CORRECTION)
# curr_pose = pred_pose
# curr_pose = first_cam_to_real_world @ ViewUtil.get_camera_pose(view_data)
print('pred_pose:', pred_pose)
print('curr_pose:', curr_pose)
##################################### DEBUG #####################################
# print(curr_pose)
# rot = R.from_matrix(curr_pose[:3, :3])
# quat_xyzw = rot.as_quat()
# translation = curr_pose[:3, 3]
# print(quat_xyzw, translation)
# # from ipdb import set_trace; set_trace()
#################################################################################
world_shot_pts = PtsUtil.transform_point_cloud(
cam_shot_pts, first_cam_to_real_world
)
_, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts(
world_shot_pts
)
curr_pts = world_splitted_shot_pts
import ipdb; ipdb.set_trace()
from utils.vis import visualizeUtil
visualizeUtil.visualize_pts_and_camera(world_splitted_shot_pts,pred_pose)
curr_pose_6d = PoseUtil.matrix_to_rotation_6d_numpy(curr_pose[:3,:3])
curr_pose_9d = np.concatenate([curr_pose_6d, curr_pose[:3, 3]], axis=0)
scanned_pts_list.append(curr_pts.tolist())
scanned_n_to_world_pose.append(curr_pose_9d.tolist())
combined_pts = np.concatenate(scanned_pts_list, axis=0)
downsampled_combined_pts = PtsUtil.voxel_downsample_point_cloud(combined_pts, 0.003)
curr_downsampled_combined_pts_num = downsampled_combined_pts.shape[0]
Log.info(f"Downsampled combined pts: {curr_downsampled_combined_pts_num}")
if curr_downsampled_combined_pts_num < last_downsampled_combined_pts_num + self.min_delta_pts_num:
no_new_pts += 1
Log.info(f"No new points, cnt: {cnt}, fail: {fail}, no_new_pts: {no_new_pts}")
continue
Log.success("Inference finished")
# self.save_inference_result(scanned_pts_list, downsampled_combined_pts)
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
self.inference_result_dir = os.path.join(self.experiment_path, "inference_result")
os.makedirs(self.inference_result_dir)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
self.inference_result_dir = os.path.join(self.experiment_path, "inference_result")
def save_inference_result(self, scanned_pts_list, downsampled_combined_pts):
import time
dir_name = time.strftime("inference_result_%Y_%m_%d_%Hh%Mm%Ss", time.localtime())
dir_path = os.path.join(self.inference_result_dir, dir_name)
for i in range(len(scanned_pts_list)):
np.savetxt(os.path.join(dir_path, f"{i}.txt"), np.array(scanned_pts_list[i]))
np.savetxt(os.path.join(dir_path, "downsampled_combined_pts.txt"), np.array(downsampled_combined_pts))
Log.success("Inference result saved")