diff --git a/src/active_grasp/active_perception.py b/src/active_grasp/active_perception.py new file mode 100644 index 0000000..cc11be1 --- /dev/null +++ b/src/active_grasp/active_perception.py @@ -0,0 +1,89 @@ +import os +import sys +import numpy as np +import torch + +# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + + + +path = os.path.abspath(__file__) +for i in range(2): + path = os.path.dirname(path) +PROJECT_ROOT = path +sys.path.append(PROJECT_ROOT) + +from active_perception.configs.config import ConfigManager +from active_perception.modules.pipeline import Pipeline + +class InferenceEngine(): + RESULTS_DIR_NAME: str = 'results' + LOG_DIR_NAME: str = 'log' + + def __init__(self, config_path): + ''' Config Manager ''' + ConfigManager.load_config_with(config_path) + + ''' Pytorch Seed ''' + seed = ConfigManager.get("settings", "general", "seed") + np.random.seed(seed) + torch.manual_seed(seed) + + ''' Pipeline ''' + # self.pipeline_config = {'pts_encoder': 'pointnet', 'view_finder': 'gradient_field'} + self.pipeline_config = ConfigManager.get("settings", "pipeline") + self.device = ConfigManager.get("settings", "general", "device") + self.pipeline = Pipeline(self.pipeline_config) + self.parallel = ConfigManager.get("settings","general","parallel") + if self.parallel and self.device == "cuda": + self.pipeline = torch.nn.DataParallel(self.pipeline) + self.pipeline = self.pipeline.to(self.device) + + ''' Experiment ''' + # self.model_path = '~/Downloads/full_149_241009.pth' + self.model_path = ConfigManager.get("settings", "experiment", "model_path") + self.load(self.model_path) + + + def load(self, path): + state_dict = torch.load(path) + if self.parallel: + self.pipeline.module.load_state_dict(state_dict) + else: + self.pipeline.load_state_dict(state_dict) + + + def inference(self, data): + self.pipeline.eval() + with torch.no_grad(): + output = self.pipeline(data, Pipeline.TEST_MODE) + + return output + + +if __name__ == "__main__": + ''' Load Configs ''' + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default=PROJECT_ROOT+"/active_grasp/active_perception/configs/local_inference_config.yaml") + args = parser.parse_args() + + ''' Initialize Test Data ''' + test_scene = torch.rand(1, 1024, 3).to("cuda:0") + test_target = torch.rand(1, 1024, 3).to("cuda:0") + test_delta_rot_6d = torch.rand(1, 6).to("cuda:0") + a = test_delta_rot_6d[:, :3] + b = test_delta_rot_6d[:, 3:] + a_norm = a / a.norm(dim=1, keepdim=True) + b_norm = b / b.norm(dim=1, keepdim=True) + normalized_test_delta_rot_6d = torch.cat((a_norm, b_norm), dim=1) + test_data = { + 'target_pts': test_target, + 'scene_pts': test_scene, + } + + ''' Inference ''' + infenrence_engine = InferenceEngine(args.config) + output = infenrence_engine.inference(test_data) + print(output.keys()) + print(output['estimated_delta_rot_6d']) \ No newline at end of file diff --git a/src/active_grasp/active_perception/__init__.py b/src/active_grasp/active_perception/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/active_grasp/active_perception/annotations/external_module.py b/src/active_grasp/active_perception/annotations/external_module.py new file mode 100755 index 0000000..530bea4 --- /dev/null +++ b/src/active_grasp/active_perception/annotations/external_module.py @@ -0,0 +1,7 @@ +EXTERNAL_FREEZE_MODULES = set() + +def external_freeze(cls): + if not hasattr(cls, 'load') or not callable(getattr(cls, 'load')): + raise TypeError(f"external module <{cls.__name__}> must implement a 'load' method") + EXTERNAL_FREEZE_MODULES.add(cls) + return cls \ No newline at end of file diff --git a/src/active_grasp/active_perception/annotations/singleton.py b/src/active_grasp/active_perception/annotations/singleton.py new file mode 100755 index 0000000..8291aa1 --- /dev/null +++ b/src/active_grasp/active_perception/annotations/singleton.py @@ -0,0 +1,8 @@ + +def singleton(cls): + instances = {} + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + return get_instance \ No newline at end of file diff --git a/src/active_grasp/active_perception/annotations/stereotype.py b/src/active_grasp/active_perception/annotations/stereotype.py new file mode 100755 index 0000000..a4b4eae --- /dev/null +++ b/src/active_grasp/active_perception/annotations/stereotype.py @@ -0,0 +1,34 @@ +# --- Classes --- # + +def dataset(): + pass + +def module(): + pass + +def pipeline(): + pass + +def runner(): + pass + +def factory(): + pass + +# --- Functions --- # + +evaluation_methods = {} +def evaluation_method(eval_type): + def decorator(func): + evaluation_methods[eval_type] = func + return func + return decorator + + +def loss_function(): + pass + + +# --- Main --- # + + \ No newline at end of file diff --git a/src/active_grasp/active_perception/configs/__init__.py b/src/active_grasp/active_perception/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/active_grasp/active_perception/configs/config.py b/src/active_grasp/active_perception/configs/config.py new file mode 100755 index 0000000..e8ed05c --- /dev/null +++ b/src/active_grasp/active_perception/configs/config.py @@ -0,0 +1,74 @@ +import argparse +import os.path +import shutil +import yaml + + +class ConfigManager: + config = None + config_path = None + + @staticmethod + def get(*args): + result = ConfigManager.config + for arg in args: + result = result[arg] + return result + + @staticmethod + def load_config_with(config_file_path): + ConfigManager.config_path = config_file_path + if not os.path.exists(ConfigManager.config_path): + raise ValueError(f"Config file <{config_file_path}> does not exist") + with open(config_file_path, 'r') as file: + ConfigManager.config = yaml.safe_load(file) + + @staticmethod + def backup_config_to(target_config_dir, file_name, prefix="config"): + file_name = f"{prefix}_{file_name}.yaml" + target_config_file_path = str(os.path.join(target_config_dir, file_name)) + shutil.copy(ConfigManager.config_path, target_config_file_path) + + @staticmethod + def load_config(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='', help='config file path') + args = parser.parse_args() + if args.config: + ConfigManager.load_config_with(args.config) + + @staticmethod + def print_config(key: str = None, group: dict = None, level=0): + table_size = 80 + if key and group: + value = group[key] + if type(value) is dict: + print("\t" * level + f"+-{key}:") + for k in value: + ConfigManager.print_config(k, value, level=level + 1) + else: + print("\t" * level + f"| {key}: {value}") + elif key: + ConfigManager.print_config(key, ConfigManager.config, level=level) + else: + print("+" + "-" * table_size + "+") + print(f"| Configurations in <{ConfigManager.config_path}>:") + print("+" + "-" * table_size + "+") + for key in ConfigManager.config: + ConfigManager.print_config(key, level=level + 1) + print("+" + "-" * table_size + "+") + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + test_args = ['--config', 'local_train_config.yaml'] + test_parser = argparse.ArgumentParser() + test_parser.add_argument('--config', type=str, default='', help='config file path') + test_args = test_parser.parse_args(test_args) + if test_args.config: + ConfigManager.load_config_with(test_args.config) + ConfigManager.print_config() + print() + pipeline = ConfigManager.get('settings', 'train', 'batch_size') + ConfigManager.print_config('settings') + print(pipeline) diff --git a/src/active_grasp/active_perception/configs/local_inference_config.yaml b/src/active_grasp/active_perception/configs/local_inference_config.yaml new file mode 100755 index 0000000..8644472 --- /dev/null +++ b/src/active_grasp/active_perception/configs/local_inference_config.yaml @@ -0,0 +1,66 @@ +# Train config file + +settings: + general: + seed: 0 + cuda_visible_devices: "0,1,2,3,4,5,6,7" + device: cuda + test_dir: "" + print: True + parallel: True + + experiment: + name: test_inference + root_dir: "experiments" + model_path: "/home/zhengxiao-han/Downloads/full_149_241009.pth" + use_cache: True + small_batch_overfit: False + + test: + batch_size: 96 + dataset_list: + - name: synthetic_test_sample + source: nbv1 + data_type: sample + synthetic: True + ratio: 1.0 + batch_size: 96 + num_workers: 8 + + results: + save_data_keys: ["target_name","src_rot_mat"] + save_output_keys: ["in_process_sample"] + + pipeline: # module_type: name + pts_encoder: pointnet + view_finder: gradient_field + +datasets: + general: + data_dir: "/mnt/d/Datasets" + score_limit: 0.3 + target_pts_num: 1024 + scene_pts_num: 16384 + canonical: False + rgb_feat_cache: True + + +modules: + general: + pts_channels: 3 + feature_dim: 1024 + per_point_feature: False + pts_encoder: + pointnet: + pointnet++: + params_name: light + view_finder: + gradient_field: + pose_mode: rot_matrix + regression_head: Rx_Ry + sample_mode: ode + sample_repeat: 50 + sampling_steps: 500 + sde_mode: ve + rgb_encoder: + dinov2: diff --git a/src/active_grasp/active_perception/modules/__init__.py b/src/active_grasp/active_perception/modules/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/src/active_grasp/active_perception/modules/func_lib/__init__.py b/src/active_grasp/active_perception/modules/func_lib/__init__.py new file mode 100755 index 0000000..5d3879a --- /dev/null +++ b/src/active_grasp/active_perception/modules/func_lib/__init__.py @@ -0,0 +1,7 @@ +from modules.func_lib.samplers import ( + cond_pc_sampler, + cond_ode_sampler +) +from modules.func_lib.sde import ( + init_sde +) diff --git a/src/active_grasp/active_perception/modules/func_lib/samplers.py b/src/active_grasp/active_perception/modules/func_lib/samplers.py new file mode 100755 index 0000000..923dc9f --- /dev/null +++ b/src/active_grasp/active_perception/modules/func_lib/samplers.py @@ -0,0 +1,282 @@ +import sys +import os +import torch +import numpy as np + +from scipy import integrate +from utils.pose_util import PoseUtil + + +def global_prior_likelihood(z, sigma_max): + """The likelihood of a Gaussian distribution with mean zero and + standard deviation sigma.""" + # z: [bs, pose_dim] + shape = z.shape + N = np.prod(shape[1:]) # pose_dim + return -N / 2. * torch.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=-1) / (2 * sigma_max ** 2) + + +def cond_ode_likelihood( + score_model, + data, + prior, + sde_coeff, + marginal_prob_fn, + atol=1e-5, + rtol=1e-5, + device='cuda', + eps=1e-5, + num_steps=None, + pose_mode='quat_wxyz', + init_x=None, +): + pose_dim = PoseUtil.get_pose_dim(pose_mode) + batch_size = data['pts'].shape[0] + epsilon = prior((batch_size, pose_dim)).to(device) + init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x + shape = init_x.shape + init_logp = np.zeros((shape[0],)) # [bs] + init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0) + + def score_eval_wrapper(data): + """A wrapper of the score-based model for use by the ODE solver.""" + with torch.no_grad(): + score = score_model(data) + return score.cpu().numpy().reshape((-1,)) + + def divergence_eval(data, epsilon): + """Compute the divergence of the score-based model with Skilling-Hutchinson.""" + # save ckpt of sampled_pose + origin_sampled_pose = data['sampled_pose'].clone() + with torch.enable_grad(): + # make sampled_pose differentiable + data['sampled_pose'].requires_grad_(True) + score = score_model(data) + score_energy = torch.sum(score * epsilon) # [, ] + grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim] + # reset sampled_pose + data['sampled_pose'] = origin_sampled_pose + return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1] + + def divergence_eval_wrapper(data): + """A wrapper for evaluating the divergence of score for the black-box ODE solver.""" + with torch.no_grad(): + # Compute likelihood. + div = divergence_eval(data, epsilon) # [bs, 1] + return div.cpu().numpy().reshape((-1,)).astype(np.float64) + + def ode_func(t, inp): + """The ODE function for use by the ODE solver.""" + # split x, logp from inp + x = inp[:-shape[0]] + # calc x-grad + x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device) + time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t + drift, diffusion = sde_coeff(torch.tensor(t)) + drift = drift.cpu().numpy() + diffusion = diffusion.cpu().numpy() + data['sampled_pose'] = x + data['t'] = time_steps + x_grad = drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data) + # calc logp-grad + logp_grad = drift - 0.5 * (diffusion ** 2) * divergence_eval_wrapper(data) + # concat curr grad + return np.concatenate([x_grad, logp_grad], axis=0) + + # Run the black-box ODE solver, note the + res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45') + zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)] + z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim] + delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp + _, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1 + prior_logp = global_prior_likelihood(z, sigma_max) + log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls) + return z, log_likelihoods + + +def cond_pc_sampler( + score_model, + data, + prior, + sde_coeff, + num_steps=500, + snr=0.16, + device='cuda', + eps=1e-5, + pose_mode='quat_wxyz', + init_x=None, +): + pose_dim = PoseUtil.get_pose_dim(pose_mode) + batch_size = data['target_pts_feat'].shape[0] + init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x + time_steps = torch.linspace(1., eps, num_steps, device=device) + step_size = time_steps[0] - time_steps[1] + noise_norm = np.sqrt(pose_dim) + x = init_x + poses = [] + with torch.no_grad(): + for time_step in time_steps: + batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step + # Corrector step (Langevin MCMC) + data['sampled_pose'] = x + data['t'] = batch_time_step + grad = score_model(data) + grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean() + langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2 + x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x) + + # normalisation + if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw': + # quat, should be normalised + x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True) + elif pose_mode == 'euler_xyz': + pass + else: + # rotation(x axis, y axis), should be normalised + x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True) + x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True) + + # Predictor step (Euler-Maruyama) + drift, diffusion = sde_coeff(batch_time_step) + drift = drift - diffusion ** 2 * grad # R-SDE + mean_x = x + drift * step_size + x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x) + + # normalisation + x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode) + poses.append(x.unsqueeze(0)) + + xs = torch.cat(poses, dim=0) + xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1) + mean_x[:, -3:] += data['pts_center'] + mean_x[:, :-3] = PoseUtil.normalize_rotation(mean_x[:, :-3], pose_mode) + # The last step does not include any noise + return xs.permute(1, 0, 2), mean_x + + +def cond_ode_sampler( + score_model, + data, + prior, + sde_coeff, + atol=1e-5, + rtol=1e-5, + device='cuda', + eps=1e-5, + T=1.0, + num_steps=None, + pose_mode='quat_wxyz', + denoise=True, + init_x=None, +): + pose_dim = PoseUtil.get_pose_dim(pose_mode) + batch_size = data['target_feat'].shape[0] + init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim), + T=T).to(device) + shape = init_x.shape + + def score_eval_wrapper(data): + """A wrapper of the score-based model for use by the ODE solver.""" + with torch.no_grad(): + score = score_model(data) + return score.cpu().numpy().reshape((-1,)) + + def ode_func(t, x): + """The ODE function for use by the ODE solver.""" + x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device) + time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t + drift, diffusion = sde_coeff(torch.tensor(t)) + drift = drift.cpu().numpy() + diffusion = diffusion.cpu().numpy() + data['sampled_pose'] = x + data['t'] = time_steps + return drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data) + + # Run the black-box ODE solver, note the + t_eval = None + if num_steps is not None: + # num_steps, from T -> eps + t_eval = np.linspace(T, eps, num_steps) + res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45', + t_eval=t_eval) + xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim] + x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim] + # denoise, using the predictor step in P-C sampler + if denoise: + # Reverse diffusion predictor for denoising + vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps + drift, diffusion = sde_coeff(vec_eps) + data['sampled_pose'] = x.float() + data['t'] = vec_eps + grad = score_model(data) + drift = drift - diffusion ** 2 * grad # R-SDE + mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps)) + x = mean_x + + num_steps = xs.shape[0] + xs = xs.reshape(batch_size * num_steps, -1) + xs = PoseUtil.normalize_rotation(xs, pose_mode) + xs = xs.reshape(num_steps, batch_size, -1) + x = PoseUtil.normalize_rotation(x, pose_mode) + return xs.permute(1, 0, 2), x + + +def cond_edm_sampler( + decoder_model, data, prior_fn, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, + pose_mode='quat_wxyz', device='cuda' +): + pose_dim = PoseUtil.get_pose_dim(pose_mode) + batch_size = data['pts'].shape[0] + latents = prior_fn((batch_size, pose_dim)).to(device) + + # Time step discretion. note that sigma and t is interchangeable + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * ( + sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + def decoder_wrapper(decoder, data, x, t): + # save temp + x_, t_ = data['sampled_pose'], data['t'] + # init data + data['sampled_pose'], data['t'] = x, t + # denoise + data, denoised = decoder(data) + # recover data + data['sampled_pose'], data['t'] = x_, t_ + return denoised.to(torch.float64) + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + xs = [] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = torch.as_tensor(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = decoder_wrapper(decoder_model, data, x_next, t_next) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + xs.append(x_next.unsqueeze(0)) + + xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim] + x = xs[-1] # [bs, pose_dim] + + # post-processing + xs = xs.reshape(batch_size * num_steps, -1) + xs = PoseUtil.normalize_rotation(xs, pose_mode) + xs = xs.reshape(num_steps, batch_size, -1) + x = PoseUtil.normalize_rotation(x, pose_mode) + return xs.permute(1, 0, 2), x diff --git a/src/active_grasp/active_perception/modules/func_lib/sde.py b/src/active_grasp/active_perception/modules/func_lib/sde.py new file mode 100755 index 0000000..d93c999 --- /dev/null +++ b/src/active_grasp/active_perception/modules/func_lib/sde.py @@ -0,0 +1,121 @@ +import functools +import torch +import numpy as np + + +# ----- VE SDE ----- +# ------------------ +def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90): + std = sigma_min * (sigma_max / sigma_min) ** t + mean = x + return mean, std + + +def ve_sde(t, sigma_min=0.01, sigma_max=90): + sigma = sigma_min * (sigma_max / sigma_min) ** t + drift_coeff = torch.tensor(0) + diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device)) + return drift_coeff, diffusion_coeff + + +def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0): + _, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max) + return torch.randn(*shape) * sigma_max_prior + + +# ----- VP SDE ----- +# ------------------ +def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20): + log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 + mean = torch.exp(log_mean_coeff) * x + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + return mean, std + + +def vp_sde(t, beta_0=0.1, beta_1=20): + beta_t = beta_0 + t * (beta_1 - beta_0) + drift_coeff = -0.5 * beta_t + diffusion_coeff = torch.sqrt(beta_t) + return drift_coeff, diffusion_coeff + + +def vp_prior(shape, beta_0=0.1, beta_1=20): + return torch.randn(*shape) + + +# ----- sub-VP SDE ----- +# ---------------------- +def subvp_marginal_prob(x, t, beta_0, beta_1): + log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 + mean = torch.exp(log_mean_coeff) * x + std = 1 - torch.exp(2. * log_mean_coeff) + return mean, std + + +def subvp_sde(t, beta_0, beta_1): + beta_t = beta_0 + t * (beta_1 - beta_0) + drift_coeff = -0.5 * beta_t + discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2) + diffusion_coeff = torch.sqrt(beta_t * discount) + return drift_coeff, diffusion_coeff + + +def subvp_prior(shape, beta_0=0.1, beta_1=20): + return torch.randn(*shape) + + +# ----- EDM SDE ----- +# ------------------ +def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80): + std = t + mean = x + return mean, std + + +def edm_sde(t, sigma_min=0.002, sigma_max=80): + drift_coeff = torch.tensor(0) + diffusion_coeff = torch.sqrt(2 * t) + return drift_coeff, diffusion_coeff + + +def edm_prior(shape, sigma_min=0.002, sigma_max=80): + return torch.randn(*shape) * sigma_max + + +def init_sde(sde_mode): + # the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch + if sde_mode == 'edm': + sigma_min = 0.002 + sigma_max = 80 + eps = 0.002 + prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max) + marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max) + sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max) + T = sigma_max + elif sde_mode == 've': + sigma_min = 0.01 + sigma_max = 50 + eps = 1e-5 + marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max) + sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max) + T = 1.0 + prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max) + elif sde_mode == 'vp': + beta_0 = 0.1 + beta_1 = 20 + eps = 1e-3 + prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1) + marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1) + sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1) + T = 1.0 + elif sde_mode == 'subvp': + beta_0 = 0.1 + beta_1 = 20 + eps = 1e-3 + prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1) + marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1) + sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1) + T = 1.0 + else: + raise NotImplementedError + return prior_fn, marginal_prob_fn, sde_fn, eps, T diff --git a/src/active_grasp/active_perception/modules/module_lib/__init__.py b/src/active_grasp/active_perception/modules/module_lib/__init__.py new file mode 100755 index 0000000..0c2f0c1 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/__init__.py @@ -0,0 +1,4 @@ +from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection +from modules.module_lib.linear import Linear +from modules.module_lib.position_embedding import PositionalEmbedding +from modules.module_lib.rot_head import RotHead diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/__init__.py new file mode 100755 index 0000000..ae847e4 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/__init__.py new file mode 100755 index 0000000..68e0830 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import pathlib + +from omegaconf import OmegaConf + + +def load_config(config_name: str): + config_filename = config_name + ".yaml" + return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) + + +dinov2_default_config = load_config("ssl_default_config") + + +def load_and_merge_config(config_name: str): + default_config = OmegaConf.create(dinov2_default_config) + loaded_config = load_config(config_name) + return OmegaConf.merge(default_config, loaded_config) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitb14_pretrain.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitb14_pretrain.yaml new file mode 100755 index 0000000..117d0f0 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitb14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_base + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitb14_reg4_pretrain.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitb14_reg4_pretrain.yaml new file mode 100755 index 0000000..d53edc0 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitb14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_base + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitg14_pretrain.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitg14_pretrain.yaml new file mode 100755 index 0000000..a96dd5b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitg14_pretrain.yaml @@ -0,0 +1,7 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitg14_reg4_pretrain.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitg14_reg4_pretrain.yaml new file mode 100755 index 0000000..15948f8 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitg14_reg4_pretrain.yaml @@ -0,0 +1,10 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitl14_pretrain.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitl14_pretrain.yaml new file mode 100755 index 0000000..7a98454 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitl14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_large + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitl14_reg4_pretrain.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitl14_reg4_pretrain.yaml new file mode 100755 index 0000000..0e2bc4e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vitl14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_large + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vits14_pretrain.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vits14_pretrain.yaml new file mode 100755 index 0000000..afbdb4b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vits14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_small + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vits14_reg4_pretrain.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vits14_reg4_pretrain.yaml new file mode 100755 index 0000000..d25fd63 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/eval/vits14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_small + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/ssl_default_config.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/ssl_default_config.yaml new file mode 100755 index 0000000..ccaae1c --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/ssl_default_config.yaml @@ -0,0 +1,118 @@ +MODEL: + WEIGHTS: '' +compute_precision: + grad_scaler: true + teacher: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + student: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp32 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp32 + buffer_dtype: fp32 +dino: + loss_weight: 1.0 + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_nlayers: 3 + head_hidden_dim: 2048 + koleo_loss_weight: 0.1 +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + separate_head: false + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_nlayers: 3 + head_hidden_dim: 2048 +train: + batch_size_per_gpu: 64 + dataset_path: ImageNet:split=TRAIN + output_dir: . + saveckp_freq: 20 + seed: 0 + num_workers: 10 + OFFICIAL_EPOCH_LENGTH: 1250 + cache_dataset: true + centering: "centering" # or "sinkhorn_knopp" +student: + arch: vit_large + patch_size: 16 + drop_path_rate: 0.3 + layerscale: 1.0e-05 + drop_path_uniform: true + pretrained_weights: '' + ffn_layer: "mlp" + block_chunks: 0 + qkv_bias: true + proj_bias: true + ffn_bias: true + num_register_tokens: 0 + interpolate_antialias: false + interpolate_offset: 0.1 +teacher: + momentum_teacher: 0.992 + final_momentum_teacher: 1 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 30 +optim: + epochs: 100 + weight_decay: 0.04 + weight_decay_end: 0.4 + base_lr: 0.004 # learning rate for a batch size of 1024 + lr: 0. # will be set after applying scaling rule + warmup_epochs: 10 + min_lr: 1.0e-06 + clip_grad: 3.0 + freeze_last_layer_epochs: 1 + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + layerwise_decay: 0.9 + adamw_beta1: 0.9 + adamw_beta2: 0.999 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 224 + local_crops_size: 96 +evaluation: + eval_period_iterations: 12500 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitg14.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitg14.yaml new file mode 100755 index 0000000..d05cf0d --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitg14.yaml @@ -0,0 +1,26 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +train: + batch_size_per_gpu: 12 + dataset_path: ImageNet22k + centering: sinkhorn_knopp +student: + arch: vit_giant2 + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitl14.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitl14.yaml new file mode 100755 index 0000000..d9b491d --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitl14.yaml @@ -0,0 +1,26 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +train: + batch_size_per_gpu: 32 + dataset_path: ImageNet22k + centering: sinkhorn_knopp +student: + arch: vit_large + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitl16_short.yaml b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitl16_short.yaml new file mode 100755 index 0000000..3e7e728 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/configs/train/vitl16_short.yaml @@ -0,0 +1,6 @@ +# this corresponds to the default config +train: + dataset_path: ImageNet:split=TRAIN + batch_size_per_gpu: 64 +student: + block_chunks: 4 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/__init__.py new file mode 100755 index 0000000..2ded47e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .adapters import DatasetWithEnumeratedTargets +from .loaders import make_data_loader, make_dataset, SamplerType +from .collate import collate_data_and_cast +from .masking import MaskingGenerator +from .augmentations import DataAugmentationDINO diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/adapters.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/adapters.py new file mode 100755 index 0000000..2097bad --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/adapters.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +from torch.utils.data import Dataset + + +class DatasetWithEnumeratedTargets(Dataset): + def __init__(self, dataset): + self._dataset = dataset + + def get_image_data(self, index: int) -> bytes: + return self._dataset.get_image_data(index) + + def get_target(self, index: int) -> Tuple[Any, int]: + target = self._dataset.get_target(index) + return (index, target) + + def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: + image, target = self._dataset[index] + target = index if target is None else target + return image, (index, target) + + def __len__(self) -> int: + return len(self._dataset) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/augmentations.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/augmentations.py new file mode 100755 index 0000000..05b1eaa --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/augmentations.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from torchvision import transforms + +from .transforms import ( + GaussianBlur, + make_normalize_transform, +) + + +logger = logging.getLogger("dinov2") + + +class DataAugmentationDINO(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + self.local_crops_number = local_crops_number + self.global_crops_size = global_crops_size + self.local_crops_size = local_crops_size + + logger.info("###################################") + logger.info("Using data augmentation parameters:") + logger.info(f"global_crops_scale: {global_crops_scale}") + logger.info(f"local_crops_scale: {local_crops_scale}") + logger.info(f"local_crops_number: {local_crops_number}") + logger.info(f"global_crops_size: {global_crops_size}") + logger.info(f"local_crops_size: {local_crops_size}") + logger.info("###################################") + + # random resized crop and flip + self.geometric_augmentation_global = transforms.Compose( + [ + transforms.RandomResizedCrop( + global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + + self.geometric_augmentation_local = transforms.Compose( + [ + transforms.RandomResizedCrop( + local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + + # color distorsions / blurring + color_jittering = transforms.Compose( + [ + transforms.RandomApply( + [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], + p=0.8, + ), + transforms.RandomGrayscale(p=0.2), + ] + ) + + global_transfo1_extra = GaussianBlur(p=1.0) + + global_transfo2_extra = transforms.Compose( + [ + GaussianBlur(p=0.1), + transforms.RandomSolarize(threshold=128, p=0.2), + ] + ) + + local_transfo_extra = GaussianBlur(p=0.5) + + # normalization + self.normalize = transforms.Compose( + [ + transforms.ToTensor(), + make_normalize_transform(), + ] + ) + + self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) + self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) + self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) + + def __call__(self, image): + output = {} + + # global crops: + im1_base = self.geometric_augmentation_global(image) + global_crop_1 = self.global_transfo1(im1_base) + + im2_base = self.geometric_augmentation_global(image) + global_crop_2 = self.global_transfo2(im2_base) + + output["global_crops"] = [global_crop_1, global_crop_2] + + # global crops for teacher: + output["global_crops_teacher"] = [global_crop_1, global_crop_2] + + # local crops: + local_crops = [ + self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) + ] + output["local_crops"] = local_crops + output["offsets"] = () + + return output diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/collate.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/collate.py new file mode 100755 index 0000000..b3e32f3 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/collate.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import random + + +def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): + # dtype = torch.half # TODO: Remove + + n_global_crops = len(samples_list[0][0]["global_crops"]) + n_local_crops = len(samples_list[0][0]["local_crops"]) + + collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) + + collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) + + B = len(collated_global_crops) + N = n_tokens + n_samples_masked = int(B * mask_probability) + probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) + upperbound = 0 + masks_list = [] + for i in range(0, n_samples_masked): + prob_min = probs[i] + prob_max = probs[i + 1] + masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) + upperbound += int(N * prob_max) + for i in range(n_samples_masked, B): + masks_list.append(torch.BoolTensor(mask_generator(0))) + + random.shuffle(masks_list) + + collated_masks = torch.stack(masks_list).flatten(1) + mask_indices_list = collated_masks.flatten().nonzero().flatten() + + masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] + + return { + "collated_global_crops": collated_global_crops.to(dtype), + "collated_local_crops": collated_local_crops.to(dtype), + "collated_masks": collated_masks, + "mask_indices_list": mask_indices_list, + "masks_weight": masks_weight, + "upperbound": upperbound, + "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), + } diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/__init__.py new file mode 100755 index 0000000..5550fdc --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .image_net import ImageNet +from .image_net_22k import ImageNet22k diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/decoders.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/decoders.py new file mode 100755 index 0000000..3769f77 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/decoders.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from io import BytesIO +from typing import Any + +from PIL import Image + + +class Decoder: + def decode(self) -> Any: + raise NotImplementedError + + +class ImageDataDecoder(Decoder): + def __init__(self, image_data: bytes) -> None: + self._image_data = image_data + + def decode(self) -> Image: + f = BytesIO(self._image_data) + return Image.open(f).convert(mode="RGB") + + +class TargetDecoder(Decoder): + def __init__(self, target: Any): + self._target = target + + def decode(self) -> Any: + return self._target diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/extended.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/extended.py new file mode 100755 index 0000000..f60b619 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/extended.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +from torchvision.datasets import VisionDataset + +from .decoders import TargetDecoder, ImageDataDecoder + + +class ExtendedVisionDataset(VisionDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # type: ignore + + def get_image_data(self, index: int) -> bytes: + raise NotImplementedError + + def get_target(self, index: int) -> Any: + raise NotImplementedError + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + try: + image_data = self.get_image_data(index) + image = ImageDataDecoder(image_data).decode() + except Exception as e: + raise RuntimeError(f"can not read image for sample {index}") from e + target = self.get_target(index) + target = TargetDecoder(target).decode() + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + raise NotImplementedError diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/image_net.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/image_net.py new file mode 100755 index 0000000..8d08446 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/image_net.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np + +from .extended import ExtendedVisionDataset + + +logger = logging.getLogger("dinov2") +_Target = int + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + TEST = "test" # NOTE: torchvision does not support the test split + + @property + def length(self) -> int: + split_lengths = { + _Split.TRAIN: 1_281_167, + _Split.VAL: 50_000, + _Split.TEST: 100_000, + } + return split_lengths[self] + + def get_dirname(self, class_id: Optional[str] = None) -> str: + return self.value if class_id is None else os.path.join(self.value, class_id) + + def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str: + dirname = self.get_dirname(class_id) + if self == _Split.TRAIN: + basename = f"{class_id}_{actual_index}" + else: # self in (_Split.VAL, _Split.TEST): + basename = f"ILSVRC2012_{self.value}_{actual_index:08d}" + return os.path.join(dirname, basename + ".JPEG") + + def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]: + assert self != _Split.TEST + dirname, filename = os.path.split(image_relpath) + class_id = os.path.split(dirname)[-1] + basename, _ = os.path.splitext(filename) + actual_index = int(basename.split("_")[-1]) + return class_id, actual_index + + +class ImageNet(ExtendedVisionDataset): + Target = Union[_Target] + Split = Union[_Split] + + def __init__( + self, + *, + split: "ImageNet.Split", + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + self._split = split + + self._entries = None + self._class_ids = None + self._class_names = None + + @property + def split(self) -> "ImageNet.Split": + return self._split + + def _get_extra_full_path(self, extra_path: str) -> str: + return os.path.join(self._extra_root, extra_path) + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_full_path = self._get_extra_full_path(extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_full_path = self._get_extra_full_path(extra_path) + os.makedirs(self._extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _entries_path(self) -> str: + return f"entries-{self._split.value.upper()}.npy" + + @property + def _class_ids_path(self) -> str: + return f"class-ids-{self._split.value.upper()}.npy" + + @property + def _class_names_path(self) -> str: + return f"class-names-{self._split.value.upper()}.npy" + + def _get_entries(self) -> np.ndarray: + if self._entries is None: + self._entries = self._load_extra(self._entries_path) + assert self._entries is not None + return self._entries + + def _get_class_ids(self) -> np.ndarray: + if self._split == _Split.TEST: + assert False, "Class IDs are not available in TEST split" + if self._class_ids is None: + self._class_ids = self._load_extra(self._class_ids_path) + assert self._class_ids is not None + return self._class_ids + + def _get_class_names(self) -> np.ndarray: + if self._split == _Split.TEST: + assert False, "Class names are not available in TEST split" + if self._class_names is None: + self._class_names = self._load_extra(self._class_names_path) + assert self._class_names is not None + return self._class_names + + def find_class_id(self, class_index: int) -> str: + class_ids = self._get_class_ids() + return str(class_ids[class_index]) + + def find_class_name(self, class_index: int) -> str: + class_names = self._get_class_names() + return str(class_names[class_index]) + + def get_image_data(self, index: int) -> bytes: + entries = self._get_entries() + actual_index = entries[index]["actual_index"] + + class_id = self.get_class_id(index) + + image_relpath = self.split.get_image_relpath(actual_index, class_id) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Optional[Target]: + entries = self._get_entries() + class_index = entries[index]["class_index"] + return None if self.split == _Split.TEST else int(class_index) + + def get_targets(self) -> Optional[np.ndarray]: + entries = self._get_entries() + return None if self.split == _Split.TEST else entries["class_index"] + + def get_class_id(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_id = entries[index]["class_id"] + return None if self.split == _Split.TEST else str(class_id) + + def get_class_name(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_name = entries[index]["class_name"] + return None if self.split == _Split.TEST else str(class_name) + + def __len__(self) -> int: + entries = self._get_entries() + assert len(entries) == self.split.length + return len(entries) + + def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]: + labels_full_path = os.path.join(self.root, labels_path) + labels = [] + + try: + with open(labels_full_path, "r") as f: + reader = csv.reader(f) + for row in reader: + class_id, class_name = row + labels.append((class_id, class_name)) + except OSError as e: + raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e + + return labels + + def _dump_entries(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + dataset = None + sample_count = split.length + max_class_id_length, max_class_name_length = 0, 0 + else: + labels_path = "labels.txt" + logger.info(f'loading labels from "{labels_path}"') + labels = self._load_labels(labels_path) + + # NOTE: Using torchvision ImageFolder for consistency + from torchvision.datasets import ImageFolder + + dataset_root = os.path.join(self.root, split.get_dirname()) + dataset = ImageFolder(dataset_root) + sample_count = len(dataset) + max_class_id_length, max_class_name_length = -1, -1 + for sample in dataset.samples: + _, class_index = sample + class_id, class_name = labels[class_index] + max_class_id_length = max(len(class_id), max_class_id_length) + max_class_name_length = max(len(class_name), max_class_name_length) + + dtype = np.dtype( + [ + ("actual_index", " old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + actual_index = index + 1 + class_index = np.uint32(-1) + class_id, class_name = "", "" + entries_array[index] = (actual_index, class_index, class_id, class_name) + else: + class_names = {class_id: class_name for class_id, class_name in labels} + + assert dataset + old_percent = -1 + for index in range(sample_count): + percent = 100 * (index + 1) // sample_count + if percent > old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + image_full_path, class_index = dataset.samples[index] + image_relpath = os.path.relpath(image_full_path, self.root) + class_id, actual_index = split.parse_image_relpath(image_relpath) + class_name = class_names[class_id] + entries_array[index] = (actual_index, class_index, class_id, class_name) + + logger.info(f'saving entries to "{self._entries_path}"') + self._save_extra(entries_array, self._entries_path) + + def _dump_class_ids_and_names(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + return + + entries_array = self._load_extra(self._entries_path) + + max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1 + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + max_class_name_length = max(len(str(class_name)), max_class_name_length) + + class_count = max_class_index + 1 + class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}") + class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}") + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + class_ids_array[class_index] = class_id + class_names_array[class_index] = class_name + + logger.info(f'saving class IDs to "{self._class_ids_path}"') + self._save_extra(class_ids_array, self._class_ids_path) + + logger.info(f'saving class names to "{self._class_names_path}"') + self._save_extra(class_names_array, self._class_names_path) + + def dump_extra(self) -> None: + self._dump_entries() + self._dump_class_ids_and_names() diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/image_net_22k.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/image_net_22k.py new file mode 100755 index 0000000..52b36a2 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/datasets/image_net_22k.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache +from gzip import GzipFile +from io import BytesIO +from mmap import ACCESS_READ, mmap +import os +from typing import Any, Callable, List, Optional, Set, Tuple +import warnings + +import numpy as np + +from .extended import ExtendedVisionDataset + + +_Labels = int + +_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors + + +@dataclass +class _ClassEntry: + block_offset: int + maybe_filename: Optional[str] = None + + +@dataclass +class _Entry: + class_index: int # noqa: E701 + start_offset: int + end_offset: int + filename: str + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + + @property + def length(self) -> int: + return { + _Split.TRAIN: 11_797_647, + _Split.VAL: 561_050, + }[self] + + def entries_path(self): + return f"imagenet21kp_{self.value}.txt" + + +def _get_tarball_path(class_id: str) -> str: + return f"{class_id}.tar" + + +def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int): + @lru_cache(maxsize=mmap_cache_size) + def _mmap_tarball(class_id: str) -> mmap: + tarball_path = _get_tarball_path(class_id) + tarball_full_path = os.path.join(tarballs_root, tarball_path) + with open(tarball_full_path) as f: + return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) + + return _mmap_tarball + + +class ImageNet22k(ExtendedVisionDataset): + _GZIPPED_INDICES: Set[int] = { + 841_545, + 1_304_131, + 2_437_921, + 2_672_079, + 2_795_676, + 2_969_786, + 6_902_965, + 6_903_550, + 6_903_628, + 7_432_557, + 7_432_589, + 7_813_809, + 8_329_633, + 10_296_990, + 10_417_652, + 10_492_265, + 10_598_078, + 10_782_398, + 10_902_612, + 11_203_736, + 11_342_890, + 11_397_596, + 11_589_762, + 11_705_103, + 12_936_875, + 13_289_782, + } + Labels = _Labels + + def __init__( + self, + *, + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + + entries_path = self._get_entries_path(root) + self._entries = self._load_extra(entries_path) + + class_ids_path = self._get_class_ids_path(root) + self._class_ids = self._load_extra(class_ids_path) + + self._gzipped_indices = ImageNet22k._GZIPPED_INDICES + self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size) + + def _get_entries_path(self, root: Optional[str] = None) -> str: + return "entries.npy" + + def _get_class_ids_path(self, root: Optional[str] = None) -> str: + return "class-ids.npy" + + def _find_class_ids(self, path: str) -> List[str]: + class_ids = [] + + with os.scandir(path) as entries: + for entry in entries: + root, ext = os.path.splitext(entry.name) + if ext != ".tar": + continue + class_ids.append(root) + + return sorted(class_ids) + + def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]: + root = self.get_root(root) + entries: List[_Entry] = [] + class_ids = self._find_class_ids(root) + + for class_index, class_id in enumerate(class_ids): + path = os.path.join(root, "blocks", f"{class_id}.log") + class_entries = [] + + try: + with open(path) as f: + for line in f: + line = line.rstrip() + block, filename = line.split(":") + block_offset = int(block[6:]) + filename = filename[1:] + + maybe_filename = None + if filename != "** Block of NULs **": + maybe_filename = filename + _, ext = os.path.splitext(filename) + # assert ext == ".JPEG" + + class_entry = _ClassEntry(block_offset, maybe_filename) + class_entries.append(class_entry) + except OSError as e: + raise RuntimeError(f'can not read blocks file "{path}"') from e + + assert class_entries[-1].maybe_filename is None + + for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]): + assert class_entry1.block_offset <= class_entry2.block_offset + start_offset = 512 * class_entry1.block_offset + end_offset = 512 * class_entry2.block_offset + assert class_entry1.maybe_filename is not None + filename = class_entry1.maybe_filename + entry = _Entry(class_index, start_offset, end_offset, filename) + # Skip invalid image files (PIL throws UnidentifiedImageError) + if filename == "n06470073_47249.JPEG": + continue + entries.append(entry) + + return entries, class_ids + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + os.makedirs(extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _tarballs_root(self) -> str: + return self.root + + def find_class_id(self, class_index: int) -> str: + return str(self._class_ids[class_index]) + + def get_image_data(self, index: int) -> bytes: + entry = self._entries[index] + class_id = entry["class_id"] + class_mmap = self._mmap_tarball(class_id) + + start_offset, end_offset = entry["start_offset"], entry["end_offset"] + try: + mapped_data = class_mmap[start_offset:end_offset] + data = mapped_data[512:] # Skip entry header block + + if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B): + assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}" + with GzipFile(fileobj=BytesIO(data)) as g: + data = g.read() + except Exception as e: + raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e + + return data + + def get_target(self, index: int) -> Any: + return int(self._entries[index]["class_index"]) + + def get_targets(self) -> np.ndarray: + return self._entries["class_index"] + + def get_class_id(self, index: int) -> str: + return str(self._entries[index]["class_id"]) + + def get_class_ids(self) -> np.ndarray: + return self._entries["class_id"] + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return super().__getitem__(index) + + def __len__(self) -> int: + return len(self._entries) + + def _dump_entries(self, *args, **kwargs) -> None: + entries, class_ids = self._load_entries_class_ids(*args, **kwargs) + + max_class_id_length, max_filename_length, max_class_index = -1, -1, -1 + for entry in entries: + class_id = class_ids[entry.class_index] + max_class_index = max(entry.class_index, max_class_index) + max_class_id_length = max(len(class_id), max_class_id_length) + max_filename_length = max(len(entry.filename), max_filename_length) + + dtype = np.dtype( + [ + ("class_index", " None: + entries_path = self._get_entries_path(*args, **kwargs) + entries_array = self._load_extra(entries_path) + + max_class_id_length, max_class_index = -1, -1 + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + + class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}") + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + class_ids_array[class_index] = class_id + class_ids_path = self._get_class_ids_path(*args, **kwargs) + self._save_extra(class_ids_array, class_ids_path) + + def _dump_extra(self, *args, **kwargs) -> None: + self._dump_entries(*args, *kwargs) + self._dump_class_ids(*args, *kwargs) + + def dump_extra(self, root: Optional[str] = None) -> None: + return self._dump_extra(root) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/loaders.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/loaders.py new file mode 100755 index 0000000..d6a2f02 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/loaders.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +from enum import Enum +from typing import Any, Callable, List, Optional, TypeVar + +import torch +from torch.utils.data import Sampler + +from .datasets import ImageNet, ImageNet22k +from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler + + +logger = logging.getLogger("dinov2") + + +class SamplerType(Enum): + DISTRIBUTED = 0 + EPOCH = 1 + INFINITE = 2 + SHARDED_INFINITE = 3 + SHARDED_INFINITE_NEW = 4 + + +def _make_bool_str(b: bool) -> str: + return "yes" if b else "no" + + +def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): + def transform(sample): + image, target = sample + if image_transform is not None: + image = image_transform(image) + if target_transform is not None: + target = target_transform(target) + return image, target + + return transform + + +def _parse_dataset_str(dataset_str: str): + tokens = dataset_str.split(":") + + name = tokens[0] + kwargs = {} + + for token in tokens[1:]: + key, value = token.split("=") + assert key in ("root", "extra", "split") + kwargs[key] = value + + if name == "ImageNet": + class_ = ImageNet + if "split" in kwargs: + kwargs["split"] = ImageNet.Split[kwargs["split"]] + elif name == "ImageNet22k": + class_ = ImageNet22k + else: + raise ValueError(f'Unsupported dataset "{name}"') + + return class_, kwargs + + +def make_dataset( + *, + dataset_str: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, +): + """ + Creates a dataset with the specified parameters. + + Args: + dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). + transform: A transform to apply to images. + target_transform: A transform to apply to targets. + + Returns: + The created dataset. + """ + logger.info(f'using dataset: "{dataset_str}"') + + class_, kwargs = _parse_dataset_str(dataset_str) + dataset = class_(transform=transform, target_transform=target_transform, **kwargs) + + logger.info(f"# of dataset samples: {len(dataset):,d}") + + # Aggregated datasets do not expose (yet) these attributes, so add them. + if not hasattr(dataset, "transform"): + setattr(dataset, "transform", transform) + if not hasattr(dataset, "target_transform"): + setattr(dataset, "target_transform", target_transform) + + return dataset + + +def _make_sampler( + *, + dataset, + type: Optional[SamplerType] = None, + shuffle: bool = False, + seed: int = 0, + size: int = -1, + advance: int = 0, +) -> Optional[Sampler]: + sample_count = len(dataset) + + if type == SamplerType.INFINITE: + logger.info("sampler: infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + return InfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + ) + elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): + logger.info("sampler: sharded infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + # TODO: Remove support for old shuffling + use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW + return ShardedInfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, + ) + elif type == SamplerType.EPOCH: + logger.info("sampler: epoch") + if advance > 0: + raise NotImplementedError("sampler advance > 0 is not supported") + size = size if size > 0 else sample_count + logger.info(f"# of samples / epoch: {size:,d}") + return EpochSampler( + size=size, + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + ) + elif type == SamplerType.DISTRIBUTED: + logger.info("sampler: distributed") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + if advance > 0: + raise ValueError("sampler advance > 0 is invalid") + return torch.utils.data.DistributedSampler( + dataset=dataset, + shuffle=shuffle, + seed=seed, + drop_last=False, + ) + + logger.info("sampler: none") + return None + + +T = TypeVar("T") + + +def make_data_loader( + *, + dataset, + batch_size: int, + num_workers: int, + shuffle: bool = True, + seed: int = 0, + sampler_type: Optional[SamplerType] = SamplerType.INFINITE, + sampler_size: int = -1, + sampler_advance: int = 0, + drop_last: bool = True, + persistent_workers: bool = False, + collate_fn: Optional[Callable[[List[T]], Any]] = None, +): + """ + Creates a data loader with the specified parameters. + + Args: + dataset: A dataset (third party, LaViDa or WebDataset). + batch_size: The size of batches to generate. + num_workers: The number of workers to use. + shuffle: Whether to shuffle samples. + seed: The random seed to use. + sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. + sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. + sampler_advance: How many samples to skip (when applicable). + drop_last: Whether the last non-full batch of data should be dropped. + persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. + collate_fn: Function that performs batch collation + """ + + sampler = _make_sampler( + dataset=dataset, + type=sampler_type, + shuffle=shuffle, + seed=seed, + size=sampler_size, + advance=sampler_advance, + ) + + logger.info("using PyTorch data loader") + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + drop_last=drop_last, + persistent_workers=persistent_workers, + collate_fn=collate_fn, + ) + + try: + logger.info(f"# of batches: {len(data_loader):,d}") + except TypeError: # data loader has no length + logger.info("infinite data loader") + return data_loader diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/masking.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/masking.py new file mode 100755 index 0000000..ab12aa7 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/masking.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import random +import math +import numpy as np + + +class MaskingGenerator: + def __init__( + self, + input_size, + num_masking_patches=None, + min_num_patches=4, + max_num_patches=None, + min_aspect=0.3, + max_aspect=None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.num_masking_patches = num_masking_patches + + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def __repr__(self): + repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.min_num_patches, + self.max_num_patches, + self.num_masking_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _ in range(10): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self, num_masking_patches=0): + mask = np.zeros(shape=self.get_shape(), dtype=bool) + mask_count = 0 + while mask_count < num_masking_patches: + max_mask_patches = num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/samplers.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/samplers.py new file mode 100755 index 0000000..6562197 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/samplers.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +from typing import Any, Optional +import warnings + +import numpy as np +import torch +from torch.utils.data.sampler import Sampler + +import dinov2.distributed as distributed + + +class EpochSampler(Sampler): + def __init__( + self, + *, + size: int, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + ): + self._size = size + self._sample_count = sample_count + self._shuffle = shuffle + self._seed = seed + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._epoch = 0 + + def __iter__(self): + count = (self._size + self._sample_count - 1) // self._sample_count + tiled_indices = np.tile(np.arange(self._sample_count), count) + if self._shuffle: + seed = self._seed * self._epoch if self._seed != 0 else self._epoch + rng = np.random.default_rng(seed) + iterable = rng.choice(tiled_indices, self._size, replace=False) + else: + iterable = tiled_indices[: self._size] + + yield from itertools.islice(iterable, self._start, None, self._step) + + def __len__(self): + return (self._size - self._start + self._step - 1) // self._step + + def set_epoch(self, epoch): + self._epoch = epoch + + +def _get_numpy_dtype(size: int) -> Any: + return np.int32 if size <= 2**31 else np.int64 + + +def _get_torch_dtype(size: int) -> Any: + return torch.int32 if size <= 2**31 else torch.int64 + + +def _generate_randperm_indices(*, size: int, generator: torch.Generator): + """Generate the indices of a random permutation.""" + dtype = _get_torch_dtype(size) + # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 + perm = torch.arange(size, dtype=dtype) + for i in range(size): + j = torch.randint(i, size, size=(1,), generator=generator).item() + + # Always swap even if no-op + value = perm[j].item() + perm[j] = perm[i].item() + perm[i] = value + yield value + + +class InfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._advance = advance + + def __iter__(self): + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator().manual_seed(self._seed) + + while True: + iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) + yield from itertools.islice(iterable, self._start, None, self._step) + + +# The following function is somewhat equivalent to _new_shuffle_tensor_slice below, +# but avoids a full in-place random permutation generation. +def _shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + + dtype = _get_numpy_dtype(stop) + result = np.empty(count, dtype=dtype) + + for i in range(count): + j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 + + result[i] = result[j] + result[j] = tensor[start + i * step].item() + + return result + + +def _new_shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + dtype = torch.int64 # Needed for using randperm result as indices + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + indices = torch.randperm(count, dtype=dtype, generator=generator) + return tensor[start::step][indices].numpy() + + +def _make_seed(seed: int, start: int, iter_count: int) -> int: + # NOTE: Tried a few variants (including iter_count << 32), this one worked best. + return seed + start + (iter_count << 24) + + +class ShardedInfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + use_new_shuffle_tensor_slice: bool = False, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._advance = advance + self._iter_count = 0 + self._shuffle_tensor_slice_fn = ( + _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice + ) + + def __iter__(self): + iter_count = self._advance // self._sample_count + if iter_count > 0: + self._advance -= iter_count * self._sample_count + self._iter_count += iter_count + + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to be keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator() + + # Always shuffle everything first + generator.manual_seed(self._seed) + dtype = _get_torch_dtype(self._sample_count) + perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) + + while True: + # Re-seed on each iteration to allow skipping whole permutations + seed = _make_seed(self._seed, self._start, self._iter_count) + generator.manual_seed(seed) + + iterable = self._shuffle_tensor_slice_fn( + tensor=perm, start=self._start, step=self._step, generator=generator + ) + yield from iterable + self._iter_count += 1 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/transforms.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/transforms.py new file mode 100755 index 0000000..eb5f252 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/data/transforms.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Sequence + +import torch +from torchvision import transforms + + +class GaussianBlur(transforms.RandomApply): + """ + Apply Gaussian Blur to the PIL image. + """ + + def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): + # NOTE: torchvision is applying 1 - probability to return the original image + keep_p = 1 - p + transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) + super().__init__(transforms=[transform], p=keep_p) + + +class MaybeToTensor(transforms.ToTensor): + """ + Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + if isinstance(pic, torch.Tensor): + return pic + return super().__call__(pic) + + +# Use timm's names +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +def make_normalize_transform( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Normalize: + return transforms.Normalize(mean=mean, std=std) + + +# This roughly matches torchvision's preset for classification training: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 +def make_classification_train_transform( + *, + crop_size: int = 224, + interpolation=transforms.InterpolationMode.BICUBIC, + hflip_prob: float = 0.5, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +): + transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0.0: + transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) + transforms_list.extend( + [ + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + ) + return transforms.Compose(transforms_list) + + +# This matches (roughly) torchvision's preset for classification evaluation: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 +def make_classification_eval_transform( + *, + resize_size: int = 256, + interpolation=transforms.InterpolationMode.BICUBIC, + crop_size: int = 224, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Compose: + transforms_list = [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + return transforms.Compose(transforms_list) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/distributed/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/distributed/__init__.py new file mode 100755 index 0000000..23226f4 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/distributed/__init__.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +import random +import re +import socket +from typing import Dict, List + +import torch +import torch.distributed as dist + +_LOCAL_RANK = -1 +_LOCAL_WORLD_SIZE = -1 + + +def is_enabled() -> bool: + """ + Returns: + True if distributed training is enabled + """ + return dist.is_available() and dist.is_initialized() + + +def get_global_size() -> int: + """ + Returns: + The number of processes in the process group + """ + return dist.get_world_size() if is_enabled() else 1 + + +def get_global_rank() -> int: + """ + Returns: + The rank of the current process within the global process group. + """ + return dist.get_rank() if is_enabled() else 0 + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not is_enabled(): + return 0 + assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE + return _LOCAL_RANK + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not is_enabled(): + return 1 + assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE + return _LOCAL_WORLD_SIZE + + +def is_main_process() -> bool: + """ + Returns: + True if the current process is the main one. + """ + return get_global_rank() == 0 + + +def _restrict_print_to_main_process() -> None: + """ + This function disables printing when not in the main process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_main_process() or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def _get_master_port(seed: int = 0) -> int: + MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) + + master_port_str = os.environ.get("MASTER_PORT") + if master_port_str is None: + rng = random.Random(seed) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + return int(master_port_str) + + +def _get_available_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # A "" host address means INADDR_ANY i.e. binding to all interfaces. + # Note this is not compatible with IPv6. + s.bind(("", 0)) + port = s.getsockname()[1] + return port + + +_TORCH_DISTRIBUTED_ENV_VARS = ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", +) + + +def _collect_env_vars() -> Dict[str, str]: + return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ} + + +def _is_slurm_job_process() -> bool: + return "SLURM_JOB_ID" in os.environ + + +def _parse_slurm_node_list(s: str) -> List[str]: + nodes = [] + # Extract "hostname", "hostname[1-2,3,4-5]," substrings + p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") + for m in p.finditer(s): + prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] + for suffix in suffixes.split(","): + span = suffix.split("-") + if len(span) == 1: + nodes.append(prefix + suffix) + else: + width = len(span[0]) + start, end = int(span[0]), int(span[1]) + 1 + nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) + return nodes + + +def _check_env_variable(key: str, new_value: str): + # Only check for difference with preset environment variables + if key in os.environ and os.environ[key] != new_value: + raise RuntimeError(f"Cannot export environment variables as {key} is already set") + + +class _TorchDistributedEnvironment: + def __init__(self): + self.master_addr = "127.0.0.1" + self.master_port = 0 + self.rank = -1 + self.world_size = -1 + self.local_rank = -1 + self.local_world_size = -1 + + if _is_slurm_job_process(): + return self._set_from_slurm_env() + + env_vars = _collect_env_vars() + if not env_vars: + # Environment is not set + pass + elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): + # Environment is fully set + return self._set_from_preset_env() + else: + # Environment is partially set + collected_env_vars = ", ".join(env_vars.keys()) + raise RuntimeError(f"Partially set environment: {collected_env_vars}") + + if torch.cuda.device_count() > 0: + return self._set_from_local() + + raise RuntimeError("Can't initialize PyTorch distributed environment") + + # Slurm job created with sbatch, submitit, etc... + def _set_from_slurm_env(self): + # logger.info("Initialization from Slurm environment") + job_id = int(os.environ["SLURM_JOB_ID"]) + node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) + nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) + assert len(nodes) == node_count + + self.master_addr = nodes[0] + self.master_port = _get_master_port(seed=job_id) + self.rank = int(os.environ["SLURM_PROCID"]) + self.world_size = int(os.environ["SLURM_NTASKS"]) + assert self.rank < self.world_size + self.local_rank = int(os.environ["SLURM_LOCALID"]) + self.local_world_size = self.world_size // node_count + assert self.local_rank < self.local_world_size + + # Single node job with preset environment (i.e. torchrun) + def _set_from_preset_env(self): + # logger.info("Initialization from preset environment") + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = os.environ["MASTER_PORT"] + self.rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + assert self.rank < self.world_size + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + assert self.local_rank < self.local_world_size + + # Single node and GPU job (i.e. local script run) + def _set_from_local(self): + # logger.info("Initialization from local") + self.master_addr = "127.0.0.1" + self.master_port = _get_available_port() + self.rank = 0 + self.world_size = 1 + self.local_rank = 0 + self.local_world_size = 1 + + def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": + # See the "Environment variable initialization" section from + # https://pytorch.org/docs/stable/distributed.html for the complete list of + # environment variables required for the env:// initialization method. + env_vars = { + "MASTER_ADDR": self.master_addr, + "MASTER_PORT": str(self.master_port), + "RANK": str(self.rank), + "WORLD_SIZE": str(self.world_size), + "LOCAL_RANK": str(self.local_rank), + "LOCAL_WORLD_SIZE": str(self.local_world_size), + } + if not overwrite: + for k, v in env_vars.items(): + _check_env_variable(k, v) + + os.environ.update(env_vars) + return self + + +def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False): + """Enable distributed mode + + Args: + set_cuda_current_device: If True, call torch.cuda.set_device() to set the + current PyTorch CUDA device to the one matching the local rank. + overwrite: If True, overwrites already set variables. Else fails. + """ + + global _LOCAL_RANK, _LOCAL_WORLD_SIZE + if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: + raise RuntimeError("Distributed mode has already been enabled") + torch_env = _TorchDistributedEnvironment() + torch_env.export(overwrite=overwrite) + + if set_cuda_current_device: + torch.cuda.set_device(torch_env.local_rank) + + if allow_nccl_timeout: + # This allows to use torch distributed timeout in a NCCL backend + key, value = "NCCL_ASYNC_ERROR_HANDLING", "1" + if not overwrite: + _check_env_variable(key, value) + os.environ[key] = value + + dist.init_process_group(backend="nccl") + dist.barrier() + + # Finalize setup + _LOCAL_RANK = torch_env.local_rank + _LOCAL_WORLD_SIZE = torch_env.local_world_size + _restrict_print_to_main_process() diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/__init__.py new file mode 100755 index 0000000..b88da6b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/__init__.py new file mode 100755 index 0000000..b88da6b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/__init__.py new file mode 100755 index 0000000..9a58251 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss +from .decode_heads import * # noqa: F403 +from .depther import * # noqa: F403 +from .losses import * # noqa: F403 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/backbones/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/backbones/__init__.py new file mode 100755 index 0000000..520d75b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vision_transformer import DinoVisionTransformer diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/backbones/vision_transformer.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/backbones/vision_transformer.py new file mode 100755 index 0000000..69bda46 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/backbones/vision_transformer.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.runner import BaseModule + +from ..builder import BACKBONES + + +@BACKBONES.register_module() +class DinoVisionTransformer(BaseModule): + """Vision Transformer.""" + + def __init__(self, *args, **kwargs): + super().__init__() diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/builder.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/builder.py new file mode 100755 index 0000000..c152643 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/builder.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION +from mmcv.utils import Registry + +MODELS = Registry("models", parent=MMCV_MODELS) +ATTENTION = Registry("attention", parent=MMCV_ATTENTION) + + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +DEPTHER = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_depther(cfg, train_cfg=None, test_cfg=None): + """Build depther.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning) + assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field " + assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field " + return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/__init__.py new file mode 100755 index 0000000..bd0f075 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dpt_head import DPTHead +from .linear_head import BNHead diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/decode_head.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/decode_head.py new file mode 100755 index 0000000..f8c867a --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/decode_head.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from abc import ABCMeta, abstractmethod + +import mmcv +import numpy as np +import torch +import torch.nn as nn +from mmcv.runner import BaseModule, auto_fp16, force_fp32 + +from ...ops import resize +from ..builder import build_loss + + +class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_cfg (dict|None): Config of conv layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + loss_decode (dict): Config of decode loss. + Default: dict(type='SigLoss'). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_cfg (dict|None): Config of norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + channels=96, + conv_cfg=None, + act_cfg=dict(type="ReLU"), + loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_cfg=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.act_cfg = act_cfg + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_cfg = norm_cfg + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.fp16_enabled = False + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def extra_repr(self): + """Extra repr.""" + s = f"align_corners={self.align_corners}" + return s + + @auto_fp16() + @abstractmethod + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + @force_fp32(apply_to=("depth_pred",)) + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = mmcv.imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/dpt_head.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/dpt_head.py new file mode 100755 index 0000000..c6c6d94 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/dpt_head.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Linear, build_activation_layer +from mmcv.runner import BaseModule + +from ...ops import resize +from ..builder import HEADS +from .decode_head import DepthBaseDecodeHead + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(BaseModule): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__( + self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None + ): + super(ReassembleBlocks, self).__init__(init_cfg) + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU"))) + ) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(BaseModule): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_cfg (dict): dictionary to construct and config activation layer. + norm_cfg (dict): dictionary to construct and config norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None): + super(PreActResidualConvUnit, self).__init__(init_cfg) + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(BaseModule): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_cfg (dict): The activation config for ResidualConvUnit. + norm_cfg (dict): Config dict for normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None): + super(FeatureFusionBlock, self).__init__(init_cfg) + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +@HEADS.register_module() +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/linear_head.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/linear_head.py new file mode 100755 index 0000000..3da1436 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/decode_heads/linear_head.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...ops import resize +from ..builder import HEADS +from .decode_head import DepthBaseDecodeHead + + +@HEADS.register_module() +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + + return output diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/__init__.py new file mode 100755 index 0000000..be99743 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .base import BaseDepther +from .encoder_decoder import DepthEncoderDecoder diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/base.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/base.py new file mode 100755 index 0000000..e133a82 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/base.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import torch +import torch.distributed as dist +from mmcv.runner import BaseModule, auto_fp16 + + +class BaseDepther(BaseModule, metaclass=ABCMeta): + """Base class for depther.""" + + def __init__(self, init_cfg=None): + super(BaseDepther, self).__init__(init_cfg) + self.fp16_enabled = False + + @property + def with_neck(self): + """bool: whether the depther has neck""" + return hasattr(self, "neck") and self.neck is not None + + @property + def with_auxiliary_head(self): + """bool: whether the depther has auxiliary head""" + return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None + + @property + def with_decode_head(self): + """bool: whether the depther has decode head""" + return hasattr(self, "decode_head") and self.decode_head is not None + + @abstractmethod + def extract_feat(self, imgs): + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, img, img_metas): + """Placeholder for encode images with backbone and decode into a + semantic depth map of the same size as input.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """Placeholder for Forward function for training.""" + pass + + @abstractmethod + def simple_test(self, img, img_meta, **kwargs): + """Placeholder for single image test.""" + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Placeholder for augmentation test.""" + pass + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=("img",)) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/encoder_decoder.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/encoder_decoder.py new file mode 100755 index 0000000..6b0ec2d --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/depther/encoder_decoder.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from ...models import builder +from ...models.builder import DEPTHER +from ...ops import resize +from .base import BaseDepther + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +@DEPTHER.register_module() +class DepthEncoderDecoder(BaseDepther): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone, (neck) and decode_head. + """ + + def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): + super(DepthEncoderDecoder, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight" + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + self._init_decode_head(decode_head) + + if neck is not None: + self.neck = builder.build_neck(neck) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + depth_pred = self.encode_decode(img, img_meta, rescale, size=size) + + return depth_pred + + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert self.test_cfg.mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if self.test_cfg.mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/__init__.py new file mode 100755 index 0000000..2f86242 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .gradientloss import GradientLoss +from .sigloss import SigLoss diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/gradientloss.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/gradientloss.py new file mode 100755 index 0000000..1599878 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/gradientloss.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...models.builder import LOSSES + + +@LOSSES.register_module() +class GradientLoss(nn.Module): + """GradientLoss. + + Adapted from https://www.cs.cornell.edu/projects/megadepth/ + + Args: + valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. + loss_weight (float): Weight of the loss. Default: 1.0. + max_depth (int): When filtering invalid gt, set a max threshold. Default: None. + """ + + def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"): + super(GradientLoss, self).__init__() + self.valid_mask = valid_mask + self.loss_weight = loss_weight + self.max_depth = max_depth + self.loss_name = loss_name + + self.eps = 0.001 # avoid grad explode + + def gradientloss(self, input, target): + input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)] + target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)] + + gradient_loss = 0 + for input, target in zip(input_downscaled, target_downscaled): + if self.valid_mask: + mask = target > 0 + if self.max_depth is not None: + mask = torch.logical_and(target > 0, target <= self.max_depth) + N = torch.sum(mask) + else: + mask = torch.ones_like(target) + N = input.numel() + input_log = torch.log(input + self.eps) + target_log = torch.log(target + self.eps) + log_d_diff = input_log - target_log + + log_d_diff = torch.mul(log_d_diff, mask) + + v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :]) + v_mask = torch.mul(mask[0:-2, :], mask[2:, :]) + v_gradient = torch.mul(v_gradient, v_mask) + + h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:]) + h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:]) + h_gradient = torch.mul(h_gradient, h_mask) + + gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N + + return gradient_loss + + def forward(self, depth_pred, depth_gt): + """Forward function.""" + + gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt) + return gradient_loss diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/sigloss.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/sigloss.py new file mode 100755 index 0000000..e12fad3 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/models/losses/sigloss.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...models.builder import LOSSES + + +@LOSSES.register_module() +class SigLoss(nn.Module): + """SigLoss. + + This follows `AdaBins `_. + + Args: + valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. + loss_weight (float): Weight of the loss. Default: 1.0. + max_depth (int): When filtering invalid gt, set a max threshold. Default: None. + warm_up (bool): A simple warm up stage to help convergence. Default: False. + warm_iter (int): The number of warm up stage. Default: 100. + """ + + def __init__( + self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss" + ): + super(SigLoss, self).__init__() + self.valid_mask = valid_mask + self.loss_weight = loss_weight + self.max_depth = max_depth + self.loss_name = loss_name + + self.eps = 0.001 # avoid grad explode + + # HACK: a hack implementation for warmup sigloss + self.warm_up = warm_up + self.warm_iter = warm_iter + self.warm_up_counter = 0 + + def sigloss(self, input, target): + if self.valid_mask: + valid_mask = target > 0 + if self.max_depth is not None: + valid_mask = torch.logical_and(target > 0, target <= self.max_depth) + input = input[valid_mask] + target = target[valid_mask] + + if self.warm_up: + if self.warm_up_counter < self.warm_iter: + g = torch.log(input + self.eps) - torch.log(target + self.eps) + g = 0.15 * torch.pow(torch.mean(g), 2) + self.warm_up_counter += 1 + return torch.sqrt(g) + + g = torch.log(input + self.eps) - torch.log(target + self.eps) + Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2) + return torch.sqrt(Dg) + + def forward(self, depth_pred, depth_gt): + """Forward function.""" + + loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt) + return loss_depth diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/ops/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/ops/__init__.py new file mode 100755 index 0000000..78181c2 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/ops/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .wrappers import resize diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/ops/wrappers.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/ops/wrappers.py new file mode 100755 index 0000000..15880ee --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/depth/ops/wrappers.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/knn.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/knn.py new file mode 100755 index 0000000..f3a4845 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/knn.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional + +import torch +from torch.nn.functional import one_hot, softmax + +import dinov2.distributed as distributed +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data.transforms import make_classification_eval_transform +from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--nb_knn", + nargs="+", + type=int, + help="Number of NN to use. 20 is usually working the best.", + ) + parser.add_argument( + "--temperature", + type=float, + help="Temperature used in the voting coefficient", + ) + parser.add_argument( + "--gather-on-cpu", + action="store_true", + help="Whether to gather the train features on cpu, slower" + "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch size.", + ) + parser.add_argument( + "--n-per-class-list", + nargs="+", + type=int, + help="Number to take per class", + ) + parser.add_argument( + "--n-tries", + type=int, + help="Number of tries", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=[10, 20, 100, 200], + temperature=0.07, + batch_size=256, + n_per_class_list=[-1], + n_tries=1, + ) + return parser + + +class KnnModule(torch.nn.Module): + """ + Gets knn of test features from all processes on a chunk of the train features + + Each rank gets a chunk of the train features as well as a chunk of the test features. + In `compute_neighbors`, for each rank one after the other, its chunk of test features + is sent to all devices, partial knns are computed with each chunk of train features + then collated back on the original device. + """ + + def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): + super().__init__() + + self.global_rank = distributed.get_global_rank() + self.global_size = distributed.get_global_size() + + self.device = device + self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) + self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) + + self.nb_knn = nb_knn + self.max_k = max(self.nb_knn) + self.T = T + self.num_classes = num_classes + + def _get_knn_sims_and_labels(self, similarity, train_labels): + topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) + neighbors_labels = torch.gather(train_labels, 1, indices) + return topk_sims, neighbors_labels + + def _similarity_for_rank(self, features_rank, source_rank): + # Send the features from `source_rank` to all ranks + broadcast_shape = torch.tensor(features_rank.shape).to(self.device) + torch.distributed.broadcast(broadcast_shape, source_rank) + + broadcasted = features_rank + if self.global_rank != source_rank: + broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) + torch.distributed.broadcast(broadcasted, source_rank) + + # Compute the neighbors for `source_rank` among `train_features_rank_T` + similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) + candidate_labels = self.candidates.expand(len(similarity_rank), -1) + return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) + + def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): + # Gather all neighbors for `target_rank` + topk_sims_rank = retrieved_rank = None + if self.global_rank == target_rank: + topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] + retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] + + torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) + torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) + + if self.global_rank == target_rank: + # Perform a second top-k on the k * global_size retrieved neighbors + topk_sims_rank = torch.cat(topk_sims_rank, dim=1) + retrieved_rank = torch.cat(retrieved_rank, dim=1) + results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) + return results + return None + + def compute_neighbors(self, features_rank): + for rank in range(self.global_size): + topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) + results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) + if results is not None: + topk_sims_rank, neighbors_labels_rank = results + return topk_sims_rank, neighbors_labels_rank + + def forward(self, features_rank): + """ + Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` + """ + assert all(k <= self.max_k for k in self.nb_knn) + + topk_sims, neighbors_labels = self.compute_neighbors(features_rank) + batch_size = neighbors_labels.shape[0] + topk_sims_transform = softmax(topk_sims / self.T, 1) + matmul = torch.mul( + one_hot(neighbors_labels, num_classes=self.num_classes), + topk_sims_transform.view(batch_size, -1, 1), + ) + probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} + return probas_for_k + + +class DictKeysModule(torch.nn.Module): + def __init__(self, keys): + super().__init__() + self.keys = keys + + def forward(self, features_dict, targets): + for k in self.keys: + features_dict = features_dict[k] + return {"preds": features_dict, "target": targets} + + +def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): + modules = {} + mapping = create_class_indices_mapping(train_labels) + for npc in n_per_class_list: + if npc < 0: # Only one try needed when using the full data + full_module = module( + train_features=train_features, + train_labels=train_labels, + nb_knn=nb_knn, + ) + modules["full"] = ModuleDictWithForward({"1": full_module}) + continue + all_tries = {} + for t in range(n_tries): + final_indices = filter_train(mapping, npc, seed=t) + k_list = list(set(nb_knn + [npc])) + k_list = sorted([el for el in k_list if el <= npc]) + all_tries[str(t)] = module( + train_features=train_features[final_indices], + train_labels=train_labels[final_indices], + nb_knn=k_list, + ) + modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) + + return ModuleDictWithForward(modules) + + +def filter_train(mapping, n_per_class, seed): + torch.manual_seed(seed) + final_indices = [] + for k in mapping.keys(): + index = torch.randperm(len(mapping[k]))[:n_per_class] + final_indices.append(mapping[k][index]) + return torch.cat(final_indices).squeeze() + + +def create_class_indices_mapping(labels): + unique_labels, inverse = torch.unique(labels, return_inverse=True) + mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} + return mapping + + +class ModuleDictWithForward(torch.nn.ModuleDict): + def forward(self, *args, **kwargs): + return {k: module(*args, **kwargs) for k, module in self._modules.items()} + + +def eval_knn( + model, + train_dataset, + val_dataset, + accuracy_averaging, + nb_knn, + temperature, + batch_size, + num_workers, + gather_on_cpu, + n_per_class_list=[-1], + n_tries=1, +): + model = ModelWithNormalize(model) + + logger.info("Extracting features for train set...") + train_features, train_labels = extract_features( + model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu + ) + logger.info(f"Train features created, shape {train_features.shape}.") + + val_dataloader = make_data_loader( + dataset=val_dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=True, + ) + num_classes = train_labels.max() + 1 + metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) + + device = torch.cuda.current_device() + partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) + knn_module_dict = create_module_dict( + module=partial_module, + n_per_class_list=n_per_class_list, + n_tries=n_tries, + nb_knn=nb_knn, + train_features=train_features, + train_labels=train_labels, + ) + postprocessors, metrics = {}, {} + for n_per_class, knn_module in knn_module_dict.items(): + for t, knn_try in knn_module.items(): + postprocessors = { + **postprocessors, + **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, + } + metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} + model_with_knn = torch.nn.Sequential(model, knn_module_dict) + + # ============ evaluation ... ============ + logger.info("Start the k-NN classification.") + _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) + + # Averaging the results over the n tries for each value of n_per_class + for n_per_class, knn_module in knn_module_dict.items(): + first_try = list(knn_module.keys())[0] + k_list = knn_module[first_try].nb_knn + for k in k_list: + keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5` + results_dict[(n_per_class, k)] = { + key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) + for key in keys + } + for t in knn_module.keys(): + del results_dict[(n_per_class, t, k)] + + return results_dict + + +def eval_knn_with_model( + model, + output_dir, + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=(10, 20, 100, 200), + temperature=0.07, + autocast_dtype=torch.float, + accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, + transform=None, + gather_on_cpu=False, + batch_size=256, + num_workers=5, + n_per_class_list=[-1], + n_tries=1, +): + transform = transform or make_classification_eval_transform() + + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=transform, + ) + val_dataset = make_dataset( + dataset_str=val_dataset_str, + transform=transform, + ) + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + results_dict_knn = eval_knn( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + accuracy_averaging=accuracy_averaging, + nb_knn=nb_knn, + temperature=temperature, + batch_size=batch_size, + num_workers=num_workers, + gather_on_cpu=gather_on_cpu, + n_per_class_list=n_per_class_list, + n_tries=n_tries, + ) + + results_dict = {} + if distributed.is_main_process(): + for knn_ in results_dict_knn.keys(): + top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 + top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 + results_dict[f"{knn_} Top 1"] = top1 + results_dict[f"{knn_} Top 5"] = top5 + logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") + + metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") + with open(metrics_file_path, "a") as f: + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + + if distributed.is_enabled(): + torch.distributed.barrier() + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_knn_with_model( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + nb_knn=args.nb_knn, + temperature=args.temperature, + autocast_dtype=autocast_dtype, + accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, + transform=None, + gather_on_cpu=args.gather_on_cpu, + batch_size=args.batch_size, + num_workers=5, + n_per_class_list=args.n_per_class_list, + n_tries=args.n_tries, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 k-NN evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/linear.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/linear.py new file mode 100755 index 0000000..1bd4c5d --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/linear.py @@ -0,0 +1,625 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer + +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform +import dinov2.distributed as distributed +from dinov2.eval.metrics import MetricType, build_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate +from dinov2.logging import MetricLogger + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--test-datasets", + dest="test_dataset_strs", + type=str, + nargs="+", + help="Test datasets, none to reuse the validation dataset", + ) + parser.add_argument( + "--epochs", + type=int, + help="Number of training epochs", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch Size (per GPU)", + ) + parser.add_argument( + "--num-workers", + type=int, + help="Number de Workers", + ) + parser.add_argument( + "--epoch-length", + type=int, + help="Length of an epoch in number of iterations", + ) + parser.add_argument( + "--save-checkpoint-frequency", + type=int, + help="Number of epochs between two named checkpoint saves.", + ) + parser.add_argument( + "--eval-period-iterations", + type=int, + help="Number of iterations between two evaluations.", + ) + parser.add_argument( + "--learning-rates", + nargs="+", + type=float, + help="Learning rates to grid search.", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not resume from existing checkpoints", + ) + parser.add_argument( + "--val-metric-type", + type=MetricType, + choices=list(MetricType), + help="Validation metric", + ) + parser.add_argument( + "--test-metric-types", + type=MetricType, + choices=list(MetricType), + nargs="+", + help="Evaluation metric", + ) + parser.add_argument( + "--classifier-fpath", + type=str, + help="Path to a file containing pretrained linear classifiers", + ) + parser.add_argument( + "--val-class-mapping-fpath", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.add_argument( + "--test-class-mapping-fpaths", + nargs="+", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + test_dataset_strs=None, + epochs=10, + batch_size=128, + num_workers=8, + epoch_length=1250, + save_checkpoint_frequency=20, + eval_period_iterations=1250, + learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + ) + return parser + + +def has_ddp_wrapper(m: nn.Module) -> bool: + return isinstance(m, DistributedDataParallel) + + +def remove_ddp_wrapper(m: nn.Module) -> nn.Module: + return m.module if has_ddp_wrapper(m) else m + + +def _pad_and_collate(batch): + maxlen = max(len(targets) for image, targets in batch) + padded_batch = [ + (image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch + ] + return torch.utils.data.default_collate(padded_batch) + + +def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool): + intermediate_output = x_tokens_list[-use_n_blocks:] + output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) + if use_avgpool: + output = torch.cat( + ( + output, + torch.mean(intermediate_output[-1][0], dim=1), # patch tokens + ), + dim=-1, + ) + output = output.reshape(output.shape[0], -1) + return output.float() + + +class LinearClassifier(nn.Module): + """Linear layer to train on top of frozen features""" + + def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000): + super().__init__() + self.out_dim = out_dim + self.use_n_blocks = use_n_blocks + self.use_avgpool = use_avgpool + self.num_classes = num_classes + self.linear = nn.Linear(out_dim, num_classes) + self.linear.weight.data.normal_(mean=0.0, std=0.01) + self.linear.bias.data.zero_() + + def forward(self, x_tokens_list): + output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool) + return self.linear(output) + + +class AllClassifiers(nn.Module): + def __init__(self, classifiers_dict): + super().__init__() + self.classifiers_dict = nn.ModuleDict() + self.classifiers_dict.update(classifiers_dict) + + def forward(self, inputs): + return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} + + def __len__(self): + return len(self.classifiers_dict) + + +class LinearPostprocessor(nn.Module): + def __init__(self, linear_classifier, class_mapping=None): + super().__init__() + self.linear_classifier = linear_classifier + self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) + + def forward(self, samples, targets): + preds = self.linear_classifier(samples) + return { + "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, + "target": targets, + } + + +def scale_lr(learning_rates, batch_size): + return learning_rates * (batch_size * distributed.get_global_size()) / 256.0 + + +def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000): + linear_classifiers_dict = nn.ModuleDict() + optim_param_groups = [] + for n in n_last_blocks_list: + for avgpool in [False, True]: + for _lr in learning_rates: + lr = scale_lr(_lr, batch_size) + out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1] + linear_classifier = LinearClassifier( + out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes + ) + linear_classifier = linear_classifier.cuda() + linear_classifiers_dict[ + f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_") + ] = linear_classifier + optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr}) + + linear_classifiers = AllClassifiers(linear_classifiers_dict) + if distributed.is_enabled(): + linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) + + return linear_classifiers, optim_param_groups + + +@torch.no_grad() +def evaluate_linear_classifiers( + feature_model, + linear_classifiers, + data_loader, + metric_type, + metrics_file_path, + training_num_classes, + iteration, + prefixstring="", + class_mapping=None, + best_classifier_on_val=None, +): + logger.info("running validation !") + + num_classes = len(class_mapping) if class_mapping is not None else training_num_classes + metric = build_metric(metric_type, num_classes=num_classes) + postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()} + metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} + + _, results_dict_temp = evaluate( + feature_model, + data_loader, + postprocessors, + metrics, + torch.cuda.current_device(), + ) + + logger.info("") + results_dict = {} + max_accuracy = 0 + best_classifier = "" + for i, (classifier_string, metric) in enumerate(results_dict_temp.items()): + logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") + if ( + best_classifier_on_val is None and metric["top-1"].item() > max_accuracy + ) or classifier_string == best_classifier_on_val: + max_accuracy = metric["top-1"].item() + best_classifier = classifier_string + + results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} + + logger.info(f"best classifier: {results_dict['best_classifier']}") + + if distributed.is_main_process(): + with open(metrics_file_path, "a") as f: + f.write(f"iter: {iteration}\n") + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + f.write("\n") + + return results_dict + + +def eval_linear( + *, + feature_model, + linear_classifiers, + train_data_loader, + val_data_loader, + metrics_file_path, + optimizer, + scheduler, + output_dir, + max_iter, + checkpoint_period, # In number of iter, creates a new file every period + running_checkpoint_period, # Period to update main checkpoint file + eval_period, + metric_type, + training_num_classes, + resume=True, + classifier_fpath=None, + val_class_mapping=None, +): + checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) + start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 + + periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter) + iteration = start_iter + logger.info("Starting training from iteration {}".format(start_iter)) + metric_logger = MetricLogger(delimiter=" ") + header = "Training" + + for data, labels in metric_logger.log_every( + train_data_loader, + 10, + header, + max_iter, + start_iter, + ): + data = data.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + features = feature_model(data) + outputs = linear_classifiers(features) + + losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()} + loss = sum(losses.values()) + + # compute the gradients + optimizer.zero_grad() + loss.backward() + + # step + optimizer.step() + scheduler.step() + + # log + if iteration % 10 == 0: + torch.cuda.synchronize() + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + print("lr", optimizer.param_groups[0]["lr"]) + + if iteration - start_iter > 5: + if iteration % running_checkpoint_period == 0: + torch.cuda.synchronize() + if distributed.is_main_process(): + logger.info("Checkpointing running_checkpoint") + periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration) + torch.cuda.synchronize() + periodic_checkpointer.step(iteration) + + if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: + _ = evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + prefixstring=f"ITER: {iteration}", + metric_type=metric_type, + training_num_classes=training_num_classes, + iteration=iteration, + class_mapping=val_class_mapping, + ) + torch.cuda.synchronize() + + iteration = iteration + 1 + + val_results_dict = evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + metric_type=metric_type, + training_num_classes=training_num_classes, + iteration=iteration, + class_mapping=val_class_mapping, + ) + return val_results_dict, feature_model, linear_classifiers, iteration + + +def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type): + test_dataset = make_dataset( + dataset_str=test_dataset_str, + transform=make_classification_eval_transform(), + ) + test_data_loader = make_data_loader( + dataset=test_dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=False, + collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None, + ) + return test_data_loader + + +def test_on_datasets( + feature_model, + linear_classifiers, + test_dataset_strs, + batch_size, + num_workers, + test_metric_types, + metrics_file_path, + training_num_classes, + iteration, + best_classifier_on_val, + prefixstring="", + test_class_mappings=[None], +): + results_dict = {} + for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types): + logger.info(f"Testing on {test_dataset_str}") + test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type) + dataset_results_dict = evaluate_linear_classifiers( + feature_model, + remove_ddp_wrapper(linear_classifiers), + test_data_loader, + metric_type, + metrics_file_path, + training_num_classes, + iteration, + prefixstring="", + class_mapping=class_mapping, + best_classifier_on_val=best_classifier_on_val, + ) + results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"] + return results_dict + + +def run_eval_linear( + model, + output_dir, + train_dataset_str, + val_dataset_str, + batch_size, + epochs, + epoch_length, + num_workers, + save_checkpoint_frequency, + eval_period_iterations, + learning_rates, + autocast_dtype, + test_dataset_strs=None, + resume=True, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, +): + seed = 0 + + if test_dataset_strs is None: + test_dataset_strs = [val_dataset_str] + if test_metric_types is None: + test_metric_types = [val_metric_type] * len(test_dataset_strs) + else: + assert len(test_metric_types) == len(test_dataset_strs) + assert len(test_dataset_strs) == len(test_class_mapping_fpaths) + + train_transform = make_classification_train_transform() + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=train_transform, + ) + training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int)))) + sampler_type = SamplerType.SHARDED_INFINITE + # sampler_type = SamplerType.INFINITE + + n_last_blocks_list = [1, 4] + n_last_blocks = max(n_last_blocks_list) + autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) + feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) + sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda()) + + linear_classifiers, optim_param_groups = setup_linear_classifiers( + sample_output, + n_last_blocks_list, + learning_rates, + batch_size, + training_num_classes, + ) + + optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0) + max_iter = epochs * epoch_length + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) + checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) + start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 + train_data_loader = make_data_loader( + dataset=train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + seed=seed, + sampler_type=sampler_type, + sampler_advance=start_iter, + drop_last=True, + persistent_workers=True, + ) + val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type) + + checkpoint_period = save_checkpoint_frequency * epoch_length + + if val_class_mapping_fpath is not None: + logger.info(f"Using class mapping from {val_class_mapping_fpath}") + val_class_mapping = np.load(val_class_mapping_fpath) + else: + val_class_mapping = None + + test_class_mappings = [] + for class_mapping_fpath in test_class_mapping_fpaths: + if class_mapping_fpath is not None and class_mapping_fpath != "None": + logger.info(f"Using class mapping from {class_mapping_fpath}") + class_mapping = np.load(class_mapping_fpath) + else: + class_mapping = None + test_class_mappings.append(class_mapping) + + metrics_file_path = os.path.join(output_dir, "results_eval_linear.json") + val_results_dict, feature_model, linear_classifiers, iteration = eval_linear( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + optimizer=optimizer, + scheduler=scheduler, + output_dir=output_dir, + max_iter=max_iter, + checkpoint_period=checkpoint_period, + running_checkpoint_period=epoch_length, + eval_period=eval_period_iterations, + metric_type=val_metric_type, + training_num_classes=training_num_classes, + resume=resume, + val_class_mapping=val_class_mapping, + classifier_fpath=classifier_fpath, + ) + results_dict = {} + if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str: + results_dict = test_on_datasets( + feature_model, + linear_classifiers, + test_dataset_strs, + batch_size, + 0, # num_workers, + test_metric_types, + metrics_file_path, + training_num_classes, + iteration, + val_results_dict["best_classifier"]["name"], + prefixstring="", + test_class_mappings=test_class_mappings, + ) + results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"] + results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"] + logger.info("Test Results Dict " + str(results_dict)) + + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + run_eval_linear( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + test_dataset_strs=args.test_dataset_strs, + batch_size=args.batch_size, + epochs=args.epochs, + epoch_length=args.epoch_length, + num_workers=args.num_workers, + save_checkpoint_frequency=args.save_checkpoint_frequency, + eval_period_iterations=args.eval_period_iterations, + learning_rates=args.learning_rates, + autocast_dtype=autocast_dtype, + resume=not args.no_resume, + classifier_fpath=args.classifier_fpath, + val_metric_type=args.val_metric_type, + test_metric_types=args.test_metric_types, + val_class_mapping_fpath=args.val_class_mapping_fpath, + test_class_mapping_fpaths=args.test_class_mapping_fpaths, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 linear evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/log_regression.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/log_regression.py new file mode 100755 index 0000000..5f36ec1 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/log_regression.py @@ -0,0 +1,444 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import gc +import logging +import sys +import time +from typing import List, Optional + +from cuml.linear_model import LogisticRegression +import torch +import torch.backends.cudnn as cudnn +import torch.distributed +from torch import nn +from torch.utils.data import TensorDataset +from torchmetrics import MetricTracker + +from dinov2.data import make_dataset +from dinov2.data.transforms import make_classification_eval_transform +from dinov2.distributed import get_global_rank, get_global_size +from dinov2.eval.metrics import MetricType, build_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import evaluate, extract_features +from dinov2.utils.dtype import as_torch_dtype + + +logger = logging.getLogger("dinov2") + +DEFAULT_MAX_ITER = 1_000 +C_POWER_RANGE = torch.linspace(-6, 5, 45) +_CPU_DEVICE = torch.device("cpu") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--finetune-dataset-str", + dest="finetune_dataset_str", + type=str, + help="Fine-tuning dataset", + ) + parser.add_argument( + "--finetune-on-val", + action="store_true", + help="If there is no finetune dataset, whether to choose the " + "hyperparameters on the val set instead of 10%% of the train dataset", + ) + parser.add_argument( + "--metric-type", + type=MetricType, + choices=list(MetricType), + help="Metric type", + ) + parser.add_argument( + "--train-features-device", + type=str, + help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s", + ) + parser.add_argument( + "--train-dtype", + type=str, + help="Data type to convert the train features to (default: %(default)s)", + ) + parser.add_argument( + "--max-train-iters", + type=int, + help="Maximum number of train iterations (default: %(default)s)", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + finetune_dataset_str=None, + metric_type=MetricType.MEAN_ACCURACY, + train_features_device="cpu", + train_dtype="float64", + max_train_iters=DEFAULT_MAX_ITER, + finetune_on_val=False, + ) + return parser + + +class LogRegModule(nn.Module): + def __init__( + self, + C, + max_iter=DEFAULT_MAX_ITER, + dtype=torch.float64, + device=_CPU_DEVICE, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.estimator = LogisticRegression( + penalty="l2", + C=C, + max_iter=max_iter, + output_type="numpy", + tol=1e-12, + linesearch_max_iter=50, + ) + + def forward(self, samples, targets): + samples_device = samples.device + samples = samples.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + samples = samples.numpy() + probas = self.estimator.predict_proba(samples) + return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets} + + def fit(self, train_features, train_labels): + train_features = train_features.to(dtype=self.dtype, device=self.device) + train_labels = train_labels.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + # both cuML and sklearn only work with numpy arrays on CPU + train_features = train_features.numpy() + train_labels = train_labels.numpy() + self.estimator.fit(train_features, train_labels) + + +def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device): + postprocessors = {"metrics": logreg_model} + metrics = {"metrics": logreg_metric} + return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device) + + +def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE): + logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device) + logreg_model.fit(train_features, train_labels) + return logreg_model + + +def train_and_evaluate( + *, + C, + max_iter, + train_features, + train_labels, + logreg_metric, + test_data_loader, + train_dtype=torch.float64, + train_features_device, + eval_device, +): + logreg_model = train_for_C( + C=C, + max_iter=max_iter, + train_features=train_features, + train_labels=train_labels, + dtype=train_dtype, + device=train_features_device, + ) + return evaluate_model( + logreg_model=logreg_model, + logreg_metric=logreg_metric, + test_data_loader=test_data_loader, + device=eval_device, + ) + + +def sweep_C_values( + *, + train_features, + train_labels, + test_data_loader, + metric_type, + num_classes, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + if metric_type == MetricType.PER_CLASS_ACCURACY: + # If we want to output per-class accuracy, we select the hyperparameters with mean per class + metric_type = MetricType.MEAN_PER_CLASS_ACCURACY + logreg_metric = build_metric(metric_type, num_classes=num_classes) + metric_tracker = MetricTracker(logreg_metric, maximize=True) + ALL_C = 10**C_POWER_RANGE + logreg_models = {} + + train_features = train_features.to(dtype=train_dtype, device=train_features_device) + train_labels = train_labels.to(device=train_features_device) + + for i in range(get_global_rank(), len(ALL_C), get_global_size()): + C = ALL_C[i].item() + logger.info( + f"Training for C = {C:.5f}, dtype={train_dtype}, " + f"features: {train_features.shape}, {train_features.dtype}, " + f"labels: {train_labels.shape}, {train_labels.dtype}" + ) + logreg_models[C] = train_for_C( + C=C, + max_iter=max_train_iters, + train_features=train_features, + train_labels=train_labels, + dtype=train_dtype, + device=train_features_device, + ) + + gather_list = [None for _ in range(get_global_size())] + torch.distributed.all_gather_object(gather_list, logreg_models) + + logreg_models_gathered = {} + for logreg_dict in gather_list: + logreg_models_gathered.update(logreg_dict) + + for i in range(len(ALL_C)): + metric_tracker.increment() + C = ALL_C[i].item() + evals = evaluate_model( + logreg_model=logreg_models_gathered[C], + logreg_metric=metric_tracker, + test_data_loader=test_data_loader, + device=torch.cuda.current_device(), + ) + logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}") + + best_stats, which_epoch = metric_tracker.best_metric(return_step=True) + best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()} + if which_epoch["top-1"] == i: + best_C = C + logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}") + + return best_stats, best_C + + +def eval_log_regression( + *, + model, + train_dataset, + val_dataset, + finetune_dataset, + metric_type, + batch_size, + num_workers, + finetune_on_val=False, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + """ + Implements the "standard" process for log regression evaluation: + The value of C is chosen by training on train_dataset and evaluating on + finetune_dataset. Then, the final model is trained on a concatenation of + train_dataset and finetune_dataset, and is evaluated on val_dataset. + If there is no finetune_dataset, the value of C is the one that yields + the best results on a random 10% subset of the train dataset + """ + + start = time.time() + + train_features, train_labels = extract_features( + model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + val_features, val_labels = extract_features( + model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + val_data_loader = torch.utils.data.DataLoader( + TensorDataset(val_features, val_labels), + batch_size=batch_size, + drop_last=False, + num_workers=0, + persistent_workers=False, + ) + + if finetune_dataset is None and finetune_on_val: + logger.info("Choosing hyperparameters on the val dataset") + finetune_features, finetune_labels = val_features, val_labels + elif finetune_dataset is None and not finetune_on_val: + logger.info("Choosing hyperparameters on 10% of the train dataset") + torch.manual_seed(0) + indices = torch.randperm(len(train_features), device=train_features.device) + finetune_index = indices[: len(train_features) // 10] + train_index = indices[len(train_features) // 10 :] + finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index] + train_features, train_labels = train_features[train_index], train_labels[train_index] + else: + logger.info("Choosing hyperparameters on the finetune dataset") + finetune_features, finetune_labels = extract_features( + model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + # release the model - free GPU memory + del model + gc.collect() + torch.cuda.empty_cache() + finetune_data_loader = torch.utils.data.DataLoader( + TensorDataset(finetune_features, finetune_labels), + batch_size=batch_size, + drop_last=False, + ) + + if len(train_labels.shape) > 1: + num_classes = train_labels.shape[1] + else: + num_classes = train_labels.max() + 1 + + logger.info("Using cuML for logistic regression") + + best_stats, best_C = sweep_C_values( + train_features=train_features, + train_labels=train_labels, + test_data_loader=finetune_data_loader, + metric_type=metric_type, + num_classes=num_classes, + train_dtype=train_dtype, + train_features_device=train_features_device, + max_train_iters=max_train_iters, + ) + + if not finetune_on_val: + logger.info("Best parameter found, concatenating features") + train_features = torch.cat((train_features, finetune_features)) + train_labels = torch.cat((train_labels, finetune_labels)) + + logger.info("Training final model") + logreg_metric = build_metric(metric_type, num_classes=num_classes) + evals = train_and_evaluate( + C=best_C, + max_iter=max_train_iters, + train_features=train_features, + train_labels=train_labels, + logreg_metric=logreg_metric.clone(), + test_data_loader=val_data_loader, + eval_device=torch.cuda.current_device(), + train_dtype=train_dtype, + train_features_device=train_features_device, + ) + + best_stats = evals[1]["metrics"] + + best_stats["best_C"] = best_C + + logger.info(f"Log regression evaluation done in {int(time.time() - start)}s") + return best_stats + + +def eval_log_regression_with_model( + model, + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + finetune_dataset_str=None, + autocast_dtype=torch.float, + finetune_on_val=False, + metric_type=MetricType.MEAN_ACCURACY, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + cudnn.benchmark = True + + transform = make_classification_eval_transform(resize_size=224) + target_transform = None + + train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform) + val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform) + if finetune_dataset_str is not None: + finetune_dataset = make_dataset( + dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform + ) + else: + finetune_dataset = None + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + results_dict_logreg = eval_log_regression( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + finetune_dataset=finetune_dataset, + metric_type=metric_type, + batch_size=256, + num_workers=0, # 5, + finetune_on_val=finetune_on_val, + train_dtype=train_dtype, + train_features_device=train_features_device, + max_train_iters=max_train_iters, + ) + + results_dict = { + "top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0, + "top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0, + "best_C": results_dict_logreg["best_C"], + } + logger.info( + "\n".join( + [ + "Training of the supervised logistic regression on frozen features completed.\n" + "Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]), + "Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]), + "obtained for C = {c:.6f}".format(c=results_dict["best_C"]), + ] + ) + ) + + torch.distributed.barrier() + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_log_regression_with_model( + model=model, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + finetune_dataset_str=args.finetune_dataset_str, + autocast_dtype=autocast_dtype, + finetune_on_val=args.finetune_on_val, + metric_type=args.metric_type, + train_dtype=as_torch_dtype(args.train_dtype), + train_features_device=torch.device(args.train_features_device), + max_train_iters=args.max_train_iters, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 logistic regression evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/metrics.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/metrics.py new file mode 100755 index 0000000..52be81a --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/metrics.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import logging +from typing import Any, Dict, Optional + +import torch +from torch import Tensor +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import MulticlassAccuracy +from torchmetrics.utilities.data import dim_zero_cat, select_topk + + +logger = logging.getLogger("dinov2") + + +class MetricType(Enum): + MEAN_ACCURACY = "mean_accuracy" + MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" + PER_CLASS_ACCURACY = "per_class_accuracy" + IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" + + @property + def accuracy_averaging(self): + return getattr(AccuracyAveraging, self.name, None) + + def __str__(self): + return self.value + + +class AccuracyAveraging(Enum): + MEAN_ACCURACY = "micro" + MEAN_PER_CLASS_ACCURACY = "macro" + PER_CLASS_ACCURACY = "none" + + def __str__(self): + return self.value + + +def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): + if metric_type.accuracy_averaging is not None: + return build_topk_accuracy_metric( + average_type=metric_type.accuracy_averaging, + num_classes=num_classes, + ks=(1, 5) if ks is None else ks, + ) + elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: + return build_topk_imagenet_real_accuracy_metric( + num_classes=num_classes, + ks=(1, 5) if ks is None else ks, + ) + + raise ValueError(f"Unknown metric type {metric_type}") + + +def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = { + f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks + } + return MetricCollection(metrics) + + +def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} + return MetricCollection(metrics) + + +class ImageNetReaLAccuracy(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.top_k = top_k + self.add_state("tp", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + # preds [B, D] + # target [B, A] + # preds_oh [B, D] with 0 and 1 + # select top K highest probabilities, use one hot representation + preds_oh = select_topk(preds, self.top_k) + # target_oh [B, D + 1] with 0 and 1 + target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) + target = target.long() + # for undefined targets (-1) use a fake value `num_classes` + target[target == -1] = self.num_classes + # fill targets, use one hot representation + target_oh.scatter_(1, target, 1) + # target_oh [B, D] (remove the fake target at index `num_classes`) + target_oh = target_oh[:, :-1] + # tp [B] with 0 and 1 + tp = (preds_oh * target_oh == 1).sum(dim=1) + # at least one match between prediction and target + tp.clip_(max=1) + # ignore instances where no targets are defined + mask = target_oh.sum(dim=1) > 0 + tp = tp[mask] + self.tp.append(tp) # type: ignore + + def compute(self) -> Tensor: + tp = dim_zero_cat(self.tp) # type: ignore + return tp.float().mean() diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/__init__.py new file mode 100755 index 0000000..b88da6b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/hooks/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/hooks/__init__.py new file mode 100755 index 0000000..738cc2d --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/hooks/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .optimizer import DistOptimizerHook diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/hooks/optimizer.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/hooks/optimizer.py new file mode 100755 index 0000000..f593f26 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/hooks/optimizer.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +try: + import apex +except ImportError: + print("apex is not installed") + +from mmcv.runner import OptimizerHook, HOOKS + + +@HOOKS.register_module() +class DistOptimizerHook(OptimizerHook): + """Optimizer hook for distributed training.""" + + def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.update_interval = update_interval + self.use_fp16 = use_fp16 + + def before_run(self, runner): + runner.optimizer.zero_grad() + + def after_train_iter(self, runner): + runner.outputs["loss"] /= self.update_interval + if self.use_fp16: + # runner.outputs['loss'].backward() + with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss: + scaled_loss.backward() + else: + runner.outputs["loss"].backward() + if self.every_n_iters(runner, self.update_interval): + if self.grad_clip is not None: + self.clip_grads(runner.model.parameters()) + runner.optimizer.step() + runner.optimizer.zero_grad() diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/__init__.py new file mode 100755 index 0000000..88e4563 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .decode_heads import * # noqa: F403 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/backbones/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/backbones/__init__.py new file mode 100755 index 0000000..520d75b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vision_transformer import DinoVisionTransformer diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/backbones/vision_transformer.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/backbones/vision_transformer.py new file mode 100755 index 0000000..c3e9753 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/backbones/vision_transformer.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.runner import BaseModule +from mmseg.models.builder import BACKBONES + + +@BACKBONES.register_module() +class DinoVisionTransformer(BaseModule): + """Vision Transformer.""" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__() diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/decode_heads/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/decode_heads/__init__.py new file mode 100755 index 0000000..c553178 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/decode_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .linear_head import BNHead diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/decode_heads/linear_head.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/decode_heads/linear_head.py new file mode 100755 index 0000000..d1f39c6 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/models/decode_heads/linear_head.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from mmseg.models.builder import HEADS +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.ops import resize + + +@HEADS.register_module() +class BNHead(BaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, resize_factors=None, **kwargs): + super().__init__(**kwargs) + assert self.in_channels == self.channels + self.bn = nn.SyncBatchNorm(self.in_channels) + self.resize_factors = resize_factors + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # print("inputs", [i.shape for i in inputs]) + x = self._transform_inputs(inputs) + # print("x", x.shape) + feats = self.bn(x) + # print("feats", feats.shape) + return feats + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == "resize_concat": + # accept lists (for cls token) + input_list = [] + for x in inputs: + if isinstance(x, list): + input_list.extend(x) + else: + input_list.append(x) + inputs = input_list + # an image descriptor can be a local descriptor with resolution 1x1 + for i, x in enumerate(inputs): + if len(x.shape) == 2: + inputs[i] = x[:, :, None, None] + # select indices + inputs = [inputs[i] for i in self.in_index] + # Resizing shenanigans + # print("before", *(x.shape for x in inputs)) + if self.resize_factors is not None: + assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs)) + inputs = [ + resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area") + for x, f in zip(inputs, self.resize_factors) + ] + # print("after", *(x.shape for x in inputs)) + upsampled_inputs = [ + resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) + for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/utils/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/utils/__init__.py new file mode 100755 index 0000000..b88da6b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/utils/colormaps.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/utils/colormaps.py new file mode 100755 index 0000000..e6ef604 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation/utils/colormaps.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +ADE20K_COLORMAP = [ + (0, 0, 0), + (120, 120, 120), + (180, 120, 120), + (6, 230, 230), + (80, 50, 50), + (4, 200, 3), + (120, 120, 80), + (140, 140, 140), + (204, 5, 255), + (230, 230, 230), + (4, 250, 7), + (224, 5, 255), + (235, 255, 7), + (150, 5, 61), + (120, 120, 70), + (8, 255, 51), + (255, 6, 82), + (143, 255, 140), + (204, 255, 4), + (255, 51, 7), + (204, 70, 3), + (0, 102, 200), + (61, 230, 250), + (255, 6, 51), + (11, 102, 255), + (255, 7, 71), + (255, 9, 224), + (9, 7, 230), + (220, 220, 220), + (255, 9, 92), + (112, 9, 255), + (8, 255, 214), + (7, 255, 224), + (255, 184, 6), + (10, 255, 71), + (255, 41, 10), + (7, 255, 255), + (224, 255, 8), + (102, 8, 255), + (255, 61, 6), + (255, 194, 7), + (255, 122, 8), + (0, 255, 20), + (255, 8, 41), + (255, 5, 153), + (6, 51, 255), + (235, 12, 255), + (160, 150, 20), + (0, 163, 255), + (140, 140, 140), + (250, 10, 15), + (20, 255, 0), + (31, 255, 0), + (255, 31, 0), + (255, 224, 0), + (153, 255, 0), + (0, 0, 255), + (255, 71, 0), + (0, 235, 255), + (0, 173, 255), + (31, 0, 255), + (11, 200, 200), + (255, 82, 0), + (0, 255, 245), + (0, 61, 255), + (0, 255, 112), + (0, 255, 133), + (255, 0, 0), + (255, 163, 0), + (255, 102, 0), + (194, 255, 0), + (0, 143, 255), + (51, 255, 0), + (0, 82, 255), + (0, 255, 41), + (0, 255, 173), + (10, 0, 255), + (173, 255, 0), + (0, 255, 153), + (255, 92, 0), + (255, 0, 255), + (255, 0, 245), + (255, 0, 102), + (255, 173, 0), + (255, 0, 20), + (255, 184, 184), + (0, 31, 255), + (0, 255, 61), + (0, 71, 255), + (255, 0, 204), + (0, 255, 194), + (0, 255, 82), + (0, 10, 255), + (0, 112, 255), + (51, 0, 255), + (0, 194, 255), + (0, 122, 255), + (0, 255, 163), + (255, 153, 0), + (0, 255, 10), + (255, 112, 0), + (143, 255, 0), + (82, 0, 255), + (163, 255, 0), + (255, 235, 0), + (8, 184, 170), + (133, 0, 255), + (0, 255, 92), + (184, 0, 255), + (255, 0, 31), + (0, 184, 255), + (0, 214, 255), + (255, 0, 112), + (92, 255, 0), + (0, 224, 255), + (112, 224, 255), + (70, 184, 160), + (163, 0, 255), + (153, 0, 255), + (71, 255, 0), + (255, 0, 163), + (255, 204, 0), + (255, 0, 143), + (0, 255, 235), + (133, 255, 0), + (255, 0, 235), + (245, 0, 255), + (255, 0, 122), + (255, 245, 0), + (10, 190, 212), + (214, 255, 0), + (0, 204, 255), + (20, 0, 255), + (255, 255, 0), + (0, 153, 255), + (0, 41, 255), + (0, 255, 204), + (41, 0, 255), + (41, 255, 0), + (173, 0, 255), + (0, 245, 255), + (71, 0, 255), + (122, 0, 255), + (0, 255, 184), + (0, 92, 255), + (184, 255, 0), + (0, 133, 255), + (255, 214, 0), + (25, 194, 194), + (102, 255, 0), + (92, 0, 255), +] + +ADE20K_CLASS_NAMES = [ + "", + "wall", + "building;edifice", + "sky", + "floor;flooring", + "tree", + "ceiling", + "road;route", + "bed", + "windowpane;window", + "grass", + "cabinet", + "sidewalk;pavement", + "person;individual;someone;somebody;mortal;soul", + "earth;ground", + "door;double;door", + "table", + "mountain;mount", + "plant;flora;plant;life", + "curtain;drape;drapery;mantle;pall", + "chair", + "car;auto;automobile;machine;motorcar", + "water", + "painting;picture", + "sofa;couch;lounge", + "shelf", + "house", + "sea", + "mirror", + "rug;carpet;carpeting", + "field", + "armchair", + "seat", + "fence;fencing", + "desk", + "rock;stone", + "wardrobe;closet;press", + "lamp", + "bathtub;bathing;tub;bath;tub", + "railing;rail", + "cushion", + "base;pedestal;stand", + "box", + "column;pillar", + "signboard;sign", + "chest;of;drawers;chest;bureau;dresser", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace;hearth;open;fireplace", + "refrigerator;icebox", + "grandstand;covered;stand", + "path", + "stairs;steps", + "runway", + "case;display;case;showcase;vitrine", + "pool;table;billiard;table;snooker;table", + "pillow", + "screen;door;screen", + "stairway;staircase", + "river", + "bridge;span", + "bookcase", + "blind;screen", + "coffee;table;cocktail;table", + "toilet;can;commode;crapper;pot;potty;stool;throne", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove;kitchen;stove;range;kitchen;range;cooking;stove", + "palm;palm;tree", + "kitchen;island", + "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system", + "swivel;chair", + "boat", + "bar", + "arcade;machine", + "hovel;hut;hutch;shack;shanty", + "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle", + "towel", + "light;light;source", + "truck;motortruck", + "tower", + "chandelier;pendant;pendent", + "awning;sunshade;sunblind", + "streetlight;street;lamp", + "booth;cubicle;stall;kiosk", + "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box", + "airplane;aeroplane;plane", + "dirt;track", + "apparel;wearing;apparel;dress;clothes", + "pole", + "land;ground;soil", + "bannister;banister;balustrade;balusters;handrail", + "escalator;moving;staircase;moving;stairway", + "ottoman;pouf;pouffe;puff;hassock", + "bottle", + "buffet;counter;sideboard", + "poster;posting;placard;notice;bill;card", + "stage", + "van", + "ship", + "fountain", + "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter", + "canopy", + "washer;automatic;washer;washing;machine", + "plaything;toy", + "swimming;pool;swimming;bath;natatorium", + "stool", + "barrel;cask", + "basket;handbasket", + "waterfall;falls", + "tent;collapsible;shelter", + "bag", + "minibike;motorbike", + "cradle", + "oven", + "ball", + "food;solid;food", + "step;stair", + "tank;storage;tank", + "trade;name;brand;name;brand;marque", + "microwave;microwave;oven", + "pot;flowerpot", + "animal;animate;being;beast;brute;creature;fauna", + "bicycle;bike;wheel;cycle", + "lake", + "dishwasher;dish;washer;dishwashing;machine", + "screen;silver;screen;projection;screen", + "blanket;cover", + "sculpture", + "hood;exhaust;hood", + "sconce", + "vase", + "traffic;light;traffic;signal;stoplight", + "tray", + "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin", + "fan", + "pier;wharf;wharfage;dock", + "crt;screen", + "plate", + "monitor;monitoring;device", + "bulletin;board;notice;board", + "shower", + "radiator", + "glass;drinking;glass", + "clock", + "flag", +] + + +VOC2012_COLORMAP = [ + (0, 0, 0), + (128, 0, 0), + (0, 128, 0), + (128, 128, 0), + (0, 0, 128), + (128, 0, 128), + (0, 128, 128), + (128, 128, 128), + (64, 0, 0), + (192, 0, 0), + (64, 128, 0), + (192, 128, 0), + (64, 0, 128), + (192, 0, 128), + (64, 128, 128), + (192, 128, 128), + (0, 64, 0), + (128, 64, 0), + (0, 192, 0), + (128, 192, 0), + (0, 64, 128), +] + + +VOC2012_CLASS_NAMES = [ + "", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +] diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/__init__.py new file mode 100755 index 0000000..6c678fd --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .core import * # noqa: F403 +from .models import * # noqa: F403 +from .ops import * # noqa: F403 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/__init__.py new file mode 100755 index 0000000..9259980 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmseg.core.evaluation import * # noqa: F403 +from mmseg.core.seg import * # noqa: F403 + +from .anchor import * # noqa: F403 +from .box import * # noqa: F403 +from .utils import * # noqa: F403 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/__init__.py new file mode 100755 index 0000000..e71ac4d --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .point_generator import MlvlPointGenerator # noqa: F403 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/builder.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/builder.py new file mode 100755 index 0000000..6dba90e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/builder.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +from mmcv.utils import Registry, build_from_cfg + +PRIOR_GENERATORS = Registry("Generator for anchors and points") + +ANCHOR_GENERATORS = PRIOR_GENERATORS + + +def build_prior_generator(cfg, default_args=None): + return build_from_cfg(cfg, PRIOR_GENERATORS, default_args) + + +def build_anchor_generator(cfg, default_args=None): + warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ") + return build_prior_generator(cfg, default_args=default_args) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py new file mode 100755 index 0000000..574d719 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn.modules.utils import _pair + +from .builder import PRIOR_GENERATORS + + +@PRIOR_GENERATORS.register_module() +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__(self, strides, offset=0.5): + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self): + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self): + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, x, y, row_major=True): + yy, xx = torch.meshgrid(y, x) + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + + else: + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False): + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str): The device where the anchors will be put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[torch.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride + ) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False): + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str, optional): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = torch.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype) + stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype) + shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, featmap_sizes, pad_shape, device="cuda"): + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str): The device where the anchors will be put on. + + Return: + list(torch.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"): + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str, optional): The device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"): + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`torch.dtype`): Date type of points. Defaults to + ``torch.float32``. + device (obj:`torch.device`): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1] + prioris = torch.stack([x, y], 1).to(dtype) + prioris = prioris.to(device) + return prioris diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/__init__.py new file mode 100755 index 0000000..bf35a61 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .builder import * # noqa: F403 +from .samplers import MaskPseudoSampler # noqa: F403 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/builder.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/builder.py new file mode 100755 index 0000000..9538c0d --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/builder.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.utils import Registry, build_from_cfg + +BBOX_SAMPLERS = Registry("bbox_sampler") +BBOX_CODERS = Registry("bbox_coder") + + +def build_sampler(cfg, **default_args): + """Builder of box sampler.""" + return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) + + +def build_bbox_coder(cfg, **default_args): + """Builder of box coder.""" + return build_from_cfg(cfg, BBOX_CODERS, default_args) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py new file mode 100755 index 0000000..19c363e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py new file mode 100755 index 0000000..c45cec3 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod + +import torch + +from .sampling_result import SamplingResult + + +class BaseSampler(metaclass=ABCMeta): + """Base class of samplers.""" + + def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + self.pos_sampler = self + self.neg_sampler = self + + @abstractmethod + def _sample_pos(self, assign_result, num_expected, **kwargs): + """Sample positive samples.""" + pass + + @abstractmethod + def _sample_neg(self, assign_result, num_expected, **kwargs): + """Sample negative samples.""" + pass + + def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs): + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + bboxes (Tensor): Boxes to be sampled from. + gt_bboxes (Tensor): Ground truth bboxes. + gt_labels (Tensor, optional): Class labels of ground truth bboxes. + + Returns: + :obj:`SamplingResult`: Sampling result. + + Example: + >>> from mmdet.core.bbox import RandomSampler + >>> from mmdet.core.bbox import AssignResult + >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes + >>> rng = ensure_rng(None) + >>> assign_result = AssignResult.random(rng=rng) + >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) + >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) + >>> gt_labels = None + >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, + >>> add_gt_as_proposals=False) + >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) + """ + if len(bboxes.shape) < 2: + bboxes = bboxes[None, :] + + bboxes = bboxes[:, :4] + + gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8) + if self.add_gt_as_proposals and len(gt_bboxes) > 0: + if gt_labels is None: + raise ValueError("gt_labels must be given when add_gt_as_proposals is True") + bboxes = torch.cat([gt_bboxes, bboxes], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs) + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs) + neg_inds = neg_inds.unique() + + sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags) + return sampling_result diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py new file mode 100755 index 0000000..3e67ea6 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py + +import torch + +from ..builder import BBOX_SAMPLERS +from .base_sampler import BaseSampler +from .mask_sampling_result import MaskSamplingResult + + +@BBOX_SAMPLERS.register_module() +class MaskPseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result, masks, gt_masks, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + masks (torch.Tensor): Bounding boxes + gt_masks (torch.Tensor): Ground truth boxes + Returns: + :obj:`SamplingResult`: sampler results + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8) + sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags) + return sampling_result diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py new file mode 100755 index 0000000..270ffd3 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py + +import torch + +from .sampling_result import SamplingResult + + +class MaskSamplingResult(SamplingResult): + """Mask sampling result.""" + + def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_masks = masks[pos_inds] + self.neg_masks = masks[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_masks.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_masks.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_masks = torch.empty_like(gt_masks) + else: + self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def masks(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_masks, self.neg_masks]) + + def __nice__(self): + data = self.info.copy() + data["pos_masks"] = data.pop("pos_masks").shape + data["neg_masks"] = data.pop("neg_masks").shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_masks": self.pos_masks, + "neg_masks": self.neg_masks, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + } diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py new file mode 100755 index 0000000..aaee3fe --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch + + +class SamplingResult: + """Bbox sampling result. + + Example: + >>> # xdoctest: +IGNORE_WANT + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random(rng=10) + >>> print(f'self = {self}') + self = + """ + + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_bboxes = bboxes[pos_inds] + self.neg_bboxes = bboxes[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def bboxes(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_bboxes, self.neg_bboxes]) + + def to(self, device): + """Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print(f'self = {self.to(None)}') + >>> # xdoctest: +REQUIRES(--gpu) + >>> print(f'self = {self.to(0)}') + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, torch.Tensor): + _dict[key] = value.to(device) + return self + + def __nice__(self): + data = self.info.copy() + data["pos_bboxes"] = data.pop("pos_bboxes").shape + data["neg_bboxes"] = data.pop("neg_bboxes").shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_bboxes": self.pos_bboxes, + "neg_bboxes": self.neg_bboxes, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + } + + @classmethod + def random(cls, rng=None, **kwargs): + """ + Args: + rng (None | int | numpy.random.RandomState): seed or state. + kwargs (keyword arguments): + - num_preds: number of predicted boxes + - num_gts: number of true boxes + - p_ignore (float): probability of a predicted box assigned to \ + an ignored truth. + - p_assigned (float): probability of a predicted box not being \ + assigned. + - p_use_label (float | bool): with labels or not. + + Returns: + :obj:`SamplingResult`: Randomly generated sampling result. + + Example: + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random() + >>> print(self.__dict__) + """ + from mmdet.core.bbox import demodata + from mmdet.core.bbox.assigners.assign_result import AssignResult + from mmdet.core.bbox.samplers.random_sampler import RandomSampler + + rng = demodata.ensure_rng(rng) + + # make probabalistic? + num = 32 + pos_fraction = 0.5 + neg_pos_ub = -1 + + assign_result = AssignResult.random(rng=rng, **kwargs) + + # Note we could just compute an assignment + bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng) + gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng) + + if rng.rand() > 0.2: + # sometimes algorithms squeeze their data, be robust to that + gt_bboxes = gt_bboxes.squeeze() + bboxes = bboxes.squeeze() + + if assign_result.labels is None: + gt_labels = None + else: + gt_labels = None + + if gt_labels is None: + add_gt_as_proposals = False + else: + add_gt_as_proposals = True # make probabalistic? + + sampler = RandomSampler( + num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng + ) + self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + return self diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/__init__.py new file mode 100755 index 0000000..6cdc9e1 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dist_utils import reduce_mean +from .misc import add_prefix, multi_apply diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py new file mode 100755 index 0000000..7dfed42 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch.distributed as dist + + +def reduce_mean(tensor): + """ "Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/misc.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/misc.py new file mode 100755 index 0000000..e07579e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/core/utils/misc.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/__init__.py new file mode 100755 index 0000000..ed89bb0 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost +from .decode_heads import * # noqa: F403 +from .losses import * # noqa: F403 +from .plugins import * # noqa: F403 +from .segmentors import * # noqa: F403 diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/__init__.py new file mode 100755 index 0000000..c4bf73b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vit_adapter import ViTAdapter diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py new file mode 100755 index 0000000..26bfdf8 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py @@ -0,0 +1,442 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp + +from ...ops.modules import MSDeformAttn +from .drop_path import DropPath + + +def get_reference_points(spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + +def deform_inputs(x, patch_size): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor( + [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device + ) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + +class ConvFFN(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() + x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() + x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class Extractor(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0.0, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + with_cp=False, + ): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.with_cffn = with_cffn + self.with_cp = with_cp + if with_cffn: + self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): + def _inner_forward(query, feat): + + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + query = query + attn + + if self.with_cffn: + query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) + return query + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class Injector(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0.0, + with_cp=False, + ): + super().__init__() + self.with_cp = with_cp + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): + def _inner_forward(query, feat): + + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + return query + self.gamma * attn + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class InteractionBlock(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp, + ) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2], + ) + for idx, blk in enumerate(blocks): + x = blk(x, H_toks, W_toks) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c + + +class InteractionBlockWithCls(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp, + ) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2], + ) + x = torch.cat((cls, x), dim=1) + for idx, blk in enumerate(blocks): + x = blk(x, H_toks, W_toks) + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 1:, + ], + ) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c, cls + + +class SpatialPriorModule(nn.Module): + def __init__(self, inplanes=64, embed_dim=384, with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.stem = nn.Sequential( + *[ + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ] + ) + self.conv2 = nn.Sequential( + *[ + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(2 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv3 = nn.Sequential( + *[ + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv4 = nn.Sequential( + *[ + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, x): + def _inner_forward(x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + if self.with_cp and x.requires_grad: + outs = cp.checkpoint(_inner_forward, x) + else: + outs = _inner_forward(x) + return outs diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py new file mode 100755 index 0000000..864eb87 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit.py new file mode 100755 index 0000000..8a14757 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit.py @@ -0,0 +1,552 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +"""Vision Transformer (ViT) in PyTorch. + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +from functools import partial +from itertools import repeat +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.runner import BaseModule, load_checkpoint +from mmseg.ops import resize +from mmseg.utils import get_root_logger +from torch import Tensor + +from .drop_path import DropPath + + +def to_2tuple(x): + return tuple(repeat(x, 2)) + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + swiglu_hidden_features = int(2 * hidden_features / 3) + align_as = 8 + swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as + self.w1 = nn.Linear(in_features, swiglu_hidden_features) + self.w2 = nn.Linear(in_features, swiglu_hidden_features) + self.w3 = nn.Linear(swiglu_hidden_features, out_features) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.w1(x) + x2 = self.w2(x) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding.""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, H, W + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, H, W) -> Tensor: + from xformers.ops import memory_efficient_attention, unbind + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowedAttention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant" + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.window_size = window_size + self.pad_mode = pad_mode + + def forward(self, x, H, W): + B, N, C = x.shape + N_ = self.window_size * self.window_size + H_ = math.ceil(H / self.window_size) * self.window_size + W_ = math.ceil(W / self.window_size) * self.window_size + + qkv = self.qkv(x) # [B, N, C] + qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W] + qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode) + + qkv = F.unfold( + qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size) + ) + B, C_kw_kw, L = qkv.shape # L - the num of windows + qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C] + qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # q,k,v [B, L, num_head, N_, C/num_head] + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] + # if self.mask: + # attn = attn * mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] + # attn @ v = [B, L, num_head, N_, C/num_head] + x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L) + + x = F.fold( + x, + output_size=(H_, W_), + kernel_size=(self.window_size, self.window_size), + stride=(self.window_size, self.window_size), + ) # [B, C, H_, W_] + x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# class WindowedAttention(nn.Module): +# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"): +# super().__init__() +# self.num_heads = num_heads +# head_dim = dim // num_heads +# self.scale = head_dim ** -0.5 +# +# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) +# self.attn_drop = nn.Dropout(attn_drop) +# self.proj = nn.Linear(dim, dim) +# self.proj_drop = nn.Dropout(proj_drop) +# self.window_size = window_size +# self.pad_mode = pad_mode +# +# def forward(self, x, H, W): +# B, N, C = x.shape +# +# N_ = self.window_size * self.window_size +# H_ = math.ceil(H / self.window_size) * self.window_size +# W_ = math.ceil(W / self.window_size) * self.window_size +# x = x.view(B, H, W, C) +# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode) +# +# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C +# x = x.view(-1, N_, C) +# +# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) +# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) +# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] +# attn = attn.softmax(dim=-1) +# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] +# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C) +# +# x = window_reverse(x, self.window_size, H_, W_) +# x = x[:, :H, :W, :].reshape(B, N, C).contiguous() +# x = self.proj(x) +# x = self.proj_drop(x) +# return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + windowed=False, + window_size=14, + pad_mode="constant", + layer_scale=False, + with_cp=False, + ffn_layer=Mlp, + memeff=False, + ): + super().__init__() + self.with_cp = with_cp + self.norm1 = norm_layer(dim) + if windowed: + self.attn = WindowedAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + pad_mode=pad_mode, + ) + elif memeff: + self.attn = MemEffAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + else: + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.layer_scale = layer_scale + if layer_scale: + self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) + self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) + + def forward(self, x, H, W): + def _inner_forward(x): + if self.layer_scale: + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class TIMMVisionTransformer(BaseModule): + """Vision Transformer. + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + layer_scale=True, + embed_layer=PatchEmbed, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + window_attn=False, + window_size=14, + pretrained=None, + with_cp=False, + pre_norm=False, + ffn_type="mlp", + memeff=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + pretrained: (str): pretrained path + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.norm_layer = norm_layer + self.act_layer = act_layer + self.pretrain_size = img_size + self.drop_path_rate = drop_path_rate + self.drop_rate = drop_rate + self.patch_size = patch_size + + window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn + window_size = [window_size] * depth if not isinstance(window_size, list) else window_size + logging.info("window attention:", window_attn) + logging.info("window size:", window_size) + logging.info("layer scale:", layer_scale) + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm + ) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN} + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + windowed=window_attn[i], + window_size=window_size[i], + layer_scale=layer_scale, + with_cp=with_cp, + ffn_layer=ffn_types[ffn_type], + memeff=memeff, + ) + for i in range(depth) + ] + ) + + # self.norm = norm_layer(embed_dim) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # For CLIP + if pre_norm: + norm_pre = norm_layer(embed_dim) + self.norm_pre = norm_pre + else: + self.norm_pre = nn.Identity() + self.init_weights(pretrained) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger) + + def forward_features(self, x): + x, H, W = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + + # For CLIP + x = self.norm_pre(x) + + for blk in self.blocks: + x = blk(x, H, W) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + return x + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]" + pos_h, pos_w = pos_shape + # keep dim for easy deployment + cls_token_weight = pos_embed[:, 0:1] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :] + pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py new file mode 100755 index 0000000..ebc4f0f --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import BACKBONES +from torch.nn.init import normal_ + +from ...ops.modules import MSDeformAttn +from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs +from .vit import TIMMVisionTransformer + + +@BACKBONES.register_module() +class ViTAdapter(TIMMVisionTransformer): + def __init__( + self, + pretrain_size=224, + num_heads=12, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0.0, + interaction_indexes=None, + with_cffn=True, + cffn_ratio=0.25, + deform_ratio=1.0, + add_vit_feature=True, + pretrained=None, + use_extra_extractor=True, + freeze_vit=False, + use_cls=True, + with_cp=False, + *args, + **kwargs + ): + + super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs) + if freeze_vit: + for param in self.parameters(): + param.requires_grad = False + + # self.num_classes = 80 + self.use_cls = use_cls + if not self.use_cls: + self.cls_token = None + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + + block_fn = InteractionBlockWithCls if use_cls else InteractionBlock + + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) + self.interactions = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor), + with_cp=with_cp, + ) + for i in range(len(interaction_indexes)) + ] + ) + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape( + 1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1 + ).permute(0, 3, 1, 2) + pos_embed = ( + F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) + .reshape(1, -1, H * W) + .permute(0, 2, 1) + ) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + c = torch.cat([c2, c3, c4], dim=1) + + # Patch Embedding forward + H_c, W_c = x.shape[2] // 16, x.shape[3] // 16 + x, H_toks, W_toks = self.patch_embed(x) + # print("H_toks, W_toks =", H_toks, W_toks) + bs, n, dim = x.shape + pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks) + if self.use_cls: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1) + x = self.pos_drop(x + pos_embed) + # For CLIP + x = self.norm_pre(x) + + # Interaction + if self.use_cls: + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 1:, + ], + ) + outs = list() + for i, layer in enumerate(self.interactions): + indexes = self.interaction_indexes[i] + if self.use_cls: + x, c, cls = layer( + x, + c, + cls, + self.blocks[indexes[0] : indexes[-1] + 1], + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + else: + x, c = layer( + x, + c, + self.blocks[indexes[0] : indexes[-1] + 1], + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous()) + + # Split & Reshape + c2 = c[:, 0 : c2.size(1), :] + c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1) :, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + + x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False) + x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False) + x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False) + x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False) + # print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/builder.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/builder.py new file mode 100755 index 0000000..d7cf7b9 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.utils import Registry + +TRANSFORMER = Registry("Transformer") +MASK_ASSIGNERS = Registry("mask_assigner") +MATCH_COST = Registry("match_cost") + + +def build_match_cost(cfg): + """Build Match Cost.""" + return MATCH_COST.build(cfg) + + +def build_assigner(cfg): + """Build Assigner.""" + return MASK_ASSIGNERS.build(cfg) + + +def build_transformer(cfg): + """Build Transformer.""" + return TRANSFORMER.build(cfg) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py new file mode 100755 index 0000000..01f08b8 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .mask2former_head import Mask2FormerHead diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py new file mode 100755 index 0000000..d1705fc --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py @@ -0,0 +1,544 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence +from mmcv.ops import point_sample +from mmcv.runner import ModuleList, force_fp32 +from mmseg.models.builder import HEADS, build_loss +from mmseg.models.decode_heads.decode_head import BaseDecodeHead + +from ...core import build_sampler, multi_apply, reduce_mean +from ..builder import build_assigner +from ..utils import get_uncertain_point_coords_with_randomness + + +@HEADS.register_module() +class Mask2FormerHead(BaseDecodeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of + Mask2Former head. + test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of + Mask2Former head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=None, + loss_mask=None, + loss_dice=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs, + ): + super(Mask2FormerHead, self).__init__( + in_channels=in_channels, + channels=feat_channels, + num_classes=(num_things_classes + num_stuff_classes), + init_cfg=init_cfg, + input_transform="multiple_select", + **kwargs, + ) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] + self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project: + self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = build_positional_encoding(positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), + nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), + nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels), + ) + self.conv_seg = None # fix a bug here (conv_seg is not used) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + self.sampler = build_sampler(self.train_cfg.sampler, context=self) + self.num_points = self.train_cfg.get("num_points", 12544) + self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0) + self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape [num_queries, + cls_out_channels]. + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape [num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images. + Each with shape [num_queries, ]. + - label_weights_list (list[Tensor]): Label weights of all + images.Each with shape [num_queries, ]. + - mask_targets_list (list[Tensor]): Mask targets of all images. + Each with shape [num_queries, h, w]. + - mask_weights_list (list[Tensor]): Mask weights of all images. + Each with shape [num_queries, ]. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + ( + labels_list, + label_weights_list, + mask_targets_list, + mask_weights_list, + pos_inds_list, + neg_inds_list, + ) = multi_apply( + self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas + ) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (num_gts, ). + gt_masks (Tensor): Ground truth mask for each image, each with + shape (num_gts, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + """ + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1) + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries,)) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries,)) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds) + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (num_gts, ). + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (num_gts, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + ( + labels_list, + label_weights_list, + mask_targets_list, + mask_weights_list, + num_total_pos, + num_total_neg, + ) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio + ) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1, 1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + @force_fp32(apply_to=("all_cls_scores", "all_mask_preds")) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape [num_decoder, batch_size, num_queries, + cls_out_channels]. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape [num_decoder, batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list + ) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict["loss_cls"] = losses_cls[-1] + loss_dict["loss_mask"] = losses_mask[-1] + loss_dict["loss_dice"] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i + loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature) + attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = len(img_metas) + mask_features, multi_scale_memorys = self.pixel_decoder(feats) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding(mask) + decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + ) + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:] + ) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list + + def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks): + """Forward function for training mode. + + Args: + x (list[Tensor]): Multi-level features from the upstream network, + each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + train_cfg (dict): The training config, which not been used in + maskformer. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + + # forward + all_cls_scores, all_mask_preds = self(x, img_metas) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + inputs (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + test_cfg (dict): Testing config. + + Returns: + seg_mask (Tensor): Predicted semantic segmentation logits. + """ + all_cls_scores, all_mask_preds = self(inputs, img_metas) + cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] + ori_h, ori_w, _ = img_metas[0]["ori_shape"] + + # semantic inference + cls_score = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred) + return seg_mask diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/__init__.py new file mode 100755 index 0000000..229a887 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy +from .dice_loss import DiceLoss +from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py new file mode 100755 index 0000000..0a1f9dd --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss + + +def cross_entropy( + pred, + label, + weight=None, + class_weight=None, + reduction="mean", + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False, +): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == "mean": + avg_factor = label.numel() - (label == ignore_index).sum().item() + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy( + pred, + label, + weight=None, + reduction="mean", + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs, +): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes" + pred = pred.squeeze() + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), ( + "Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported" + ) + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == "mean" and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none") + # do the reduction for the weighted loss + loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy( + pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs +): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, "BCE loss does not support ignore_index" + assert reduction == "mean" and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None] + + +@LOSSES.register_module(force=True) +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + def __init__( + self, + use_sigmoid=False, + use_mask=False, + reduction="mean", + class_weight=None, + loss_weight=1.0, + loss_name="loss_ce", + avg_non_ignore=False, + ): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == "mean": + warnings.warn( + "Default ``avg_non_ignore`` is False, if you would like to " + "ignore the certain label and average loss over non-ignore " + "labels, which is the same with PyTorch official " + "cross_entropy, set ``avg_non_ignore=True``." + ) + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f"avg_non_ignore={self.avg_non_ignore}" + return s + + def forward( + self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs + ): + """Forward function.""" + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs, + ) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py new file mode 100755 index 0000000..1bc5ba8 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import weight_reduce_loss + + +def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): + """Calculate dice loss, which is proposed in + `V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation `_. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): + """Calculate naive dice loss, the coefficient in the denominator is the + first power instead of the second power. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module(force=True) +class DiceLoss(nn.Module): + def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3): + """Dice Loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + `_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + """ + + super(DiceLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + + def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + else: + raise NotImplementedError + + if self.naive_dice: + loss = self.loss_weight * naive_dice_loss( + pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor + ) + else: + loss = self.loss_weight * dice_loss( + pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor + ) + + return loss diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/match_costs.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/match_costs.py new file mode 100755 index 0000000..4917d2a --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/losses/match_costs.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from ..builder import MATCH_COST + + +@MATCH_COST.register_module() +class ClassificationCost: + """ClsSoftmaxCost.Borrow from + mmdet.core.bbox.match_costs.match_cost.ClassificationCost. + + Args: + weight (int | float, optional): loss_weight + + Examples: + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight=1.0): + self.weight = weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be omitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class DiceCost: + """Cost of mask assignments based on dice losses. + + Args: + weight (int | float, optional): loss_weight. Defaults to 1. + pred_act (bool, optional): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float, optional): default 1e-12. + """ + + def __init__(self, weight=1.0, pred_act=False, eps=1e-3): + self.weight = weight + self.pred_act = pred_act + self.eps = eps + + def binary_mask_dice_loss(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + mask_preds = mask_preds.reshape((mask_preds.shape[0], -1)) + gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float() + numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks) + denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction logits in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W). + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + if self.pred_act: + mask_preds = mask_preds.sigmoid() + dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) + return dice_cost * self.weight + + +@MATCH_COST.register_module() +class CrossEntropyLossCost: + """CrossEntropyLossCost. + + Args: + weight (int | float, optional): loss weight. Defaults to 1. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to True. + """ + + def __init__(self, weight=1.0, use_sigmoid=True): + assert use_sigmoid, "use_sigmoid = False is not supported yet." + self.weight = weight + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): The prediction with shape (num_query, 1, *) or + (num_query, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + Returns: + Tensor: Cross entropy cost matrix in shape (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none") + neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none") + cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits. + gt_labels (Tensor): Labels. + Returns: + Tensor: Cross entropy cost matrix with weight in + shape (num_query, num_gt). + """ + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(cls_pred, gt_labels) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/plugins/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/plugins/__init__.py new file mode 100755 index 0000000..81a60db --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/plugins/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py new file mode 100755 index 0000000..db19471 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init +from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence +from mmcv.runner import BaseModule, ModuleList + +from ...core.anchor import MlvlPointGenerator +from ..utils.transformer import MultiScaleDeformableAttention + + +@PLUGIN_LAYERS.register_module() +class MSDeformAttnPixelDecoder(BaseModule): + """Pixel decoder with multi-scale deformable attention. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + strides (list[int] | tuple[int]): Output strides of feature from + backbone. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_outs (int): Number of output scales. + norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer + encoder. Defaults to `DetrTransformerEncoder`. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict. + """ + + def __init__( + self, + in_channels=[256, 512, 1024, 2048], + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_outs=3, + norm_cfg=dict(type="GN", num_groups=32), + act_cfg=dict(type="ReLU"), + encoder=dict( + type="DetrTransformerEncoder", + num_layers=6, + transformerlayers=dict( + type="BaseTransformerLayer", + attn_cfgs=dict( + type="MultiScaleDeformableAttention", + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None, + ), + feedforward_channels=1024, + ffn_dropout=0.0, + operation_order=("self_attn", "norm", "ffn", "norm"), + ), + init_cfg=None, + ), + positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.strides = strides + self.num_input_levels = len(in_channels) + self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels + assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one" + input_conv_list = [] + # from top to down (low to high resolution) + for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1): + input_conv = ConvModule( + in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True + ) + input_conv_list.append(input_conv) + self.input_convs = ModuleList(input_conv_list) + + self.encoder = build_transformer_layer_sequence(encoder) + self.postional_encoding = build_positional_encoding(positional_encoding) + # high resolution to low resolution + self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels) + + # fpn-like structure + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + # from top to down (low to high resolution) + # fpn for the rest features that didn't pass in encoder + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): + lateral_conv = ConvModule( + in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None + ) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0) + + self.num_outs = num_outs + self.point_generator = MlvlPointGenerator(strides) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_encoder_levels): + xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform") + + for i in range(0, self.num_input_levels - self.num_encoder_levels): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + + normal_init(self.level_encoding, mean=0, std=1) + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # init_weights defined in MultiScaleDeformableAttention + for layer in self.encoder.layers: + for attn in layer.attentions: + if isinstance(attn, MultiScaleDeformableAttention): + attn.init_weights() + + def forward(self, feats): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + h, w = feat.shape[-2:] + + # no padding + padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device + ) + # normalize + factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(2, 0, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat.shape[-2:]) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_query), + # total_num_query=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_query, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=0) + level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0) + device = encoder_inputs.device + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2)) + # shape (num_total_query, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + key=None, + value=None, + query_pos=level_positional_encodings, + key_pos=None, + attn_masks=None, + key_padding_mask=None, + query_key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_radios=valid_radios, + ) + # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) + memory = memory.permute(1, 2, 0) + + # from low resolution to high resolution + num_query_per_level = [e[0] * e[1] for e in spatial_shapes] + outs = torch.split(memory, num_query_per_level, dim=-1) + outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[: self.num_outs] + + mask_feature = self.mask_feature(outs[-1]) + return mask_feature, multi_scale_features diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py new file mode 100755 index 0000000..adf0062 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .encoder_decoder_mask2former import EncoderDecoderMask2Former diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py new file mode 100755 index 0000000..cfe572c --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.core import add_prefix +from mmseg.models import builder +from mmseg.models.builder import SEGMENTORS +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.ops import resize + + +@SEGMENTORS.register_module() +class EncoderDecoderMask2Former(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__( + self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None, + ): + super(EncoderDecoderMask2Former, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight" + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + decode_head.update(train_cfg=train_cfg) + decode_head.update(test_cfg=test_cfg) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs) + + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, f"aux_{idx}")) + else: + loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, "aux")) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + if rescale: + preds = resize( + preds, + size=img_meta[0]["ori_shape"][:2], + mode="bilinear", + align_corners=self.align_corners, + warning=False, + ) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]["ori_shape"][:2] + seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if self.test_cfg.mode == "slide": + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/__init__.py new file mode 100755 index 0000000..e7fdc16 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .assigner import MaskHungarianAssigner +from .point_sample import get_uncertain_point_coords_with_randomness +from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding +from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/assigner.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/assigner.py new file mode 100755 index 0000000..3cb08fc --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/assigner.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod + +import torch + +from ..builder import MASK_ASSIGNERS, build_match_cost + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + + +class AssignResult(metaclass=ABCMeta): + """Collection of assign results.""" + + def __init__(self, num_gts, gt_inds, labels): + self.num_gts = num_gts + self.gt_inds = gt_inds + self.labels = labels + + @property + def info(self): + info = { + "num_gts": self.num_gts, + "gt_inds": self.gt_inds, + "labels": self.labels, + } + return info + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns boxes to ground truth boxes.""" + + @abstractmethod + def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None): + """Assign boxes to either a ground truth boxes or a negative boxes.""" + pass + + +@MASK_ASSIGNERS.register_module() +class MaskHungarianAssigner(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth for + mask. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config. + mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config. + dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config. + """ + + def __init__( + self, + cls_cost=dict(type="ClassificationCost", weight=1.0), + dice_cost=dict(type="DiceCost", weight=1.0), + mask_cost=dict(type="MaskFocalCost", weight=1.0), + ): + self.cls_cost = build_match_cost(cls_cost) + self.dice_cost = build_match_cost(dice_cost) + self.mask_cost = build_match_cost(mask_cost) + + def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + mask_pred (Tensor): Predicted mask, shape [num_query, h, w] + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w]. + gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_masks_ignore (Tensor, optional): Ground truth masks that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported." + num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0] + + # 1. assign -1 by default + assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long) + assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long) + if num_gts == 0 or num_queries == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) + + # 2. compute the weighted costs + # classification and maskcost. + if self.cls_cost.weight != 0 and cls_pred is not None: + cls_cost = self.cls_cost(cls_pred, gt_labels) + else: + cls_cost = 0 + + if self.mask_cost.weight != 0: + # mask_pred shape = [nq, h, w] + # gt_mask shape = [ng, h, w] + # mask_cost shape = [nq, ng] + mask_cost = self.mask_cost(mask_pred, gt_masks) + else: + mask_cost = 0 + + if self.dice_cost.weight != 0: + dice_cost = self.dice_cost(mask_pred, gt_masks) + else: + dice_cost = 0 + cost = cls_cost + mask_cost + dice_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' "to install scipy first.") + + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/point_sample.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/point_sample.py new file mode 100755 index 0000000..9f11340 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/point_sample.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +from mmcv.ops import point_sample + + +def get_uncertainty(mask_pred, labels): + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_pred' for the foreground class in `classes`. + + Args: + mask_pred (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (list[Tensor]): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_pred.shape[1] == 1: + gt_class_logits = mask_pred.clone() + else: + inds = torch.arange(mask_pred.shape[0], device=mask_pred.device) + gt_class_logits = mask_pred[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness( + mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio +): + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_pred (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (list): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_pred.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device) + point_logits = point_sample(mask_pred, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py new file mode 100755 index 0000000..bf5d6fa --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING +from mmcv.runner import BaseModule + + +@POSITIONAL_ENCODING.register_module() +class SinePositionalEncoding(BaseModule): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + `_ for details. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float): offset add to embed when do the normalization. + Defaults to 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__( + self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None + ): + super(SinePositionalEncoding, self).__init__(init_cfg) + if normalize: + assert isinstance(scale, (float, int)), ( + "when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}" + ) + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + self.offset = offset + + def forward(self, mask): + """Forward function for `SinePositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + # For convenience of exporting to ONNX, it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + B, H, W = mask.size() + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"temperature={self.temperature}, " + repr_str += f"normalize={self.normalize}, " + repr_str += f"scale={self.scale}, " + repr_str += f"eps={self.eps})" + return repr_str + + +@POSITIONAL_ENCODING.register_module() +class LearnedPositionalEncoding(BaseModule): + """Position embedding with learnable embedding weights. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row embeddings. + Default 50. + col_num_embed (int, optional): The dictionary size of col embeddings. + Default 50. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")): + super(LearnedPositionalEncoding, self).__init__(init_cfg) + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + + def forward(self, mask): + """Forward function for `LearnedPositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = ( + torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(mask.shape[0], 1, 1, 1) + ) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"row_num_embed={self.row_num_embed}, " + repr_str += f"col_num_embed={self.col_num_embed})" + return repr_str diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/transformer.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/transformer.py new file mode 100755 index 0000000..8befe60 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/models/utils/transformer.py @@ -0,0 +1,989 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import warnings +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE +from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence +from mmcv.runner.base_module import BaseModule, Sequential +from mmcv.utils import deprecated_api_warning, to_2tuple +from torch.nn.init import normal_ + +from ..builder import TRANSFORMER + +try: + from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention + +except ImportError: + warnings.warn( + "`MultiScaleDeformableAttention` in MMCV has been moved to " + "`mmcv.ops.multi_scale_deform_attn`, please update your MMCV" + ) + from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): + + super(AdaptivePadding, self).__init__() + + assert padding in ("same", "corner") + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == "corner": + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == "same": + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding="corner", + dilation=1, + bias=False, + norm_cfg=dict(type="LN"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}" + + H, W = input_size + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = ( + H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1 + ) // self.sampler.stride[0] + 1 + out_w = ( + W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1 + ) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +@FEEDFORWARD_NETWORK.register_module(force=True) +class FFN(BaseModule): + """Implements feed-forward networks (FFNs) with identity connection. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN") + def __init__( + self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type="ReLU", inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True, + init_cfg=None, + with_cp=False, + **kwargs, + ): + super().__init__(init_cfg) + assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}." + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + self.with_cp = with_cp + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + @deprecated_api_warning({"residual": "identity"}, cls_name="FFN") + def forward(self, x, identity=None): + """Forward function for `FFN`. + The function would add x to the output tensor if residue is None. + """ + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.layers, x) + else: + out = self.layers(x) + + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +@TRANSFORMER_LAYER.register_module() +class DetrTransformerDecoderLayer(BaseTransformerLayer): + """Implements decoder layer in DETR transformer. + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): + Configs for self_attention or cross_attention, the order + should be consistent with it in `operation_order`. If it is + a dict, it would be expand to the number of attention in + `operation_order`. + feedforward_channels (int): The hidden dimension for FFNs. + ffn_dropout (float): Probability of an element to be zeroed + in ffn. Default 0.0. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Default:None + act_cfg (dict): The activation config for FFNs. Default: `LN` + norm_cfg (dict): Config dict for normalization layer. + Default: `LN`. + ffn_num_fcs (int): The number of fully-connected layers in FFNs. + Default:2. + """ + + def __init__( + self, + attn_cfgs, + feedforward_channels, + ffn_dropout=0.0, + operation_order=None, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + ffn_num_fcs=2, + **kwargs, + ): + super(DetrTransformerDecoderLayer, self).__init__( + attn_cfgs=attn_cfgs, + feedforward_channels=feedforward_channels, + ffn_dropout=ffn_dropout, + operation_order=operation_order, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ffn_num_fcs=ffn_num_fcs, + **kwargs, + ) + assert len(operation_order) == 6 + assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"]) + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerEncoder(TransformerLayerSequence): + """TransformerEncoder of DETR. + + Args: + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. Only used when `self.pre_norm` is `True` + """ + + def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs): + super(DetrTransformerEncoder, self).__init__(*args, **kwargs) + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None + else: + assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg" + self.post_norm = None + + def forward(self, *args, **kwargs): + """Forward function for `TransformerCoder`. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + x = super(DetrTransformerEncoder, self).forward(*args, **kwargs) + if self.post_norm is not None: + x = self.post_norm(x) + return x + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs): + + super(DetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] + else: + self.post_norm = None + + def forward(self, query, *args, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + if not self.return_intermediate: + x = super().forward(query, *args, **kwargs) + if self.post_norm: + x = self.post_norm(x)[None] + return x + + intermediate = [] + for layer in self.layers: + query = layer(query, *args, **kwargs) + if self.return_intermediate: + if self.post_norm is not None: + intermediate.append(self.post_norm(query)) + else: + intermediate.append(query) + return torch.stack(intermediate) + + +@TRANSFORMER.register_module() +class Transformer(BaseModule): + """Implements the DETR transformer. + + Following the official DETR implementation, this module copy-paste + from torch.nn.Transformer with modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + + See `paper: End-to-End Object Detection with Transformers + `_ for details. + + Args: + encoder (`mmcv.ConfigDict` | Dict): Config of + TransformerEncoder. Defaults to None. + decoder ((`mmcv.ConfigDict` | Dict)): Config of + TransformerDecoder. Defaults to None + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Defaults to None. + """ + + def __init__(self, encoder=None, decoder=None, init_cfg=None): + super(Transformer, self).__init__(init_cfg=init_cfg) + self.encoder = build_transformer_layer_sequence(encoder) + self.decoder = build_transformer_layer_sequence(decoder) + self.embed_dims = self.encoder.embed_dims + + def init_weights(self): + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, "weight") and m.weight.dim() > 1: + xavier_init(m, distribution="uniform") + self._is_init = True + + def forward(self, x, mask, query_embed, pos_embed): + """Forward function for `Transformer`. + + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, c, h, w = x.shape + # use `view` instead of `flatten` for dynamically exporting to ONNX + x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] + pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] + memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask) + target = torch.zeros_like(query_embed) + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask + ) + out_dec = out_dec.transpose(1, 2) + memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) + return out_dec, memory + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DeformableDetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, return_intermediate=False, **kwargs): + + super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + + def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + reference_points (Tensor): The reference + points of offset. has shape + (bs, num_query, 4) when as_two_stage, + otherwise has shape ((bs, num_query, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + reg_branch: (obj:`nn.ModuleList`): Used for + refining the regression results. Only would + be passed when with_box_refine is True, + otherwise would be passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + ) + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] + output = layer(output, *args, reference_points=reference_points_input, **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +@TRANSFORMER.register_module() +class DeformableDetrTransformer(Transformer): + """Implements the DeformableDETR transformer. + + Args: + as_two_stage (bool): Generate query from encoder features. + Default: False. + num_feature_levels (int): Number of feature maps from FPN: + Default: 4. + two_stage_num_proposals (int): Number of proposals when set + `as_two_stage` as True. Default: 300. + """ + + def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs): + super(DeformableDetrTransformer, self).__init__(**kwargs) + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.two_stage_num_proposals = two_stage_num_proposals + self.embed_dims = self.encoder.embed_dims + self.init_layers() + + def init_layers(self): + """Initialize layers of the DeformableDetrTransformer.""" + self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points = nn.Linear(self.embed_dims, 2) + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if not self.as_two_stage: + xavier_init(self.reference_points, distribution="uniform", bias=0.0) + normal_(self.level_embeds) + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + """Generate proposals from encoded memory. + + Args: + memory (Tensor) : The output of encoder, + has shape (bs, num_key, embed_dim). num_key is + equal the number of points on feature map from + all level. + memory_padding_mask (Tensor): Padding mask for memory. + has shape (bs, num_key). + spatial_shapes (Tensor): The shape of all feature maps. + has shape (num_level, 2). + + Returns: + tuple: A tuple of feature map and bbox prediction. + + - output_memory (Tensor): The input of decoder, \ + has shape (bs, num_key, embed_dim). num_key is \ + equal the number of points on feature map from \ + all levels. + - output_proposals (Tensor): The normalized proposal \ + after a inverse sigmoid, has shape \ + (bs, num_keys, 4). + """ + + N, S, C = memory.shape + proposals = [] + _cur = 0 + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N, -1, 4) + proposals.append(proposal) + _cur += H * W + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """Get the reference points used in decoder. + + Args: + spatial_shapes (Tensor): The shape of all + feature maps, has shape (num_level, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + device (obj:`device`): The device where + reference_points should be. + + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def get_valid_ratio(self, mask): + """Get the valid radios of feature maps of all level.""" + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): + """Get the position embedding of proposal.""" + scale = 2 * math.pi + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def forward( + self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs + ): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs, + ) + + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs, + ) + + inter_references_out = inter_references + if self.as_two_stage: + return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact + return inter_states, init_reference_out, inter_references_out, None, None + + +@TRANSFORMER.register_module() +class DynamicConv(BaseModule): + """Implements Dynamic Convolution. + + This module generate parameters for each sample and + use bmm to implement 1*1 convolution. Code is modified + from the `official github repo `_ . + + Args: + in_channels (int): The input feature channel. + Defaults to 256. + feat_channels (int): The inner feature channel. + Defaults to 64. + out_channels (int, optional): The output feature channel. + When not specified, it will be set to `in_channels` + by default + input_feat_shape (int): The shape of input feature. + Defaults to 7. + with_proj (bool): Project two-dimentional feature to + one-dimentional feature. Default to True. + act_cfg (dict): The activation config for DynamicConv. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + input_feat_shape=7, + with_proj=True, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + init_cfg=None, + ): + super(DynamicConv, self).__init__(init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.input_feat_shape = input_feat_shape + self.with_proj = with_proj + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.in_channels * self.feat_channels + self.num_params_out = self.out_channels * self.feat_channels + self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out) + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + num_output = self.out_channels * input_feat_shape**2 + if self.with_proj: + self.fc_layer = nn.Linear(num_output, self.out_channels) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, param_feature, input_feature): + """Forward function for `DynamicConv`. + + Args: + param_feature (Tensor): The feature can be used + to generate the parameter, has shape + (num_all_proposals, in_channels). + input_feature (Tensor): Feature that + interact with parameters, has shape + (num_all_proposals, in_channels, H, W). + + Returns: + Tensor: The output feature has shape + (num_all_proposals, out_channels). + """ + input_feature = input_feature.flatten(2).permute(2, 0, 1) + + input_feature = input_feature.permute(1, 0, 2) + parameters = self.dynamic_layer(param_feature) + + param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels) + param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels) + + # input_feature has shape (num_all_proposals, H*W, in_channels) + # param_in has shape (num_all_proposals, in_channels, feat_channels) + # feature has shape (num_all_proposals, H*W, feat_channels) + features = torch.bmm(input_feature, param_in) + features = self.norm_in(features) + features = self.activation(features) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = torch.bmm(features, param_out) + features = self.norm_out(features) + features = self.activation(features) + + if self.with_proj: + features = features.flatten(1) + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/ops/modules/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/ops/modules/__init__.py new file mode 100755 index 0000000..49aa8fe --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/ops/modules/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules +# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 + +from .ms_deform_attn import MSDeformAttn diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py new file mode 100755 index 0000000..d8b4fa2 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import warnings + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Function +from torch.cuda.amp import custom_fwd +from torch.nn.init import constant_, xavier_uniform_ + + +class MSDeformAttnFunction(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step + ): + output = ms_deform_attn_core_pytorch( + value, + value_spatial_shapes, + # value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0): + """Multi-Scale Deformable Attention Module. + + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 + # which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make " + "the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + self.ratio = ratio + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, int(d_model * ratio)) + self.output_proj = nn.Linear(int(d_model * ratio), d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + # print(query.shape) + # print(reference_points.shape) + # print(input_flatten.shape) + # print(input_spatial_shapes.shape) + # print(input_level_start_index.shape) + # print(input_spatial_shapes) + # print(input_level_start_index) + + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + + value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/setup.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/setup.py new file mode 100755 index 0000000..959128c --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/setup.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from typing import Any, List, Optional, Tuple + +import torch +import torch.backends.cudnn as cudnn + +from dinov2.models import build_model_from_cfg +from dinov2.utils.config import setup +import dinov2.utils.utils as dinov2_utils + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parser = argparse.ArgumentParser( + description=description, + parents=parents or [], + add_help=add_help, + ) + parser.add_argument( + "--config-file", + type=str, + help="Model configuration file", + ) + parser.add_argument( + "--pretrained-weights", + type=str, + help="Pretrained model weights", + ) + parser.add_argument( + "--output-dir", + default="", + type=str, + help="Output directory to write results and logs", + ) + parser.add_argument( + "--opts", + help="Extra configuration options", + default=[], + nargs="+", + ) + return parser + + +def get_autocast_dtype(config): + teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype + if teacher_dtype_str == "fp16": + return torch.half + elif teacher_dtype_str == "bf16": + return torch.bfloat16 + else: + return torch.float + + +def build_model_for_eval(config, pretrained_weights): + model, _ = build_model_from_cfg(config, only_teacher=True) + dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") + model.eval() + model.cuda() + return model + + +def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: + cudnn.benchmark = True + config = setup(args) + model = build_model_for_eval(config, args.pretrained_weights) + autocast_dtype = get_autocast_dtype(config) + return model, autocast_dtype diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/utils.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/utils.py new file mode 100755 index 0000000..c50576b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/eval/utils.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, Optional + +import torch +from torch import nn +from torchmetrics import MetricCollection + +from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader +import dinov2.distributed as distributed +from dinov2.logging import MetricLogger + + +logger = logging.getLogger("dinov2") + + +class ModelWithNormalize(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, samples): + return nn.functional.normalize(self.model(samples), dim=1, p=2) + + +class ModelWithIntermediateLayers(nn.Module): + def __init__(self, feature_model, n_last_blocks, autocast_ctx): + super().__init__() + self.feature_model = feature_model + self.feature_model.eval() + self.n_last_blocks = n_last_blocks + self.autocast_ctx = autocast_ctx + + def forward(self, images): + with torch.inference_mode(): + with self.autocast_ctx(): + features = self.feature_model.get_intermediate_layers( + images, self.n_last_blocks, return_class_token=True + ) + return features + + +@torch.inference_mode() +def evaluate( + model: nn.Module, + data_loader, + postprocessors: Dict[str, nn.Module], + metrics: Dict[str, MetricCollection], + device: torch.device, + criterion: Optional[nn.Module] = None, +): + model.eval() + if criterion is not None: + criterion.eval() + + for metric in metrics.values(): + metric = metric.to(device) + + metric_logger = MetricLogger(delimiter=" ") + header = "Test:" + + for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): + outputs = model(samples.to(device)) + targets = targets.to(device) + + if criterion is not None: + loss = criterion(outputs, targets) + metric_logger.update(loss=loss.item()) + + for k, metric in metrics.items(): + metric_inputs = postprocessors[k](outputs, targets) + metric.update(**metric_inputs) + + metric_logger.synchronize_between_processes() + logger.info(f"Averaged stats: {metric_logger}") + + stats = {k: metric.compute() for k, metric in metrics.items()} + metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + return metric_logger_stats, stats + + +def all_gather_and_flatten(tensor_rank): + tensor_all_ranks = torch.empty( + distributed.get_global_size(), + *tensor_rank.shape, + dtype=tensor_rank.dtype, + device=tensor_rank.device, + ) + tensor_list = list(tensor_all_ranks.unbind(0)) + torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) + return tensor_all_ranks.flatten(end_dim=1) + + +def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): + dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) + sample_count = len(dataset_with_enumerated_targets) + data_loader = make_data_loader( + dataset=dataset_with_enumerated_targets, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + ) + return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) + + +@torch.inference_mode() +def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): + gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") + metric_logger = MetricLogger(delimiter=" ") + features, all_labels = None, None + for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): + samples = samples.cuda(non_blocking=True) + labels_rank = labels_rank.cuda(non_blocking=True) + index = index.cuda(non_blocking=True) + features_rank = model(samples).float() + + # init storage feature matrix + if features is None: + features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) + labels_shape = list(labels_rank.shape) + labels_shape[0] = sample_count + all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) + logger.info(f"Storing features into tensor of shape {features.shape}") + + # share indexes, features and labels between processes + index_all = all_gather_and_flatten(index).to(gather_device) + features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) + labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) + + # update storage feature matrix + if len(index_all) > 0: + features.index_copy_(0, index_all, features_all_ranks) + all_labels.index_copy_(0, index_all, labels_all_ranks) + + logger.info(f"Features shape: {tuple(features.shape)}") + logger.info(f"Labels shape: {tuple(all_labels.shape)}") + + assert torch.all(all_labels > -1) + + return features, all_labels diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/fsdp/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/fsdp/__init__.py new file mode 100755 index 0000000..ed45448 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/fsdp/__init__.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Any + +import torch +import dinov2.distributed as distributed +from functools import partial +from fvcore.common.checkpoint import Checkpointer +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp._runtime_utils import _reshard + + +def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): + sharding_strategy_dict = { + "NO_SHARD": ShardingStrategy.NO_SHARD, + "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, + "FULL_SHARD": ShardingStrategy.FULL_SHARD, + } + + dtype_dict = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + + mixed_precision_config = MixedPrecision( + param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], + reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], + buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], + ) + + sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] + + local_rank = distributed.get_local_rank() + + fsdp_wrapper = partial( + FSDP, + sharding_strategy=sharding_strategy_config, + mixed_precision=mixed_precision_config, + device_id=local_rank, + sync_module_states=True, + use_orig_params=True, + auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), + ) + return fsdp_wrapper + + +def is_fsdp(x): + return isinstance(x, FSDP) + + +def is_sharded_fsdp(x): + return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD + + +def free_if_fsdp(x): + if is_sharded_fsdp(x): + handles = x._handles + true_list = [True for h in handles] + _reshard(x, handles, true_list) + + +def get_fsdp_modules(x): + return FSDP.fsdp_modules(x) + + +def reshard_fsdp_model(x): + for m in get_fsdp_modules(x): + free_if_fsdp(m) + + +def rankstr(): + return f"rank_{distributed.get_global_rank()}" + + +class FSDPCheckpointer(Checkpointer): + def save(self, name: str, **kwargs: Any) -> None: + """ + Dump model and checkpointables to a file. + + Args: + name (str): name of the file. + kwargs (dict): extra arbitrary data to save. + """ + if not self.save_dir or not self.save_to_disk: + return + + data = {} + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + data["model"] = self.model.state_dict() + + # data["model"] = self.model.state_dict() + for key, obj in self.checkpointables.items(): + data[key] = obj.state_dict() + data.update(kwargs) + + basename = f"{name}.{rankstr()}.pth" + save_file = os.path.join(self.save_dir, basename) + assert os.path.basename(save_file) == basename, basename + self.logger.info("Saving checkpoint to {}".format(save_file)) + with self.path_manager.open(save_file, "wb") as f: + torch.save(data, f) + self.tag_last_checkpoint(basename) + + def load(self, *args, **kwargs): + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + return super().load(*args, **kwargs) + + def has_checkpoint(self) -> bool: + """ + Returns: + bool: whether a checkpoint exists in the target directory. + """ + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + return self.path_manager.exists(save_file) + + def get_checkpoint_file(self) -> str: + """ + Returns: + str: The latest checkpoint file in target directory. + """ + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + try: + with self.path_manager.open(save_file, "r") as f: + last_saved = f.read().strip() + except IOError: + # if file doesn't exist, maybe because it has just been + # deleted by a separate process + return "" + # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got + # `Union[bytes, str]`. + return os.path.join(self.save_dir, last_saved) + + def tag_last_checkpoint(self, last_filename_basename: str) -> None: + """ + Tag the last checkpoint. + + Args: + last_filename_basename (str): the basename of the last filename. + """ + if distributed.is_enabled(): + torch.distributed.barrier() + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + with self.path_manager.open(save_file, "w") as f: + f.write(last_filename_basename) # pyre-ignore + + +ShardedGradScaler = ShardedGradScaler diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/__init__.py new file mode 100755 index 0000000..b88da6b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/backbones.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/backbones.py new file mode 100755 index 0000000..53fe837 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/classifiers.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/classifiers.py new file mode 100755 index 0000000..3f0841e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/classifiers.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch +import torch.nn as nn + +from .backbones import _make_dinov2_model +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + IMAGENET1K = "IMAGENET1K" + + +def _make_dinov2_linear_classification_head( + *, + arch_name: str = "vit_large", + patch_size: int = 14, + embed_dim: int = 1024, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) + + if pretrained: + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + layers_str = str(layers) if layers == 4 else "" + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + linear_head.load_state_dict(state_dict, strict=True) + + return linear_head + + +class _LinearClassifierWrapper(nn.Module): + def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): + super().__init__() + self.backbone = backbone + self.linear_head = linear_head + self.layers = layers + + def forward(self, x): + if self.layers == 1: + x = self.backbone.forward_features(x) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + # fmt: off + linear_input = torch.cat([ + cls_token, + patch_tokens.mean(dim=1), + ], dim=1) + # fmt: on + elif self.layers == 4: + x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) + # fmt: off + linear_input = torch.cat([ + x[0][1], + x[1][1], + x[2][1], + x[3][1], + x[3][0].mean(dim=1), + ], dim=1) + # fmt: on + else: + assert False, f"Unsupported number of layers: {self.layers}" + return self.linear_head(linear_input) + + +def _make_dinov2_linear_classifier( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + **kwargs, +): + backbone = _make_dinov2_model( + arch_name=arch_name, + pretrained=pretrained, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + **kwargs, + ) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + linear_head = _make_dinov2_linear_classification_head( + arch_name=arch_name, + patch_size=patch_size, + embed_dim=embed_dim, + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=num_register_tokens, + ) + + return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) + + +def dinov2_vits14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitb14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitl14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitg14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vits14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/__init__.py new file mode 100755 index 0000000..91716e5 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .decode_heads import BNHead, DPTHead +from .encoder_decoder import DepthEncoderDecoder diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/decode_heads.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/decode_heads.py new file mode 100755 index 0000000..f455acc --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/decode_heads.py @@ -0,0 +1,747 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from functools import partial +import math +import warnings + +import torch +import torch.nn as nn + +from .ops import resize + + +# XXX: (Untested) replacement for mmcv.imdenormalize() +def _imdenormalize(img, mean, std, to_bgr=True): + import numpy as np + + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = (img * std) + mean + if to_bgr: + img = img[::-1] + return img + + +class DepthBaseDecodeHead(nn.Module): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_layer (nn.Module): Conv layers. Default: None. + act_layer (nn.Module): Activation layers. Default: nn.ReLU. + loss_decode (dict): Config of decode loss. + Default: (). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_layer (dict|None): Norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + conv_layer=None, + act_layer=nn.ReLU, + channels=96, + loss_decode=(), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_layer=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conf_layer = conv_layer + self.act_layer = act_layer + self.loss_decode = loss_decode + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_layer = norm_layer + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + import numpy as np + + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = _imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} + + +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + return output + + +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_layer. Bias will be set as True if `norm_layer` is None, otherwise + False. Default: "auto". + conv_layer (nn.Module): Convolution layer. Default: None, + which means using conv2d. + norm_layer (nn.Module): Normalization layer. Default: None. + act_layer (nn.Module): Activation layer. Default: nn.ReLU. + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = "conv_block" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias="auto", + conv_layer=nn.Conv2d, + norm_layer=None, + act_layer=nn.ReLU, + inplace=True, + with_spectral_norm=False, + padding_mode="zeros", + order=("conv", "norm", "act"), + ): + super(ConvModule, self).__init__() + official_padding_mode = ["zeros", "circular"] + self.conv_layer = conv_layer + self.norm_layer = norm_layer + self.act_layer = act_layer + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(["conv", "norm", "act"]) + + self.with_norm = norm_layer is not None + self.with_activation = act_layer is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == "auto": + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + if padding_mode == "zeros": + padding_layer = nn.ZeroPad2d + else: + raise AssertionError(f"Unsupported padding mode: {padding_mode}") + self.pad = padding_layer(padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = self.conv_layer( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index("norm") > order.index("conv"): + norm_channels = out_channels + else: + norm_channels = in_channels + norm = partial(norm_layer, num_features=norm_channels) + self.add_module("norm", norm) + if self.with_bias: + from torch.nnModules.batchnorm import _BatchNorm + from torch.nnModules.instancenorm import _InstanceNorm + + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn("Unnecessary conv bias before batch/instance norm") + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + # nn.Tanh has no 'inplace' argument + # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU) + if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): + act_layer = partial(act_layer, inplace=inplace) + self.activate = act_layer() + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, "init_weights"): + if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): + nonlinearity = "leaky_relu" + a = 0.01 # XXX: default negative_slope + else: + nonlinearity = "relu" + a = 0 + if hasattr(self.conv, "weight") and self.conv.weight is not None: + nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) + if hasattr(self.conv, "bias") and self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + if self.with_norm: + if hasattr(self.norm, "weight") and self.norm.weight is not None: + nn.init.constant_(self.norm.weight, 1) + if hasattr(self.norm, "bias") and self.norm.bias is not None: + nn.init.constant_(self.norm.bias, 0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == "conv": + if self.with_explicit_padding: + x = self.pad(x) + x = self.conv(x) + elif layer == "norm" and norm and self.with_norm: + x = self.norm(x) + elif layer == "act" and activate and self.with_activation: + x = self.activate(x) + return x + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(nn.Module): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + """ + + def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): + super(ReassembleBlocks, self).__init__() + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_layer=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(nn.Module): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_layer (nn.Module): activation layer. + norm_layer (nn.Module): norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + """ + + def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): + super(PreActResidualConvUnit, self).__init__() + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(nn.Module): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_layer (nn.Module): activation layer for ResidualConvUnit. + norm_layer (nn.Module): normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + """ + + def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): + super(FeatureFusionBlock, self).__init__() + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs, + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/encoder_decoder.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/encoder_decoder.py new file mode 100755 index 0000000..eb29ced --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/encoder_decoder.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ops import resize + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +class DepthEncoderDecoder(nn.Module): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone and decode_head. + """ + + def __init__(self, backbone, decode_head): + super(DepthEncoderDecoder, self).__init__() + + self.backbone = backbone + self.decode_head = decode_head + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + return self.backbone(img) + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + return self.encode_decode(img, img_meta, rescale, size=size) + + def slide_inference(self, img, img_meta, rescale, stride, crop_size): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = stride + h_crop, w_crop = crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None, mode="whole"): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + import torch.distributed as dist + + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/ops.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/ops.py new file mode 100755 index 0000000..15880ee --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depth/ops.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depthers.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depthers.py new file mode 100755 index 0000000..f88b7e9 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/depthers.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from functools import partial +from typing import Optional, Tuple, Union + +import torch + +from .backbones import _make_dinov2_model +from .depth import BNHead, DepthEncoderDecoder, DPTHead +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding + + +class Weights(Enum): + NYU = "NYU" + KITTI = "KITTI" + + +def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: + if not pretrained: # Default + return (0.001, 10.0) + + # Pretrained, set according to the training dataset for the provided weights + if weights == Weights.KITTI: + return (0.001, 80.0) + + if weights == Weights.NYU: + return (0.001, 10.0) + + return (0.001, 10.0) + + +def _make_dinov2_linear_depth_head( + *, + embed_dim: int, + layers: int, + min_depth: float, + max_depth: float, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + + if layers == 1: + in_index = [0] + else: + assert layers == 4 + in_index = [0, 1, 2, 3] + + return BNHead( + classify=True, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + upsample=4, + in_channels=[embed_dim] * len(in_index), + in_index=in_index, + input_transform="resize_concat", + channels=embed_dim * len(in_index) * 2, + align_corners=False, + min_depth=0.001, + max_depth=80, + loss_decode=(), + ) + + +def _make_dinov2_linear_depther( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + model_name = _make_dinov2_model_name(arch_name, patch_size) + linear_depth_head = _make_dinov2_linear_depth_head( + embed_dim=embed_dim, + layers=layers, + min_depth=min_depth, + max_depth=max_depth, + ) + + layer_count = { + "vit_small": 12, + "vit_base": 12, + "vit_large": 24, + "vit_giant2": 40, + }[arch_name] + + if layers == 4: + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + else: + assert layers == 1 + out_index = [layer_count - 1] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) + + if pretrained: + layers_str = str(layers) if layers == 4 else "" + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) + + +def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): + return DPTHead( + in_channels=[embed_dim] * 4, + channels=256, + embed_dims=embed_dim, + post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], + readout_type="project", + min_depth=min_depth, + max_depth=max_depth, + loss_decode=(), + ) + + +def _make_dinov2_dpt_depther( + *, + arch_name: str = "vit_large", + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) + dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) + + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) + + if pretrained: + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther( + arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/utils.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/utils.py new file mode 100755 index 0000000..9c66414 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/__init__.py new file mode 100755 index 0000000..05a0b61 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/attention.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/attention.py new file mode 100755 index 0000000..0fb76ef --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/block.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/block.py new file mode 100755 index 0000000..930787b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/block.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/dino_head.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/dino_head.py new file mode 100755 index 0000000..0ace8ff --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/drop_path.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/drop_path.py new file mode 100755 index 0000000..1d640e0 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/layer_scale.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/layer_scale.py new file mode 100755 index 0000000..51df0d7 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/mlp.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/mlp.py new file mode 100755 index 0000000..bbf9432 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/patch_embed.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/patch_embed.py new file mode 100755 index 0000000..8b7c080 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/swiglu_ffn.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/swiglu_ffn.py new file mode 100755 index 0000000..5e9dafa --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/logging/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/logging/__init__.py new file mode 100755 index 0000000..04a7f02 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/logging/__init__.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import functools +import logging +import os +import sys +from typing import Optional + +import dinov2.distributed as distributed +from .helpers import MetricLogger, SmoothedValue + + +# So that calling _configure_logger multiple times won't add many handlers +@functools.lru_cache() +def _configure_logger( + name: Optional[str] = None, + *, + level: int = logging.DEBUG, + output: Optional[str] = None, +): + """ + Configure a logger. + + Adapted from Detectron2. + + Args: + name: The name of the logger to configure. + level: The logging level to use. + output: A file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + + Returns: + The configured logger. + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + logger.propagate = False + + # Loosely match Google glog format: + # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg + # but use a shorter timestamp and include the logger name: + # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg + fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " + fmt_message = "%(message)s" + fmt = fmt_prefix + fmt_message + datefmt = "%Y%m%d %H:%M:%S" + formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + + # stdout logging for main worker only + if distributed.is_main_process(): + handler = logging.StreamHandler(stream=sys.stdout) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # file logging for all workers + if output: + if os.path.splitext(output)[-1] in (".txt", ".log"): + filename = output + else: + filename = os.path.join(output, "logs", "log.txt") + + if not distributed.is_main_process(): + global_rank = distributed.get_global_rank() + filename = filename + ".rank{}".format(global_rank) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + + handler = logging.StreamHandler(open(filename, "a")) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger + + +def setup_logging( + output: Optional[str] = None, + *, + name: Optional[str] = None, + level: int = logging.DEBUG, + capture_warnings: bool = True, +) -> None: + """ + Setup logging. + + Args: + output: A file name or a directory to save log files. If None, log + files will not be saved. If output ends with ".txt" or ".log", it + is assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + capture_warnings: Whether warnings should be captured as logs. + """ + logging.captureWarnings(capture_warnings) + _configure_logger(name, level=level, output=output) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/logging/helpers.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/logging/helpers.py new file mode 100755 index 0000000..c6e70bb --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/logging/helpers.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict, deque +import datetime +import json +import logging +import time + +import torch + +import dinov2.distributed as distributed + + +logger = logging.getLogger("dinov2") + + +class MetricLogger(object): + def __init__(self, delimiter="\t", output_file=None): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.output_file = output_file + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def dump_in_output_file(self, iteration, iter_time, data_time): + if self.output_file is None or not distributed.is_main_process(): + return + dict_to_dump = dict( + iteration=iteration, + iter_time=iter_time, + data_time=data_time, + ) + dict_to_dump.update({k: v.median for k, v in self.meters.items()}) + with open(self.output_file, "a") as f: + f.write(json.dumps(dict_to_dump) + "\n") + pass + + def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0): + i = start_iteration + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.6f}") + data_time = SmoothedValue(fmt="{avg:.6f}") + + if n_iterations is None: + n_iterations = len(iterable) + + space_fmt = ":" + str(len(str(n_iterations))) + "d" + + log_list = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_list += ["max mem: {memory:.0f}"] + + log_msg = self.delimiter.join(log_list) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == n_iterations - 1: + self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) + eta_seconds = iter_time.global_avg * (n_iterations - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + if i >= n_iterations: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations)) + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, num=1): + self.deque.append(value) + self.count += num + self.total += value * num + + def synchronize_between_processes(self): + """ + Distributed synchronization of the metric + Warning: does not synchronize the deque! + """ + if not distributed.is_enabled(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + torch.distributed.barrier() + torch.distributed.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/__init__.py new file mode 100755 index 0000000..d6b0115 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_clstoken_loss import DINOLoss +from .ibot_patch_loss import iBOTPatchLoss +from .koleo_loss import KoLeoLoss diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/dino_clstoken_loss.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/dino_clstoken_loss.py new file mode 100755 index 0000000..c31808e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/dino_clstoken_loss.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + + +class DINOLoss(nn.Module): + def __init__( + self, + out_dim, + student_temp=0.1, + center_momentum=0.9, + ): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_output = None + self.async_batch_center = None + + @torch.no_grad() + def softmax_center_teacher(self, teacher_output, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): + teacher_output = teacher_output.float() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_output_list, teacher_out_softmaxed_centered_list): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + # TODO: Use cross_entropy_distribution here + total_loss = 0 + for s in student_output_list: + lsm = F.log_softmax(s / self.student_temp, dim=-1) + for t in teacher_out_softmaxed_centered_list: + loss = torch.sum(t * lsm, dim=-1) + total_loss -= loss.mean() + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + self.reduce_center_update(teacher_output) + + @torch.no_grad() + def reduce_center_update(self, teacher_output): + self.updated = False + self.len_teacher_output = len(teacher_output) + self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_output * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/ibot_patch_loss.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/ibot_patch_loss.py new file mode 100755 index 0000000..6732cda --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/ibot_patch_loss.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +import logging + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import cross_entropy + + def lossfunc(t, s, temp): + s = s.float() + t = t.float() + if s.ndim == 2: + return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) + elif s.ndim == 3: + return -cross_entropy(s, t, temp, bw_inplace=True) + +except ImportError: + + def lossfunc(t, s, temp): + return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) + + +class iBOTPatchLoss(nn.Module): + def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_patch_tokens = None + self.async_batch_center = None + + @torch.no_grad() + def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + # + # WARNING: + # as self.center is a float32, everything gets casted to float32 afterwards + # + # teacher_patch_tokens = teacher_patch_tokens.float() + # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) + + return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) + + # this is experimental, keep everything in float16 and let's see what happens: + # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): + teacher_output = teacher_output.float() + # world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + # B = Q.shape[1] * world_size # number of samples to assign + B = n_masked_patches_tensor + dist.all_reduce(B) + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_patch_tokens: (B, N, D) tensor + teacher_patch_tokens: (B, N, D) tensor + student_masks_flat: (B, N) tensor + """ + t = teacher_patch_tokens + s = student_patch_tokens + loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) + return -loss.mean() + + def forward_masked( + self, + student_patch_tokens_masked, + teacher_patch_tokens_masked, + student_masks_flat, + n_masked_patches=None, + masks_weight=None, + ): + t = teacher_patch_tokens_masked + s = student_patch_tokens_masked + # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = lossfunc(t, s, self.student_temp) + if masks_weight is None: + masks_weight = ( + (1 / student_masks_flat.sum(-1).clamp(min=1.0)) + .unsqueeze(-1) + .expand_as(student_masks_flat)[student_masks_flat] + ) + if n_masked_patches is not None: + loss = loss[:n_masked_patches] + loss = loss * masks_weight + return -loss.sum() / student_masks_flat.shape[0] + + @torch.no_grad() + def update_center(self, teacher_patch_tokens): + self.reduce_center_update(teacher_patch_tokens) + + @torch.no_grad() + def reduce_center_update(self, teacher_patch_tokens): + self.updated = False + self.len_teacher_patch_tokens = len(teacher_patch_tokens) + self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/koleo_loss.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/koleo_loss.py new file mode 100755 index 0000000..b5cbcd9 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/loss/koleo_loss.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# import torch.distributed as dist + + +logger = logging.getLogger("dinov2") + + +class KoLeoLoss(nn.Module): + """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" + + def __init__(self): + super().__init__() + self.pdist = nn.PairwiseDistance(2, eps=1e-8) + + def pairwise_NNs_inner(self, x): + """ + Pairwise nearest neighbors for L2-normalized vectors. + Uses Torch rather than Faiss to remain on GPU. + """ + # parwise dot products (= inverse distance) + dots = torch.mm(x, x.t()) + n = x.shape[0] + dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 + # max inner prod -> min distance + _, I = torch.max(dots, dim=1) # noqa: E741 + return I + + def forward(self, student_output, eps=1e-8): + """ + Args: + student_output (BxD): backbone output of student + """ + with torch.cuda.amp.autocast(enabled=False): + student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) + I = self.pairwise_NNs_inner(student_output) # noqa: E741 + distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B + loss = -torch.log(distances + eps).mean() + return loss diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/models/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/models/__init__.py new file mode 100755 index 0000000..3fdff20 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/models/vision_transformer.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/models/vision_transformer.py new file mode 100755 index 0000000..13b44ae --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/models/vision_transformer.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/__init__.py new file mode 100755 index 0000000..b88da6b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/knn.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/knn.py new file mode 100755 index 0000000..d119184 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/knn.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.knn import get_args_parser as get_knn_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.knn import main as knn_main + + self._setup_args() + knn_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 k-NN evaluation" + knn_args_parser = get_knn_args_parser(add_help=False) + parents = [knn_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:knn") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/linear.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/linear.py new file mode 100755 index 0000000..e1dc329 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/linear.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.linear import get_args_parser as get_linear_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.linear import main as linear_main + + self._setup_args() + linear_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 linear evaluation" + linear_args_parser = get_linear_args_parser(add_help=False) + parents = [linear_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:linear") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/log_regression.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/log_regression.py new file mode 100755 index 0000000..cdf0218 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/eval/log_regression.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.log_regression import main as log_regression_main + + self._setup_args() + log_regression_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 logistic evaluation" + log_regression_args_parser = get_log_regression_args_parser(add_help=False) + parents = [log_regression_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:logreg") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/submit.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/submit.py new file mode 100755 index 0000000..4d1f718 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/submit.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +from pathlib import Path +from typing import List, Optional + +import submitit + +from dinov2.utils.cluster import ( + get_slurm_executor_parameters, + get_slurm_partition, + get_user_checkpoint_path, +) + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +) -> argparse.ArgumentParser: + parents = parents or [] + slurm_partition = get_slurm_partition() + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--ngpus", + "--gpus", + "--gpus-per-node", + default=8, + type=int, + help="Number of GPUs to request on each node", + ) + parser.add_argument( + "--nodes", + "--nnodes", + default=1, + type=int, + help="Number of nodes to request", + ) + parser.add_argument( + "--timeout", + default=2800, + type=int, + help="Duration of the job", + ) + parser.add_argument( + "--partition", + default=slurm_partition, + type=str, + help="Partition where to submit", + ) + parser.add_argument( + "--use-volta32", + action="store_true", + help="Request V100-32GB GPUs", + ) + parser.add_argument( + "--comment", + default="", + type=str, + help="Comment to pass to scheduler, e.g. priority message", + ) + parser.add_argument( + "--exclude", + default="", + type=str, + help="Nodes to exclude", + ) + return parser + + +def get_shared_folder() -> Path: + user_checkpoint_path = get_user_checkpoint_path() + if user_checkpoint_path is None: + raise RuntimeError("Path to user checkpoint cannot be determined") + path = user_checkpoint_path / "experiments" + path.mkdir(exist_ok=True) + return path + + +def submit_jobs(task_class, args, name: str): + if not args.output_dir: + args.output_dir = str(get_shared_folder() / "%j") + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) + + kwargs = {} + if args.use_volta32: + kwargs["slurm_constraint"] = "volta32gb" + if args.comment: + kwargs["slurm_comment"] = args.comment + if args.exclude: + kwargs["slurm_exclude"] = args.exclude + + executor_params = get_slurm_executor_parameters( + nodes=args.nodes, + num_gpus_per_node=args.ngpus, + timeout_min=args.timeout, # max is 60 * 72 + slurm_signal_delay_s=120, + slurm_partition=args.partition, + **kwargs, + ) + executor.update_parameters(name=name, **executor_params) + + task = task_class(args) + job = executor.submit(task) + + logger.info(f"Submitted job_id: {job.job_id}") + str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) + logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/train/train.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/train/train.py new file mode 100755 index 0000000..c2366e9 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/run/train/train.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.logging import setup_logging +from dinov2.train import get_args_parser as get_train_args_parser +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.train import main as train_main + + self._setup_args() + train_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 training" + train_args_parser = get_train_args_parser(add_help=False) + parents = [train_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Trainer, args, name="dinov2:train") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/__init__.py new file mode 100755 index 0000000..5f17529 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .train import get_args_parser, main +from .ssl_meta_arch import SSLMetaArch diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/ssl_meta_arch.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/ssl_meta_arch.py new file mode 100755 index 0000000..3ccf15e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/ssl_meta_arch.py @@ -0,0 +1,400 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial +import logging + +import torch +from torch import nn + +from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss +from dinov2.models import build_model_from_cfg +from dinov2.layers import DINOHead +from dinov2.utils.utils import has_batchnorms +from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups +from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model + +from dinov2.models.vision_transformer import BlockChunk + + +try: + from xformers.ops import fmha +except ImportError: + raise AssertionError("xFormers is required for training") + + +logger = logging.getLogger("dinov2") + + +class SSLMetaArch(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None + + student_model_dict = dict() + teacher_model_dict = dict() + + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + student_model_dict["backbone"] = student_backbone + teacher_model_dict["backbone"] = teacher_backbone + logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") + + if cfg.student.pretrained_weights: + chkpt = torch.load(cfg.student.pretrained_weights) + logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") + student_backbone.load_state_dict(chkpt["model"], strict=False) + + self.embed_dim = embed_dim + self.dino_out_dim = cfg.dino.head_n_prototypes + + self.do_dino = cfg.dino.loss_weight > 0 + self.do_koleo = cfg.dino.koleo_loss_weight > 0 + self.do_ibot = cfg.ibot.loss_weight > 0 + self.ibot_separate_head = cfg.ibot.separate_head + + logger.info("OPTIONS -- DINO") + if self.do_dino: + logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") + logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") + logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") + logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") + self.dino_loss_weight = cfg.dino.loss_weight + dino_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.dino.head_n_prototypes, + hidden_dim=cfg.dino.head_hidden_dim, + bottleneck_dim=cfg.dino.head_bottleneck_dim, + nlayers=cfg.dino.head_nlayers, + ) + self.dino_loss = DINOLoss(self.dino_out_dim) + if self.do_koleo: + logger.info("OPTIONS -- DINO -- applying KOLEO regularization") + self.koleo_loss = KoLeoLoss() + + else: + logger.info("OPTIONS -- DINO -- not using DINO") + + if self.do_dino or self.do_ibot: + student_model_dict["dino_head"] = dino_head() + teacher_model_dict["dino_head"] = dino_head() + + logger.info("OPTIONS -- IBOT") + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") + if self.do_ibot: + self.ibot_loss_weight = cfg.ibot.loss_weight + assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot" + assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot" + self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes + self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim) + if self.ibot_separate_head: + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") + logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") + logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") + ibot_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.ibot.head_n_prototypes, + hidden_dim=cfg.ibot.head_hidden_dim, + bottleneck_dim=cfg.ibot.head_bottleneck_dim, + nlayers=cfg.ibot.head_nlayers, + ) + student_model_dict["ibot_head"] = ibot_head() + teacher_model_dict["ibot_head"] = ibot_head() + else: + logger.info("OPTIONS -- IBOT -- head shared with DINO") + + self.need_to_synchronize_fsdp_streams = True + + self.student = nn.ModuleDict(student_model_dict) + self.teacher = nn.ModuleDict(teacher_model_dict) + + # there is no backpropagation through the teacher, so no need for gradients + for p in self.teacher.parameters(): + p.requires_grad = False + logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.") + + def forward(self, inputs): + raise NotImplementedError + + def backprop_loss(self, loss): + if self.fp16_scaler is not None: + self.fp16_scaler.scale(loss).backward() + else: + loss.backward() + + def forward_backward(self, images, teacher_temp): + n_global_crops = 2 + assert n_global_crops == 2 + n_local_crops = self.cfg.crops.local_crops_number + + global_crops = images["collated_global_crops"].cuda(non_blocking=True) + local_crops = images["collated_local_crops"].cuda(non_blocking=True) + + masks = images["collated_masks"].cuda(non_blocking=True) + mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) + n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) + n_masked_patches = mask_indices_list.shape[0] + upperbound = images["upperbound"] + masks_weight = images["masks_weight"].cuda(non_blocking=True) + + n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) + n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops + + do_dino = self.do_dino + do_ibot = self.do_ibot + + # loss scales + ibot_loss_scale = 1.0 / n_global_crops + + # teacher output + @torch.no_grad() + def get_teacher_output(): + x, n_global_crops_teacher = global_crops, n_global_crops + teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) + teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] + teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher) + # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss + teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0])) + ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] + _dim = ibot_teacher_patch_tokens.shape[-1] + n_cls_tokens = teacher_cls_tokens.shape[0] + + if do_ibot and not self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim) + buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches], + ) + tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher) + teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens] + masked_teacher_patch_tokens_after_head = tokens_after_head[ + n_cls_tokens : n_cls_tokens + n_masked_patches + ] + elif do_ibot and self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[:n_masked_patches], + ) + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[ + :n_masked_patches + ] + else: + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_ibot_softmaxed_centered = None + + if self.cfg.train.centering == "centering": + teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + self.dino_loss.update_center(teacher_cls_tokens_after_head) + if do_ibot: + masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0) + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher( + masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp + ) + masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0) + self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches]) + + elif self.cfg.train.centering == "sinkhorn_knopp": + teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + + if do_ibot: + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( + masked_teacher_patch_tokens_after_head, + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + ) + + else: + raise NotImplementedError + + return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered + + teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() + reshard_fsdp_model(self.teacher) + + loss_dict = {} + + loss_accumulator = 0 # for backprop + student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone( + [global_crops, local_crops], masks=[masks, None], is_training=True + ) + + inputs_for_student_head_list = [] + + # 1a: local crops cls tokens + student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0)) + + # 1b: global crops cls tokens + student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0)) + + # 1c: global crops patch tokens + if do_ibot: + _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1] + ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] + buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim) + buffer_tensor_patch_tokens[:n_masked_patches].copy_( + torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) + ) + if not self.ibot_separate_head: + inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0)) + else: + student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[ + :n_masked_patches + ] + + # 2: run + _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list) + outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs)) + + # 3a: local crops cls tokens + student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3b: global crops cls tokens + student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3c: global crops patch tokens + if do_ibot and not self.ibot_separate_head: + student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches] + + if n_local_crops > 0: + dino_local_crops_loss = self.dino_loss( + student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops), + teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list, + ) / (n_global_crops_loss_terms + n_local_crops_loss_terms) + + # store for display + loss_dict["dino_local_crops_loss"] = dino_local_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_local_crops_loss + + # process global crops + loss_scales = 2 # this is here since we process global crops together + + if do_dino: + # compute loss + dino_global_crops_loss = ( + self.dino_loss( + student_output_list=[student_global_cls_tokens_after_head], + teacher_out_softmaxed_centered_list=[ + teacher_dino_softmaxed_centered_list.flatten(0, 1) + ], # these were chunked and stacked in reverse so A is matched to B + ) + * loss_scales + / (n_global_crops_loss_terms + n_local_crops_loss_terms) + ) + + loss_dict["dino_global_crops_loss"] = dino_global_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_global_crops_loss + + student_cls_tokens = student_global_cls_tokens + + if self.do_koleo: + koleo_loss = self.cfg.dino.koleo_loss_weight * sum( + self.koleo_loss(p) for p in student_cls_tokens.chunk(2) + ) # we don't apply koleo loss between cls tokens of a same image + loss_accumulator += koleo_loss + loss_dict["koleo_loss"] = ( + koleo_loss / loss_scales + ) # this is to display the same losses as before but we can remove eventually + + if do_ibot: + # compute loss + ibot_patch_loss = ( + self.ibot_patch_loss.forward_masked( + student_global_masked_patch_tokens_after_head, + masked_teacher_ibot_softmaxed_centered, + student_masks_flat=masks, + n_masked_patches=n_masked_patches, + masks_weight=masks_weight, + ) + * loss_scales + * ibot_loss_scale + ) + + # store for display + loss_dict["ibot_loss"] = ibot_patch_loss / 2 + + # accumulate loss + loss_accumulator += self.ibot_loss_weight * ibot_patch_loss + + self.backprop_loss(loss_accumulator) + + self.fsdp_synchronize_streams() + + return loss_dict + + def fsdp_synchronize_streams(self): + if self.need_to_synchronize_fsdp_streams: + torch.cuda.synchronize() + self.student.dino_head._streams = ( + self.teacher.dino_head._streams + ) = self.student.backbone._streams = self.teacher.backbone._streams + self.need_to_synchronize_fsdp_streams = False + + def update_teacher(self, m): + student_param_list = [] + teacher_param_list = [] + with torch.no_grad(): + for k in self.student.keys(): + for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])): + student_param_list += ms.params + teacher_param_list += mt.params + torch._foreach_mul_(teacher_param_list, m) + torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) + + def train(self): + super().train() + self.teacher.eval() + + def get_maybe_fused_params_for_submodel(self, m): + params_groups = get_params_groups_with_decay( + model=m, + lr_decay_rate=self.cfg.optim.layerwise_decay, + patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, + ) + fused_params_groups = fuse_params_groups(params_groups) + logger.info("fusing param groups") + + for g in fused_params_groups: + g["foreach"] = True + return fused_params_groups + + def get_params_groups(self): + all_params_groups = [] + for m in self.student.values(): + all_params_groups += self.get_maybe_fused_params_for_submodel(m) + return all_params_groups + + def prepare_for_distributed_training(self): + logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + if has_batchnorms(self.student): + raise NotImplementedError + # below will synchronize all student subnetworks across gpus: + for k, v in self.student.items(): + self.teacher[k].load_state_dict(self.student[k].state_dict()) + student_model_cfg = self.cfg.compute_precision.student[k] + self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) + teacher_model_cfg = self.cfg.compute_precision.teacher[k] + self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/train.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/train.py new file mode 100755 index 0000000..473b8d0 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/train/train.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import logging +import math +import os +from functools import partial + +from fvcore.common.checkpoint import PeriodicCheckpointer +import torch + +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator +import dinov2.distributed as distributed +from dinov2.fsdp import FSDPCheckpointer +from dinov2.logging import MetricLogger +from dinov2.utils.config import setup +from dinov2.utils.utils import CosineScheduler + +from dinov2.train.ssl_meta_arch import SSLMetaArch + + +torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default +logger = logging.getLogger("dinov2") + + +def get_args_parser(add_help: bool = True): + parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not attempt to resume from the checkpoint directory. ", + ) + parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") + parser.add_argument("--eval", type=str, default="", help="Eval type to perform") + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--output-dir", + "--output_dir", + default="", + type=str, + help="Output directory to save logs and checkpoints", + ) + + return parser + + +def build_optimizer(cfg, params_groups): + return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) + + +def build_schedulers(cfg): + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + lr = dict( + base_value=cfg.optim["lr"], + final_value=cfg.optim["min_lr"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=0, + ) + wd = dict( + base_value=cfg.optim["weight_decay"], + final_value=cfg.optim["weight_decay_end"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + ) + momentum = dict( + base_value=cfg.teacher["momentum_teacher"], + final_value=cfg.teacher["final_momentum_teacher"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + ) + teacher_temp = dict( + base_value=cfg.teacher["teacher_temp"], + final_value=cfg.teacher["teacher_temp"], + total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=cfg.teacher["warmup_teacher_temp"], + ) + + lr_schedule = CosineScheduler(**lr) + wd_schedule = CosineScheduler(**wd) + momentum_schedule = CosineScheduler(**momentum) + teacher_temp_schedule = CosineScheduler(**teacher_temp) + last_layer_lr_schedule = CosineScheduler(**lr) + + last_layer_lr_schedule.schedule[ + : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH + ] = 0 # mimicking the original schedules + + logger.info("Schedulers ready.") + + return ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) + + +def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): + for param_group in optimizer.param_groups: + is_last_layer = param_group["is_last_layer"] + lr_multiplier = param_group["lr_multiplier"] + wd_multiplier = param_group["wd_multiplier"] + param_group["weight_decay"] = wd * wd_multiplier + param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier + + +def do_test(cfg, model, iteration): + new_state_dict = model.teacher.state_dict() + + if distributed.is_main_process(): + iterstring = str(iteration) + eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) + os.makedirs(eval_dir, exist_ok=True) + # save teacher checkpoint + teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") + torch.save({"teacher": new_state_dict}, teacher_ckp_path) + + +def do_train(cfg, model, resume=False): + model.train() + inputs_dtype = torch.half + fp16_scaler = model.fp16_scaler # for mixed precision training + + # setup optimizer + + optimizer = build_optimizer(cfg, model.get_params_groups()) + ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) = build_schedulers(cfg) + + # checkpointer + checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) + + start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, + period=3 * OFFICIAL_EPOCH_LENGTH, + max_iter=max_iter, + max_to_keep=3, + ) + + # setup data preprocessing + + img_size = cfg.crops.global_crops_size + patch_size = cfg.student.patch_size + n_tokens = (img_size // patch_size) ** 2 + mask_generator = MaskingGenerator( + input_size=(img_size // patch_size, img_size // patch_size), + max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, + ) + + data_transform = DataAugmentationDINO( + cfg.crops.global_crops_scale, + cfg.crops.local_crops_scale, + cfg.crops.local_crops_number, + global_crops_size=cfg.crops.global_crops_size, + local_crops_size=cfg.crops.local_crops_size, + ) + + collate_fn = partial( + collate_data_and_cast, + mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, + mask_probability=cfg.ibot.mask_sample_probability, + n_tokens=n_tokens, + mask_generator=mask_generator, + dtype=inputs_dtype, + ) + + # setup data loader + + dataset = make_dataset( + dataset_str=cfg.train.dataset_path, + transform=data_transform, + target_transform=lambda _: (), + ) + # sampler_type = SamplerType.INFINITE + sampler_type = SamplerType.SHARDED_INFINITE + data_loader = make_data_loader( + dataset=dataset, + batch_size=cfg.train.batch_size_per_gpu, + num_workers=cfg.train.num_workers, + shuffle=True, + seed=start_iter, # TODO: Fix this -- cfg.train.seed + sampler_type=sampler_type, + sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, + drop_last=True, + collate_fn=collate_fn, + ) + + # training loop + + iteration = start_iter + + logger.info("Starting training from iteration {}".format(start_iter)) + metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") + metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) + header = "Training" + + for data in metric_logger.log_every( + data_loader, + 10, + header, + max_iter, + start_iter, + ): + current_batch_size = data["collated_global_crops"].shape[0] / 2 + if iteration > max_iter: + return + + # apply schedules + + lr = lr_schedule[iteration] + wd = wd_schedule[iteration] + mom = momentum_schedule[iteration] + teacher_temp = teacher_temp_schedule[iteration] + last_layer_lr = last_layer_lr_schedule[iteration] + apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) + + # compute losses + + optimizer.zero_grad(set_to_none=True) + loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) + + # clip gradients + + if fp16_scaler is not None: + if cfg.optim.clip_grad: + fp16_scaler.unscale_(optimizer) + for v in model.student.values(): + v.clip_grad_norm_(cfg.optim.clip_grad) + fp16_scaler.step(optimizer) + fp16_scaler.update() + else: + if cfg.optim.clip_grad: + for v in model.student.values(): + v.clip_grad_norm_(cfg.optim.clip_grad) + optimizer.step() + + # perform teacher EMA update + + model.update_teacher(mom) + + # logging + + if distributed.get_global_size() > 1: + for v in loss_dict.values(): + torch.distributed.all_reduce(v) + loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} + + if math.isnan(sum(loss_dict_reduced.values())): + logger.info("NaN detected") + raise AssertionError + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + metric_logger.update(lr=lr) + metric_logger.update(wd=wd) + metric_logger.update(mom=mom) + metric_logger.update(last_layer_lr=last_layer_lr) + metric_logger.update(current_batch_size=current_batch_size) + metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) + + # checkpointing and testing + + if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: + do_test(cfg, model, f"training_{iteration}") + torch.cuda.synchronize() + periodic_checkpointer.step(iteration) + + iteration = iteration + 1 + metric_logger.synchronize_between_processes() + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def main(args): + cfg = setup(args) + + model = SSLMetaArch(cfg).to(torch.device("cuda")) + model.prepare_for_distributed_training() + + logger.info("Model:\n{}".format(model)) + if args.eval_only: + iteration = ( + FSDPCheckpointer(model, save_dir=cfg.train.output_dir) + .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) + .get("iteration", -1) + + 1 + ) + return do_test(cfg, model, f"manual_{iteration}") + + do_train(cfg, model, resume=not args.no_resume) + + +if __name__ == "__main__": + args = get_args_parser(add_help=True).parse_args() + main(args) diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/__init__.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/__init__.py new file mode 100755 index 0000000..b88da6b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/cluster.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/cluster.py new file mode 100755 index 0000000..3df87dc --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/config.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/config.py new file mode 100755 index 0000000..c9de578 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/dtype.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/dtype.py new file mode 100755 index 0000000..80f4cd7 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/param_groups.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/param_groups.py new file mode 100755 index 0000000..9a5d2ff --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/utils.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/utils.py new file mode 100755 index 0000000..68f8e2c --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/src/active_grasp/active_perception/modules/module_lib/dinov2/hubconf.py b/src/active_grasp/active_perception/modules/module_lib/dinov2/hubconf.py new file mode 100755 index 0000000..d3664e2 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/dinov2/hubconf.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14 +from dinov2.hub.backbones import dinov2_vitb14_reg, dinov2_vitg14_reg, dinov2_vitl14_reg, dinov2_vits14_reg +from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc +from dinov2.hub.classifiers import dinov2_vitb14_reg_lc, dinov2_vitg14_reg_lc, dinov2_vitl14_reg_lc, dinov2_vits14_reg_lc +from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld +from dinov2.hub.depthers import dinov2_vitb14_dd, dinov2_vitg14_dd, dinov2_vitl14_dd, dinov2_vits14_dd + + +dependencies = ["torch"] diff --git a/src/active_grasp/active_perception/modules/module_lib/fusion_layer.py b/src/active_grasp/active_perception/modules/module_lib/fusion_layer.py new file mode 100755 index 0000000..0f44ef7 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/fusion_layer.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn + + +class FeatureFusion(nn.Module): + def __init__(self, rgb_dim, pts_dim, output_dim): + super(FeatureFusion, self).__init__() + self.pts_embedding = nn.Linear(pts_dim, output_dim) + + + # B * patch_size * patch_size * C => B * 1 * 1 * C => B * C + self.rgb_embedding = nn.Sequential( + nn.Conv2d(rgb_dim, 512, kernel_size=3, stride=2, padding=1), # Bx17x17x512 + nn.ReLU(), + nn.Conv2d(512, output_dim, kernel_size=3, stride=2, padding=1), # # Bx9x9xoutput_dim + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=2, padding=1), # Bx5x5xoutput_dim + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=5, stride=1, padding=0), # Bx1x1xoutput_dim + nn.ReLU() + ) + self.fc_fusion = nn.Linear(output_dim * 2, output_dim) + self.relu = nn.ReLU() + + def forward(self, img_feat, pts_feat): + # img_feat = torch.mean(img_feat, dim=1) + patch_length = img_feat.size(1) + patch_size = int(patch_length ** 0.5) + # B * patch_size * patch_size * C = > B * C * patch_size * patch_size + img_feat = img_feat.view(-1, patch_size, patch_size, img_feat.size(2)) + img_feat = img_feat.permute(0, 3, 2, 1) + rgb_embedding = self.rgb_embedding(img_feat) + rgb_embedding = rgb_embedding.view(rgb_embedding.size(0), -1) + pts_embedding = self.relu(self.pts_embedding(pts_feat)) + fusion_feat = torch.cat((rgb_embedding, pts_embedding), dim=1) + output = self.fc_fusion(fusion_feat) + return output + +if __name__ == "__main__": + B = 64 + C = 1024 + img_feat_dim = 384 + pts_feat_dim = 1024 + img_feat = torch.randn(B, 1156, 384).cuda() + pts_feat = torch.randn(B, 1024).cuda() + fusion_model = FeatureFusion(img_feat_dim,pts_feat_dim,output_dim=C).cuda() + output = fusion_model(img_feat, pts_feat) + print(output.shape) \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/gaussian_fourier_projection.py b/src/active_grasp/active_perception/modules/module_lib/gaussian_fourier_projection.py new file mode 100755 index 0000000..c1713ab --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/gaussian_fourier_projection.py @@ -0,0 +1,17 @@ +import torch +import numpy as np +import torch.nn as nn + + +class GaussianFourierProjection(nn.Module): + """Gaussian random features for encoding time steps.""" + + def __init__(self, embed_dim, scale=30.): + super().__init__() + # Randomly sample weights during initialization. These weights are fixed + # during optimization and are not trainable. + self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) diff --git a/src/active_grasp/active_perception/modules/module_lib/linear.py b/src/active_grasp/active_perception/modules/module_lib/linear.py new file mode 100755 index 0000000..a5ae402 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/linear.py @@ -0,0 +1,30 @@ +import torch +import numpy as np + + +def weight_init(shape, mode, fan_in, fan_out): + if mode == 'xavier_uniform': + return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == 'xavier_normal': + return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == 'kaiming_uniform': + return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == 'kaiming_normal': + return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') + + +class Linear(torch.nn.Module): + def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) + self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) + self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None + + def forward(self, x): + x = x @ self.weight.to(x.dtype).t() + if self.bias is not None: + x = x.add_(self.bias.to(x.dtype)) + return x diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/.gitignore b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/.gitignore new file mode 100755 index 0000000..cf42194 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/.gitignore @@ -0,0 +1,4 @@ +pointnet2/build/ +pointnet2/dist/ +pointnet2/pointnet2.egg-info/ +__pycache__/ diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/LICENSE b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/LICENSE new file mode 100755 index 0000000..77c8ebe --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Shaoshuai Shi + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/README.md b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/README.md new file mode 100755 index 0000000..c5a43f0 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/README.md @@ -0,0 +1,51 @@ +# Pointnet2.PyTorch + +* PyTorch implementation of [PointNet++](https://arxiv.org/abs/1706.02413) based on [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch). +* Faster than the original codes by re-implementing the CUDA operations. + +## Installation +### Requirements +* Linux (tested on Ubuntu 14.04/16.04) +* Python 3.6+ +* PyTorch 1.0 + +### Install +Install this library by running the following command: + +```shell +cd pointnet2 +python setup.py install +cd ../ +``` + +## Examples +Here I provide a simple example to use this library in the task of KITTI ourdoor foreground point cloud segmentation, and you could refer to the paper [PointRCNN](https://arxiv.org/abs/1812.04244) for the details of task description and foreground label generation. + +1. Download the training data from [KITTI 3D object detection](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) website and organize the downloaded files as follows: +``` +Pointnet2.PyTorch +├── pointnet2 +├── tools +│ ├──data +│ │ ├── KITTI +│ │ │ ├── ImageSets +│ │ │ ├── object +│ │ │ │ ├──training +│ │ │ │ ├──calib & velodyne & label_2 & image_2 +│ │ train_and_eval.py +``` + +2. Run the following command to train and evaluate: +```shell +cd tools +python train_and_eval.py --batch_size 8 --epochs 100 --ckpt_save_interval 2 +``` + + + +## Project using this repo: +* [PointRCNN](https://github.com/sshaoshuai/PointRCNN): 3D object detector from raw point cloud. + +## Acknowledgement +* [charlesq34/pointnet2](https://github.com/charlesq34/pointnet2): Paper author and official code repo. +* [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch): Initial work of PyTorch implementation of PointNet++. diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_modules.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_modules.py new file mode 100755 index 0000000..4b94326 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_modules.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import pointnet2_utils +from . import pytorch_utils as pt_utils +from typing import List + + +class _PointnetSAModuleBase(nn.Module): + + def __init__(self): + super().__init__() + self.npoint = None + self.groupers = None + self.mlps = None + self.pool_method = 'max_pool' + + def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): + """ + :param xyz: (B, N, 3) tensor of the xyz coordinates of the features + :param features: (B, N, C) tensor of the descriptors of the the features + :param new_xyz: + :return: + new_xyz: (B, npoint, 3) tensor of the new features' xyz + new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors + """ + new_features_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + if new_xyz is None: + new_xyz = pointnet2_utils.gather_operation( + xyz_flipped, + pointnet2_utils.furthest_point_sample(xyz, self.npoint) + ).transpose(1, 2).contiguous() if self.npoint is not None else None + + for i in range(len(self.groupers)): + new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) + + new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) + + if self.pool_method == 'max_pool': + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + elif self.pool_method == 'avg_pool': + new_features = F.avg_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + else: + raise NotImplementedError + + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + new_features_list.append(new_features) + + return new_xyz, torch.cat(new_features_list, dim=1) + + +class PointnetSAModuleMSG(_PointnetSAModuleBase): + """Pointnet set abstraction layer with multiscale grouping""" + + def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, + use_xyz: bool = True, pool_method='max_pool', instance_norm=False): + """ + :param npoint: int + :param radii: list of float, list of radii to group with + :param nsamples: list of int, number of samples in each ball query + :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale + :param bn: whether to use batchnorm + :param use_xyz: + :param pool_method: max_pool / avg_pool + :param instance_norm: whether to use instance_norm + """ + super().__init__() + + assert len(radii) == len(nsamples) == len(mlps) + + self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) + if npoint is not None else pointnet2_utils.GroupAll(use_xyz) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) + self.pool_method = pool_method + + +class PointnetSAModule(PointnetSAModuleMSG): + """Pointnet set abstraction layer""" + + def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, + bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): + """ + :param mlp: list of int, spec of the pointnet before the global max_pool + :param npoint: int, number of features + :param radius: float, radius of ball + :param nsample: int, number of samples in the ball query + :param bn: whether to use batchnorm + :param use_xyz: + :param pool_method: max_pool / avg_pool + :param instance_norm: whether to use instance_norm + """ + super().__init__( + mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, + pool_method=pool_method, instance_norm=instance_norm + ) + + +class PointnetFPModule(nn.Module): + r"""Propigates the features of one set to another""" + + def __init__(self, *, mlp: List[int], bn: bool = True): + """ + :param mlp: list of int + :param bn: whether to use batchnorm + """ + super().__init__() + self.mlp = pt_utils.SharedMLP(mlp, bn=bn) + + def forward( + self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor + ) -> torch.Tensor: + """ + :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features + :param known: (B, m, 3) tensor of the xyz positions of the known features + :param unknow_feats: (B, C1, n) tensor of the features to be propigated to + :param known_feats: (B, C2, m) tensor of features to be propigated + :return: + new_features: (B, mlp[-1], n) tensor of the features of the unknown features + """ + if known is not None: + dist, idx = pointnet2_utils.three_nn(unknown, known) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + + interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) + else: + interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) + + if unknow_feats is not None: + new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) + else: + new_features = interpolated_feats + + new_features = new_features.unsqueeze(-1) + + new_features = self.mlp(new_features) + + return new_features.squeeze(-1) + + +if __name__ == "__main__": + pass diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_utils.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_utils.py new file mode 100755 index 0000000..97a5466 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_utils.py @@ -0,0 +1,291 @@ +import torch +from torch.autograd import Variable +from torch.autograd import Function +import torch.nn as nn +from typing import Tuple +import sys + +import pointnet2_cuda as pointnet2 + + +class FurthestPointSampling(Function): + @staticmethod + def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: + """ + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance + :param ctx: + :param xyz: (B, N, 3) where N > npoint + :param npoint: int, number of features in the sampled set + :return: + output: (B, npoint) tensor containing the set + """ + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + output = torch.cuda.IntTensor(B, npoint) + temp = torch.cuda.FloatTensor(B, N).fill_(1e10) + + pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + + +furthest_point_sample = FurthestPointSampling.apply + + +class GatherOperation(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) + :param idx: (B, npoint) index tensor of the features to gather + :return: + output: (B, C, npoint) + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, npoint = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, npoint) + + pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output) + + ctx.for_backwards = (idx, C, N) + return output + + @staticmethod + def backward(ctx, grad_out): + idx, C, N = ctx.for_backwards + B, npoint = idx.size() + + grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + grad_out_data = grad_out.data.contiguous() + pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) + return grad_features, None + + +gather_operation = GatherOperation.apply + + +class ThreeNN(Function): + + @staticmethod + def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Find the three nearest neighbors of unknown in known + :param ctx: + :param unknown: (B, N, 3) + :param known: (B, M, 3) + :return: + dist: (B, N, 3) l2 distance to the three nearest neighbors + idx: (B, N, 3) index of 3 nearest neighbors + """ + assert unknown.is_contiguous() + assert known.is_contiguous() + + B, N, _ = unknown.size() + m = known.size(1) + dist2 = torch.cuda.FloatTensor(B, N, 3) + idx = torch.cuda.IntTensor(B, N, 3) + + pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) + return torch.sqrt(dist2), idx + + @staticmethod + def backward(ctx, a=None, b=None): + return None, None + + +three_nn = ThreeNN.apply + + +class ThreeInterpolate(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Performs weight linear interpolation on 3 features + :param ctx: + :param features: (B, C, M) Features descriptors to be interpolated from + :param idx: (B, n, 3) three nearest neighbors of the target features in features + :param weight: (B, n, 3) weights + :return: + output: (B, C, N) tensor of the interpolated features + """ + assert features.is_contiguous() + assert idx.is_contiguous() + assert weight.is_contiguous() + + B, c, m = features.size() + n = idx.size(1) + ctx.three_interpolate_for_backward = (idx, weight, m) + output = torch.cuda.FloatTensor(B, c, n) + + pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, N) tensor with gradients of outputs + :return: + grad_features: (B, C, M) tensor with gradients of features + None: + None: + """ + idx, weight, m = ctx.three_interpolate_for_backward + B, c, n = grad_out.size() + + grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) + grad_out_data = grad_out.data.contiguous() + + pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) + return grad_features, None, None + + +three_interpolate = ThreeInterpolate.apply + + +class GroupingOperation(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) tensor of features to group + :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with + :return: + output: (B, C, npoint, nsample) tensor + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, nfeatures, nsample = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) + + ctx.for_backwards = (idx, N) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward + :return: + grad_features: (B, C, N) gradient of the features + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + + grad_out_data = grad_out.data.contiguous() + pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply + + +class BallQuery(Function): + + @staticmethod + def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param radius: float, radius of the balls + :param nsample: int, maximum number of features in the balls + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: (B, npoint, 3) centers of the ball query + :return: + idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls + """ + assert new_xyz.is_contiguous() + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + npoint = new_xyz.size(1) + idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() + + pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None + + +ball_query = BallQuery.apply + + +class QueryAndGroup(nn.Module): + def __init__(self, radius: float, nsample: int, use_xyz: bool = True): + """ + :param radius: float, radius of ball + :param nsample: int, maximum number of features to gather in the ball + :param use_xyz: + """ + super().__init__() + self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz + + def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]: + """ + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: (B, npoint, 3) centroids + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, 3 + C, npoint, nsample) + """ + idx = ball_query(self.radius, self.nsample, xyz, new_xyz) + xyz_trans = xyz.transpose(1, 2).contiguous() + grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample) + else: + new_features = grouped_features + else: + assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" + new_features = grouped_xyz + + return new_features + + +class GroupAll(nn.Module): + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): + """ + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: ignored + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, C + 3, 1, N) + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pytorch_utils.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pytorch_utils.py new file mode 100755 index 0000000..09cb7bc --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/pytorch_utils.py @@ -0,0 +1,236 @@ +import torch.nn as nn +from typing import List, Tuple + + +class SharedMLP(nn.Sequential): + + def __init__( + self, + args: List[int], + *, + bn: bool = False, + activation=nn.ReLU(inplace=True), + preact: bool = False, + first: bool = False, + name: str = "", + instance_norm: bool = False, + ): + super().__init__() + + for i in range(len(args) - 1): + self.add_module( + name + 'layer{}'.format(i), + Conv2d( + args[i], + args[i + 1], + bn=(not first or not preact or (i != 0)) and bn, + activation=activation + if (not first or not preact or (i != 0)) else None, + preact=preact, + instance_norm=instance_norm + ) + ) + + +class _ConvBase(nn.Sequential): + + def __init__( + self, + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=None, + batch_norm=None, + bias=True, + preact=False, + name="", + instance_norm=False, + instance_norm_func=None + ): + super().__init__() + + bias = bias and (not bn) + conv_unit = conv( + in_size, + out_size, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias + ) + init(conv_unit.weight) + if bias: + nn.init.constant_(conv_unit.bias, 0) + + if bn: + if not preact: + bn_unit = batch_norm(out_size) + else: + bn_unit = batch_norm(in_size) + if instance_norm: + if not preact: + in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) + else: + in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) + + if preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + if not bn and instance_norm: + self.add_module(name + 'in', in_unit) + + self.add_module(name + 'conv', conv_unit) + + if not preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + if not bn and instance_norm: + self.add_module(name + 'in', in_unit) + + +class _BNBase(nn.Sequential): + + def __init__(self, in_size, batch_norm=None, name=""): + super().__init__() + self.add_module(name + "bn", batch_norm(in_size)) + + nn.init.constant_(self[0].weight, 1.0) + nn.init.constant_(self[0].bias, 0) + + +class BatchNorm1d(_BNBase): + + def __init__(self, in_size: int, *, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) + + +class BatchNorm2d(_BNBase): + + def __init__(self, in_size: int, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) + + +class Conv1d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "", + instance_norm=False + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv1d, + batch_norm=BatchNorm1d, + bias=bias, + preact=preact, + name=name, + instance_norm=instance_norm, + instance_norm_func=nn.InstanceNorm1d + ) + + +class Conv2d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "", + instance_norm=False + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv2d, + batch_norm=BatchNorm2d, + bias=bias, + preact=preact, + name=name, + instance_norm=instance_norm, + instance_norm_func=nn.InstanceNorm2d + ) + + +class FC(nn.Sequential): + + def __init__( + self, + in_size: int, + out_size: int, + *, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=None, + preact: bool = False, + name: str = "" + ): + super().__init__() + + fc = nn.Linear(in_size, out_size, bias=not bn) + if init is not None: + init(fc.weight) + if not bn: + nn.init.constant(fc.bias, 0) + + if preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(in_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + + self.add_module(name + 'fc', fc) + + if not preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(out_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/setup.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/setup.py new file mode 100755 index 0000000..99e59e3 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='pointnet2', + ext_modules=[ + CUDAExtension('pointnet2_cuda', [ + 'src/pointnet2_api.cpp', + + 'src/ball_query.cpp', + 'src/ball_query_gpu.cu', + 'src/group_points.cpp', + 'src/group_points_gpu.cu', + 'src/interpolate.cpp', + 'src/interpolate_gpu.cu', + 'src/sampling.cpp', + 'src/sampling_gpu.cu', + ], + extra_compile_args={'cxx': ['-g'], + 'nvcc': ['-O2']}) + ], + cmdclass={'build_ext': BuildExtension} +) diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query.cpp b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query.cpp new file mode 100755 index 0000000..21f787e --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query.cpp @@ -0,0 +1,28 @@ +#include +#include +// #include +#include +#include +#include "ball_query_gpu.h" +#include +#include + +// extern THCState *state; + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") +#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) + +int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, + at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { + CHECK_INPUT(new_xyz_tensor); + CHECK_INPUT(xyz_tensor); + const float *new_xyz = new_xyz_tensor.data(); + const float *xyz = xyz_tensor.data(); + int *idx = idx_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); + return 1; +} \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.cu b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.cu new file mode 100755 index 0000000..f8840aa --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.cu @@ -0,0 +1,67 @@ +#include +#include +#include + +#include "ball_query_gpu.h" +#include "cuda_utils.h" + + +__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, + const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= m) return; + + new_xyz += bs_idx * m * 3 + pt_idx * 3; + xyz += bs_idx * n * 3; + idx += bs_idx * m * nsample + pt_idx * nsample; + + float radius2 = radius * radius; + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + int cnt = 0; + for (int k = 0; k < n; ++k) { + float x = xyz[k * 3 + 0]; + float y = xyz[k * 3 + 1]; + float z = xyz[k * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); + if (d2 < radius2){ + if (cnt == 0){ + for (int l = 0; l < nsample; ++l) { + idx[l] = k; + } + } + idx[cnt] = k; + ++cnt; + if (cnt >= nsample) break; + } + } +} + + +void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ + const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + + cudaError_t err; + + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.h b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.h new file mode 100755 index 0000000..ffc831a --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.h @@ -0,0 +1,15 @@ +#ifndef _BALL_QUERY_GPU_H +#define _BALL_QUERY_GPU_H + +#include +#include +#include +#include + +int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, + at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); + +void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, + const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); + +#endif diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/cuda_utils.h b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/cuda_utils.h new file mode 100755 index 0000000..7fe2796 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/cuda_utils.h @@ -0,0 +1,15 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include + +#define TOTAL_THREADS 1024 +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} +#endif diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points.cpp b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points.cpp new file mode 100755 index 0000000..f0e74e9 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include +// #include +#include "group_points_gpu.h" +#include +#include +// extern THCState *state; + + +int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { + + float *grad_points = grad_points_tensor.data(); + const int *idx = idx_tensor.data(); + const float *grad_out = grad_out_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); + return 1; +} + + +int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { + + const float *points = points_tensor.data(); + const int *idx = idx_tensor.data(); + float *out = out_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); + return 1; +} \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.cu b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.cu new file mode 100755 index 0000000..c015a81 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.cu @@ -0,0 +1,86 @@ +#include +#include + +#include "cuda_utils.h" +#include "group_points_gpu.h" + + +__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, + const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); +} + +void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + cudaError_t err; + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, + const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + int in_idx = bs_idx * c * n + c_idx * n + idx[0]; + int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + out[out_idx] = points[in_idx]; +} + + +void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, float *out, cudaStream_t stream) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + cudaError_t err; + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.h b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.h new file mode 100755 index 0000000..76c73ca --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.h @@ -0,0 +1,22 @@ +#ifndef _GROUP_POINTS_GPU_H +#define _GROUP_POINTS_GPU_H + +#include +#include +#include +#include + + +int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); + +void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, float *out, cudaStream_t stream); + +int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); + +void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); + +#endif diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate.cpp b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate.cpp new file mode 100755 index 0000000..d01f045 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate.cpp @@ -0,0 +1,59 @@ +#include +#include +// #include +#include +#include +#include +#include +#include +#include +#include +#include "interpolate_gpu.h" + +// extern THCState *state; + + +void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, + at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { + const float *unknown = unknown_tensor.data(); + const float *known = known_tensor.data(); + float *dist2 = dist2_tensor.data(); + int *idx = idx_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); +} + + +void three_interpolate_wrapper_fast(int b, int c, int m, int n, + at::Tensor points_tensor, + at::Tensor idx_tensor, + at::Tensor weight_tensor, + at::Tensor out_tensor) { + + const float *points = points_tensor.data(); + const float *weight = weight_tensor.data(); + float *out = out_tensor.data(); + const int *idx = idx_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); +} + +void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, + at::Tensor grad_out_tensor, + at::Tensor idx_tensor, + at::Tensor weight_tensor, + at::Tensor grad_points_tensor) { + + const float *grad_out = grad_out_tensor.data(); + const float *weight = weight_tensor.data(); + float *grad_points = grad_points_tensor.data(); + const int *idx = idx_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); +} \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.cu b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.cu new file mode 100755 index 0000000..a123dd8 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.cu @@ -0,0 +1,161 @@ +#include +#include +#include + +#include "cuda_utils.h" +#include "interpolate_gpu.h" + + +__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, + const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= n) return; + + unknown += bs_idx * n * 3 + pt_idx * 3; + known += bs_idx * m * 3; + dist2 += bs_idx * n * 3 + pt_idx * 3; + idx += bs_idx * n * 3 + pt_idx * 3; + + float ux = unknown[0]; + float uy = unknown[1]; + float uz = unknown[2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = 0, besti2 = 0, besti3 = 0; + for (int k = 0; k < m; ++k) { + float x = known[k * 3 + 0]; + float y = known[k * 3 + 1]; + float z = known[k * 3 + 2]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best1) { + best3 = best2; besti3 = besti2; + best2 = best1; besti2 = besti1; + best1 = d; besti1 = k; + } + else if (d < best2) { + best3 = best2; besti3 = besti2; + best2 = d; besti2 = k; + } + else if (d < best3) { + best3 = d; besti3 = k; + } + } + dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; + idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; +} + + +void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx, cudaStream_t stream) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, + const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; + + weight += bs_idx * n * 3 + pt_idx * 3; + points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + out += bs_idx * c * n + c_idx * n; + + out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; +} + +void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, + const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; + + grad_out += bs_idx * c * n + c_idx * n + pt_idx; + weight += bs_idx * n * 3 + pt_idx * 3; + grad_points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + + + atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); + atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); + atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); +} + +void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, + const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.h b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.h new file mode 100755 index 0000000..f177108 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.h @@ -0,0 +1,30 @@ +#ifndef _INTERPOLATE_GPU_H +#define _INTERPOLATE_GPU_H + +#include +#include +#include +#include + + +void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, + at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); + +void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx, cudaStream_t stream); + + +void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, + at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); + +void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, + const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); + + +void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, + at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); + +void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, + const int *idx, const float *weight, float *grad_points, cudaStream_t stream); + +#endif diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/pointnet2_api.cpp b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/pointnet2_api.cpp new file mode 100755 index 0000000..d91f0f2 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/pointnet2_api.cpp @@ -0,0 +1,24 @@ +#include +#include + +#include "ball_query_gpu.h" +#include "group_points_gpu.h" +#include "sampling_gpu.h" +#include "interpolate_gpu.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); + + m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); + m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); + + m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); + m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); + + m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); + + m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); + m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); + m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); +} diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling.cpp b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling.cpp new file mode 100755 index 0000000..fbb277a --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +// #include + +#include "sampling_gpu.h" +#include +#include + +// extern THCState *state; + + +int gather_points_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ + const float *points = points_tensor.data(); + const int *idx = idx_tensor.data(); + float *out = out_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); + return 1; +} + + +int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { + + const float *grad_out = grad_out_tensor.data(); + const int *idx = idx_tensor.data(); + float *grad_points = grad_points_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); + return 1; +} + + +int furthest_point_sampling_wrapper(int b, int n, int m, + at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { + + const float *points = points_tensor.data(); + float *temp = temp_tensor.data(); + int *idx = idx_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); + return 1; +} diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.cu b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.cu new file mode 100755 index 0000000..9e49a60 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.cu @@ -0,0 +1,253 @@ +#include +#include + +#include "cuda_utils.h" +#include "sampling_gpu.h" + + +__global__ void gather_points_kernel_fast(int b, int c, int n, int m, + const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { + // points: (B, C, N) + // idx: (B, M) + // output: + // out: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; + + out += bs_idx * c * m + c_idx * m + pt_idx; + idx += bs_idx * m + pt_idx; + points += bs_idx * c * n + c_idx * n; + out[0] = points[idx[0]]; +} + +void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *points, const int *idx, float *out, cudaStream_t stream) { + // points: (B, C, N) + // idx: (B, npoints) + // output: + // out: (B, C, npoints) + + cudaError_t err; + dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, float *__restrict__ grad_points) { + // grad_out: (B, C, M) + // idx: (B, M) + // output: + // grad_points: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; + + grad_out += bs_idx * c * m + c_idx * m + pt_idx; + idx += bs_idx * m + pt_idx; + grad_points += bs_idx * c * n + c_idx * n; + + atomicAdd(grad_points + idx[0], grad_out[0]); +} + +void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, npoints) + // idx: (B, npoints) + // output: + // grad_points: (B, C, N) + + cudaError_t err; + dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +template +__global__ void furthest_point_sampling_kernel(int b, int n, int m, + const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + if (m <= 0) return; + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int batch_index = blockIdx.x; + dataset += batch_index * n * 3; + temp += batch_index * n; + idxs += batch_index * m; + + int tid = threadIdx.x; + const int stride = block_size; + + int old = 0; + if (threadIdx.x == 0) + idxs[0] = old; + + __syncthreads(); + for (int j = 1; j < m; j++) { + int besti = 0; + float best = -1; + float x1 = dataset[old * 3 + 0]; + float y1 = dataset[old * 3 + 1]; + float z1 = dataset[old * 3 + 2]; + for (int k = tid; k < n; k += stride) { + float x2, y2, z2; + x2 = dataset[k * 3 + 0]; + y2 = dataset[k * 3 + 1]; + z2 = dataset[k * 3 + 2]; + // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); + // if (mag <= 1e-3) + // continue; + + float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + float d2 = min(d, temp[k]); + temp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 1024) { + if (tid < 512) { + __update(dists, dists_i, tid, tid + 512); + } + __syncthreads(); + } + + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) + idxs[j] = old; + } +} + +void furthest_point_sampling_kernel_launcher(int b, int n, int m, + const float *dataset, float *temp, int *idxs, cudaStream_t stream) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + cudaError_t err; + unsigned int n_threads = opt_n_threads(n); + + switch (n_threads) { + case 1024: + furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; + case 512: + furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; + case 256: + furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; + case 128: + furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; + case 64: + furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; + case 32: + furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; + case 16: + furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; + case 8: + furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; + case 4: + furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; + case 2: + furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; + case 1: + furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; + default: + furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); + } + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.h b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.h new file mode 100755 index 0000000..6200c59 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.h @@ -0,0 +1,29 @@ +#ifndef _SAMPLING_GPU_H +#define _SAMPLING_GPU_H + +#include +#include +#include + + +int gather_points_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); + +void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *points, const int *idx, float *out, cudaStream_t stream); + + +int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); + +void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); + + +int furthest_point_sampling_wrapper(int b, int n, int m, + at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); + +void furthest_point_sampling_kernel_launcher(int b, int n, int m, + const float *dataset, float *temp, int *idxs, cudaStream_t stream); + +#endif diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/_init_path.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/_init_path.py new file mode 100755 index 0000000..c6c4565 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/_init_path.py @@ -0,0 +1,2 @@ +import os, sys +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')) diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/dataset.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/dataset.py new file mode 100755 index 0000000..deca8ec --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/dataset.py @@ -0,0 +1,188 @@ +import os +import numpy as np +import torch.utils.data as torch_data +import kitti_utils +import cv2 +from PIL import Image + + +USE_INTENSITY = False + + +class KittiDataset(torch_data.Dataset): + def __init__(self, root_dir, split='train', mode='TRAIN'): + self.split = split + self.mode = mode + self.classes = ['Car'] + is_test = self.split == 'test' + self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training') + + split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt') + self.image_idx_list = [x.strip() for x in open(split_dir).readlines()] + self.sample_id_list = [int(sample_id) for sample_id in self.image_idx_list] + self.num_sample = self.image_idx_list.__len__() + + self.npoints = 16384 + + self.image_dir = os.path.join(self.imageset_dir, 'image_2') + self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne') + self.calib_dir = os.path.join(self.imageset_dir, 'calib') + self.label_dir = os.path.join(self.imageset_dir, 'label_2') + self.plane_dir = os.path.join(self.imageset_dir, 'planes') + + def get_image(self, idx): + img_file = os.path.join(self.image_dir, '%06d.png' % idx) + assert os.path.exists(img_file) + return cv2.imread(img_file) # (H, W, 3) BGR mode + + def get_image_shape(self, idx): + img_file = os.path.join(self.image_dir, '%06d.png' % idx) + assert os.path.exists(img_file) + im = Image.open(img_file) + width, height = im.size + return height, width, 3 + + def get_lidar(self, idx): + lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx) + assert os.path.exists(lidar_file) + return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4) + + def get_calib(self, idx): + calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx) + assert os.path.exists(calib_file) + return kitti_utils.Calibration(calib_file) + + def get_label(self, idx): + label_file = os.path.join(self.label_dir, '%06d.txt' % idx) + assert os.path.exists(label_file) + return kitti_utils.get_objects_from_label(label_file) + + @staticmethod + def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape): + val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1]) + val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0]) + val_flag_merge = np.logical_and(val_flag_1, val_flag_2) + pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0) + return pts_valid_flag + + def filtrate_objects(self, obj_list): + type_whitelist = self.classes + if self.mode == 'TRAIN': + type_whitelist = list(self.classes) + if 'Car' in self.classes: + type_whitelist.append('Van') + + valid_obj_list = [] + for obj in obj_list: + if obj.cls_type not in type_whitelist: + continue + + valid_obj_list.append(obj) + return valid_obj_list + + def __len__(self): + return len(self.sample_id_list) + + def __getitem__(self, index): + sample_id = int(self.sample_id_list[index]) + calib = self.get_calib(sample_id) + img_shape = self.get_image_shape(sample_id) + pts_lidar = self.get_lidar(sample_id) + + # get valid point (projected points should be in image) + pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3]) + pts_intensity = pts_lidar[:, 3] + + pts_img, pts_rect_depth = calib.rect_to_img(pts_rect) + pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape) + + pts_rect = pts_rect[pts_valid_flag][:, 0:3] + pts_intensity = pts_intensity[pts_valid_flag] + + if self.npoints < len(pts_rect): + pts_depth = pts_rect[:, 2] + pts_near_flag = pts_depth < 40.0 + far_idxs_choice = np.where(pts_near_flag == 0)[0] + near_idxs = np.where(pts_near_flag == 1)[0] + near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False) + + choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \ + if len(far_idxs_choice) > 0 else near_idxs_choice + np.random.shuffle(choice) + else: + choice = np.arange(0, len(pts_rect), dtype=np.int32) + if self.npoints > len(pts_rect): + extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False) + choice = np.concatenate((choice, extra_choice), axis=0) + np.random.shuffle(choice) + + ret_pts_rect = pts_rect[choice, :] + ret_pts_intensity = pts_intensity[choice] - 0.5 # translate intensity to [-0.5, 0.5] + + pts_features = [ret_pts_intensity.reshape(-1, 1)] + ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0] + + sample_info = {'sample_id': sample_id} + + if self.mode == 'TEST': + if USE_INTENSITY: + pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) + else: + pts_input = ret_pts_rect + sample_info['pts_input'] = pts_input + sample_info['pts_rect'] = ret_pts_rect + sample_info['pts_features'] = ret_pts_features + return sample_info + + gt_obj_list = self.filtrate_objects(self.get_label(sample_id)) + + gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list) + + # prepare input + if USE_INTENSITY: + pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) + else: + pts_input = ret_pts_rect + + # generate training labels + cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d) + sample_info['pts_input'] = pts_input + sample_info['pts_rect'] = ret_pts_rect + sample_info['cls_labels'] = cls_labels + return sample_info + + @staticmethod + def generate_training_labels(pts_rect, gt_boxes3d): + cls_label = np.zeros((pts_rect.shape[0]), dtype=np.int32) + gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, rotate=True) + extend_gt_boxes3d = kitti_utils.enlarge_box3d(gt_boxes3d, extra_width=0.2) + extend_gt_corners = kitti_utils.boxes3d_to_corners3d(extend_gt_boxes3d, rotate=True) + for k in range(gt_boxes3d.shape[0]): + box_corners = gt_corners[k] + fg_pt_flag = kitti_utils.in_hull(pts_rect, box_corners) + cls_label[fg_pt_flag] = 1 + + # enlarge the bbox3d, ignore nearby points + extend_box_corners = extend_gt_corners[k] + fg_enlarge_flag = kitti_utils.in_hull(pts_rect, extend_box_corners) + ignore_flag = np.logical_xor(fg_pt_flag, fg_enlarge_flag) + cls_label[ignore_flag] = -1 + + return cls_label + + def collate_batch(self, batch): + batch_size = batch.__len__() + ans_dict = {} + + for key in batch[0].keys(): + if isinstance(batch[0][key], np.ndarray): + ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0) + + else: + ans_dict[key] = [batch[k][key] for k in range(batch_size)] + if isinstance(batch[0][key], int): + ans_dict[key] = np.array(ans_dict[key], dtype=np.int32) + elif isinstance(batch[0][key], float): + ans_dict[key] = np.array(ans_dict[key], dtype=np.float32) + + return ans_dict diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/kitti_utils.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/kitti_utils.py new file mode 100755 index 0000000..43f06b3 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/kitti_utils.py @@ -0,0 +1,229 @@ +import numpy as np +from scipy.spatial import Delaunay +import scipy + + +def cls_type_to_id(cls_type): + type_to_id = {'Car': 1, 'Pedestrian': 2, 'Cyclist': 3, 'Van': 4} + if cls_type not in type_to_id.keys(): + return -1 + return type_to_id[cls_type] + + +class Object3d(object): + def __init__(self, line): + label = line.strip().split(' ') + self.src = line + self.cls_type = label[0] + self.cls_id = cls_type_to_id(self.cls_type) + self.trucation = float(label[1]) + self.occlusion = float(label[2]) # 0:fully visible 1:partly occluded 2:largely occluded 3:unknown + self.alpha = float(label[3]) + self.box2d = np.array((float(label[4]), float(label[5]), float(label[6]), float(label[7])), dtype=np.float32) + self.h = float(label[8]) + self.w = float(label[9]) + self.l = float(label[10]) + self.pos = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32) + self.dis_to_cam = np.linalg.norm(self.pos) + self.ry = float(label[14]) + self.score = float(label[15]) if label.__len__() == 16 else -1.0 + self.level_str = None + self.level = self.get_obj_level() + + def get_obj_level(self): + height = float(self.box2d[3]) - float(self.box2d[1]) + 1 + + if height >= 40 and self.trucation <= 0.15 and self.occlusion <= 0: + self.level_str = 'Easy' + return 1 # Easy + elif height >= 25 and self.trucation <= 0.3 and self.occlusion <= 1: + self.level_str = 'Moderate' + return 2 # Moderate + elif height >= 25 and self.trucation <= 0.5 and self.occlusion <= 2: + self.level_str = 'Hard' + return 3 # Hard + else: + self.level_str = 'UnKnown' + return 4 + + def generate_corners3d(self): + """ + generate corners3d representation for this object + :return corners_3d: (8, 3) corners of box3d in camera coord + """ + l, h, w = self.l, self.h, self.w + x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2] + y_corners = [0, 0, 0, 0, -h, -h, -h, -h] + z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2] + + R = np.array([[np.cos(self.ry), 0, np.sin(self.ry)], + [0, 1, 0], + [-np.sin(self.ry), 0, np.cos(self.ry)]]) + corners3d = np.vstack([x_corners, y_corners, z_corners]) # (3, 8) + corners3d = np.dot(R, corners3d).T + corners3d = corners3d + self.pos + return corners3d + + def to_str(self): + print_str = '%s %.3f %.3f %.3f box2d: %s hwl: [%.3f %.3f %.3f] pos: %s ry: %.3f' \ + % (self.cls_type, self.trucation, self.occlusion, self.alpha, self.box2d, self.h, self.w, self.l, + self.pos, self.ry) + return print_str + + def to_kitti_format(self): + kitti_str = '%s %.2f %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f' \ + % (self.cls_type, self.trucation, int(self.occlusion), self.alpha, self.box2d[0], self.box2d[1], + self.box2d[2], self.box2d[3], self.h, self.w, self.l, self.pos[0], self.pos[1], self.pos[2], + self.ry) + return kitti_str + + +def get_calib_from_file(calib_file): + with open(calib_file) as f: + lines = f.readlines() + + obj = lines[2].strip().split(' ')[1:] + P2 = np.array(obj, dtype=np.float32) + obj = lines[3].strip().split(' ')[1:] + P3 = np.array(obj, dtype=np.float32) + obj = lines[4].strip().split(' ')[1:] + R0 = np.array(obj, dtype=np.float32) + obj = lines[5].strip().split(' ')[1:] + Tr_velo_to_cam = np.array(obj, dtype=np.float32) + + return {'P2': P2.reshape(3, 4), + 'P3': P3.reshape(3, 4), + 'R0': R0.reshape(3, 3), + 'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)} + + +class Calibration(object): + def __init__(self, calib_file): + if isinstance(calib_file, str): + calib = get_calib_from_file(calib_file) + else: + calib = calib_file + + self.P2 = calib['P2'] # 3 x 4 + self.R0 = calib['R0'] # 3 x 3 + self.V2C = calib['Tr_velo2cam'] # 3 x 4 + + def cart_to_hom(self, pts): + """ + :param pts: (N, 3 or 2) + :return pts_hom: (N, 4 or 3) + """ + pts_hom = np.hstack((pts, np.ones((pts.shape[0], 1), dtype=np.float32))) + return pts_hom + + def lidar_to_rect(self, pts_lidar): + """ + :param pts_lidar: (N, 3) + :return pts_rect: (N, 3) + """ + pts_lidar_hom = self.cart_to_hom(pts_lidar) + pts_rect = np.dot(pts_lidar_hom, np.dot(self.V2C.T, self.R0.T)) + return pts_rect + + def rect_to_img(self, pts_rect): + """ + :param pts_rect: (N, 3) + :return pts_img: (N, 2) + """ + pts_rect_hom = self.cart_to_hom(pts_rect) + pts_2d_hom = np.dot(pts_rect_hom, self.P2.T) + pts_img = (pts_2d_hom[:, 0:2].T / pts_rect_hom[:, 2]).T # (N, 2) + pts_rect_depth = pts_2d_hom[:, 2] - self.P2.T[3, 2] # depth in rect camera coord + return pts_img, pts_rect_depth + + def lidar_to_img(self, pts_lidar): + """ + :param pts_lidar: (N, 3) + :return pts_img: (N, 2) + """ + pts_rect = self.lidar_to_rect(pts_lidar) + pts_img, pts_depth = self.rect_to_img(pts_rect) + return pts_img, pts_depth + + +def get_objects_from_label(label_file): + with open(label_file, 'r') as f: + lines = f.readlines() + objects = [Object3d(line) for line in lines] + return objects + + +def objs_to_boxes3d(obj_list): + boxes3d = np.zeros((obj_list.__len__(), 7), dtype=np.float32) + for k, obj in enumerate(obj_list): + boxes3d[k, 0:3], boxes3d[k, 3], boxes3d[k, 4], boxes3d[k, 5], boxes3d[k, 6] \ + = obj.pos, obj.h, obj.w, obj.l, obj.ry + return boxes3d + + +def boxes3d_to_corners3d(boxes3d, rotate=True): + """ + :param boxes3d: (N, 7) [x, y, z, h, w, l, ry] + :param rotate: + :return: corners3d: (N, 8, 3) + """ + boxes_num = boxes3d.shape[0] + h, w, l = boxes3d[:, 3], boxes3d[:, 4], boxes3d[:, 5] + x_corners = np.array([l / 2., l / 2., -l / 2., -l / 2., l / 2., l / 2., -l / 2., -l / 2.], dtype=np.float32).T # (N, 8) + z_corners = np.array([w / 2., -w / 2., -w / 2., w / 2., w / 2., -w / 2., -w / 2., w / 2.], dtype=np.float32).T # (N, 8) + + y_corners = np.zeros((boxes_num, 8), dtype=np.float32) + y_corners[:, 4:8] = -h.reshape(boxes_num, 1).repeat(4, axis=1) # (N, 8) + + if rotate: + ry = boxes3d[:, 6] + zeros, ones = np.zeros(ry.size, dtype=np.float32), np.ones(ry.size, dtype=np.float32) + rot_list = np.array([[np.cos(ry), zeros, -np.sin(ry)], + [zeros, ones, zeros], + [np.sin(ry), zeros, np.cos(ry)]]) # (3, 3, N) + R_list = np.transpose(rot_list, (2, 0, 1)) # (N, 3, 3) + + temp_corners = np.concatenate((x_corners.reshape(-1, 8, 1), y_corners.reshape(-1, 8, 1), + z_corners.reshape(-1, 8, 1)), axis=2) # (N, 8, 3) + rotated_corners = np.matmul(temp_corners, R_list) # (N, 8, 3) + x_corners, y_corners, z_corners = rotated_corners[:, :, 0], rotated_corners[:, :, 1], rotated_corners[:, :, 2] + + x_loc, y_loc, z_loc = boxes3d[:, 0], boxes3d[:, 1], boxes3d[:, 2] + + x = x_loc.reshape(-1, 1) + x_corners.reshape(-1, 8) + y = y_loc.reshape(-1, 1) + y_corners.reshape(-1, 8) + z = z_loc.reshape(-1, 1) + z_corners.reshape(-1, 8) + + corners = np.concatenate((x.reshape(-1, 8, 1), y.reshape(-1, 8, 1), z.reshape(-1, 8, 1)), axis=2) + + return corners.astype(np.float32) + + +def enlarge_box3d(boxes3d, extra_width): + """ + :param boxes3d: (N, 7) [x, y, z, h, w, l, ry] + """ + if isinstance(boxes3d, np.ndarray): + large_boxes3d = boxes3d.copy() + else: + large_boxes3d = boxes3d.clone() + large_boxes3d[:, 3:6] += extra_width * 2 + large_boxes3d[:, 1] += extra_width + return large_boxes3d + + +def in_hull(p, hull): + """ + :param p: (N, K) test points + :param hull: (M, K) M corners of a box + :return (N) bool + """ + try: + if not isinstance(hull, Delaunay): + hull = Delaunay(hull) + flag = hull.find_simplex(p) >= 0 + except scipy.spatial.qhull.QhullError: + print('Warning: not a hull %s' % str(hull)) + flag = np.zeros(p.shape[0], dtype=np.bool) + + return flag diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/pointnet2_msg.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/pointnet2_msg.py new file mode 100755 index 0000000..59a2207 --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/pointnet2_msg.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +import sys +sys.path.append('..') +from pointnet2.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG +import pointnet2.pytorch_utils as pt_utils + + +def get_model(input_channels=0): + return Pointnet2MSG(input_channels=input_channels) + + +NPOINTS = [4096, 1024, 256, 64] +RADIUS = [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]] +NSAMPLE = [[16, 32], [16, 32], [16, 32], [16, 32]] +MLPS = [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]] +FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]] +CLS_FC = [128] +DP_RATIO = 0.5 + + +class Pointnet2MSG(nn.Module): + def __init__(self, input_channels=6): + super().__init__() + + self.SA_modules = nn.ModuleList() + channel_in = input_channels + + skip_channel_list = [input_channels] + for k in range(NPOINTS.__len__()): + mlps = MLPS[k].copy() + channel_out = 0 + for idx in range(mlps.__len__()): + mlps[idx] = [channel_in] + mlps[idx] + channel_out += mlps[idx][-1] + + self.SA_modules.append( + PointnetSAModuleMSG( + npoint=NPOINTS[k], + radii=RADIUS[k], + nsamples=NSAMPLE[k], + mlps=mlps, + use_xyz=True, + bn=True + ) + ) + skip_channel_list.append(channel_out) + channel_in = channel_out + + self.FP_modules = nn.ModuleList() + + for k in range(FP_MLPS.__len__()): + pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out + self.FP_modules.append( + PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k]) + ) + + cls_layers = [] + pre_channel = FP_MLPS[0][-1] + for k in range(0, CLS_FC.__len__()): + cls_layers.append(pt_utils.Conv1d(pre_channel, CLS_FC[k], bn=True)) + pre_channel = CLS_FC[k] + cls_layers.append(pt_utils.Conv1d(pre_channel, 1, activation=None)) + cls_layers.insert(1, nn.Dropout(0.5)) + self.cls_layer = nn.Sequential(*cls_layers) + + def _break_up_pc(self, pc): + xyz = pc[..., 0:3].contiguous() + features = ( + pc[..., 3:].transpose(1, 2).contiguous() + if pc.size(-1) > 3 else None + ) + + return xyz, features + + def forward(self, pointcloud: torch.cuda.FloatTensor): + xyz, features = self._break_up_pc(pointcloud) + + l_xyz, l_features = [xyz], [features] + for i in range(len(self.SA_modules)): + li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) + + print(li_xyz.shape, li_features.shape) + + l_xyz.append(li_xyz) + l_features.append(li_features) + + for i in range(-1, -(len(self.FP_modules) + 1), -1): + l_features[i - 1] = self.FP_modules[i]( + l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] + ) + + pred_cls = self.cls_layer(l_features[0]).transpose(1, 2).contiguous() # (B, N, 1) + return pred_cls + +if __name__ == '__main__': + net = Pointnet2MSG(0).cuda() + pts = torch.randn(2, 1024, 3).cuda() + + pre = net(pts) + print(pre.shape) diff --git a/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/train_and_eval.py b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/train_and_eval.py new file mode 100755 index 0000000..d35502b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/pointnet2_utils/tools/train_and_eval.py @@ -0,0 +1,217 @@ +import _init_path +import numpy as np +import os +import torch +import torch.nn as nn +import torch.optim as optim +import torch.optim.lr_scheduler as lr_sched +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader +import tensorboard_logger as tb_log +from dataset import KittiDataset +import argparse +import importlib + +parser = argparse.ArgumentParser(description="Arg parser") +parser.add_argument("--batch_size", type=int, default=8) +parser.add_argument("--epochs", type=int, default=100) +parser.add_argument("--ckpt_save_interval", type=int, default=5) +parser.add_argument('--workers', type=int, default=4) +parser.add_argument("--mode", type=str, default='train') +parser.add_argument("--ckpt", type=str, default='None') + +parser.add_argument("--net", type=str, default='pointnet2_msg') + +parser.add_argument('--lr', type=float, default=0.002) +parser.add_argument('--lr_decay', type=float, default=0.2) +parser.add_argument('--lr_clip', type=float, default=0.000001) +parser.add_argument('--decay_step_list', type=list, default=[50, 70, 80, 90]) +parser.add_argument('--weight_decay', type=float, default=0.001) + +parser.add_argument("--output_dir", type=str, default='output') +parser.add_argument("--extra_tag", type=str, default='default') + +args = parser.parse_args() + +FG_THRESH = 0.3 + + +def log_print(info, log_f=None): + print(info) + if log_f is not None: + print(info, file=log_f) + + +class DiceLoss(nn.Module): + def __init__(self, ignore_target=-1): + super().__init__() + self.ignore_target = ignore_target + + def forward(self, input, target): + """ + :param input: (N), logit + :param target: (N), {0, 1} + :return: + """ + input = torch.sigmoid(input.view(-1)) + target = target.float().view(-1) + mask = (target != self.ignore_target).float() + return 1.0 - (torch.min(input, target) * mask).sum() / torch.clamp((torch.max(input, target) * mask).sum(), min=1.0) + + +def train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f): + model.train() + log_print('===============TRAIN EPOCH %d================' % epoch, log_f=log_f) + loss_func = DiceLoss(ignore_target=-1) + + for it, batch in enumerate(train_loader): + optimizer.zero_grad() + + pts_input, cls_labels = batch['pts_input'], batch['cls_labels'] + pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float() + cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1) + + pred_cls = model(pts_input) + pred_cls = pred_cls.view(-1) + + loss = loss_func(pred_cls, cls_labels) + loss.backward() + clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + total_it += 1 + + pred_class = (torch.sigmoid(pred_cls) > FG_THRESH) + fg_mask = cls_labels > 0 + correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum() + union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct + iou = correct / torch.clamp(union, min=1.0) + + cur_lr = lr_scheduler.get_lr()[0] + tb_log.log_value('learning_rate', cur_lr, epoch) + if tb_log is not None: + tb_log.log_value('train_loss', loss, total_it) + tb_log.log_value('train_fg_iou', iou, total_it) + + log_print('training epoch %d: it=%d/%d, total_it=%d, loss=%.5f, fg_iou=%.3f, lr=%f' % + (epoch, it, len(train_loader), total_it, loss.item(), iou.item(), cur_lr), log_f=log_f) + + return total_it + + +def eval_one_epoch(model, eval_loader, epoch, tb_log=None, log_f=None): + model.train() + log_print('===============EVAL EPOCH %d================' % epoch, log_f=log_f) + + iou_list = [] + for it, batch in enumerate(eval_loader): + pts_input, cls_labels = batch['pts_input'], batch['cls_labels'] + pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float() + cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1) + + pred_cls = model(pts_input) + pred_cls = pred_cls.view(-1) + + pred_class = (torch.sigmoid(pred_cls) > FG_THRESH) + fg_mask = cls_labels > 0 + correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum() + union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct + iou = correct / torch.clamp(union, min=1.0) + + iou_list.append(iou.item()) + log_print('EVAL: it=%d/%d, iou=%.3f' % (it, len(eval_loader), iou), log_f=log_f) + + iou_list = np.array(iou_list) + avg_iou = iou_list.mean() + if tb_log is not None: + tb_log.log_value('eval_fg_iou', avg_iou, epoch) + + log_print('\nEpoch %d: Average IoU (samples=%d): %.6f' % (epoch, iou_list.__len__(), avg_iou), log_f=log_f) + return avg_iou + + +def save_checkpoint(model, epoch, ckpt_name): + if isinstance(model, torch.nn.DataParallel): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() + + state = {'epoch': epoch, 'model_state': model_state} + ckpt_name = '{}.pth'.format(ckpt_name) + torch.save(state, ckpt_name) + + +def load_checkpoint(model, filename): + if os.path.isfile(filename): + log_print("==> Loading from checkpoint %s" % filename) + checkpoint = torch.load(filename) + epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['model_state']) + log_print("==> Done") + else: + raise FileNotFoundError + + return epoch + + +def train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f): + model.cuda() + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + def lr_lbmd(cur_epoch): + cur_decay = 1 + for decay_step in args.decay_step_list: + if cur_epoch >= decay_step: + cur_decay = cur_decay * args.lr_decay + return max(cur_decay, args.lr_clip / args.lr) + + lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) + + total_it = 0 + for epoch in range(1, args.epochs + 1): + lr_scheduler.step(epoch) + total_it = train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f) + + if epoch % args.ckpt_save_interval == 0: + with torch.no_grad(): + avg_iou = eval_one_epoch(model, eval_loader, epoch, tb_log, log_f) + ckpt_name = os.path.join(ckpt_dir, 'checkpoint_epoch_%d' % epoch) + save_checkpoint(model, epoch, ckpt_name) + + +if __name__ == '__main__': + MODEL = importlib.import_module(args.net) # import network module + model = MODEL.get_model(input_channels=0) + + eval_set = KittiDataset(root_dir='./data', mode='EVAL', split='val') + eval_loader = DataLoader(eval_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, + num_workers=args.workers, collate_fn=eval_set.collate_batch) + + if args.mode == 'train': + train_set = KittiDataset(root_dir='./data', mode='TRAIN', split='train') + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True, + num_workers=args.workers, collate_fn=train_set.collate_batch) + # output dir config + output_dir = os.path.join(args.output_dir, args.extra_tag) + os.makedirs(output_dir, exist_ok=True) + tb_log.configure(os.path.join(output_dir, 'tensorboard')) + ckpt_dir = os.path.join(output_dir, 'ckpt') + os.makedirs(ckpt_dir, exist_ok=True) + + log_file = os.path.join(output_dir, 'log.txt') + log_f = open(log_file, 'w') + + for key, val in vars(args).items(): + log_print("{:16} {}".format(key, val), log_f=log_f) + + # train and eval + train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f) + log_f.close() + elif args.mode == 'eval': + epoch = load_checkpoint(model, args.ckpt) + model.cuda() + with torch.no_grad(): + avg_iou = eval_one_epoch(model, eval_loader, epoch) + else: + raise NotImplementedError + diff --git a/src/active_grasp/active_perception/modules/module_lib/position_embedding.py b/src/active_grasp/active_perception/modules/module_lib/position_embedding.py new file mode 100755 index 0000000..682bf5c --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/position_embedding.py @@ -0,0 +1,17 @@ +import torch + + +class PositionalEmbedding(torch.nn.Module): + def __init__(self, num_channels, max_positions=10000, endpoint=False): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x diff --git a/src/active_grasp/active_perception/modules/module_lib/rot_head.py b/src/active_grasp/active_perception/modules/module_lib/rot_head.py new file mode 100755 index 0000000..819036b --- /dev/null +++ b/src/active_grasp/active_perception/modules/module_lib/rot_head.py @@ -0,0 +1,41 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F + + +class RotHead(nn.Module): + def __init__(self, in_feat_dim, out_dim=3): + super(RotHead, self).__init__() + self.f = in_feat_dim + self.k = out_dim + + self.conv1 = torch.nn.Conv1d(self.f, 1024, 1) + self.conv2 = torch.nn.Conv1d(1024, 256, 1) + self.conv3 = torch.nn.Conv1d(256, 256, 1) + self.conv4 = torch.nn.Conv1d(256, self.k, 1) + self.drop1 = nn.Dropout(0.2) + self.bn1 = nn.BatchNorm1d(1024) + self.bn2 = nn.BatchNorm1d(256) + self.bn3 = nn.BatchNorm1d(256) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + + x = torch.max(x, 2, keepdim=True)[0] + + x = F.relu(self.bn3(self.conv3(x))) + x = self.drop1(x) + x = self.conv4(x) + + x = x.squeeze(2) + x = x.contiguous() + + return x + + +if __name__ == "__main__": + points = torch.rand(2, 1350, 1024) # batch_size x feature x num_of_point + rot_head = RotHead(in_feat_dim=1350, out_dim=3) + rot = rot_head(points) + print(rot.shape) diff --git a/src/active_grasp/active_perception/modules/pipeline.py b/src/active_grasp/active_perception/modules/pipeline.py new file mode 100755 index 0000000..20c9858 --- /dev/null +++ b/src/active_grasp/active_perception/modules/pipeline.py @@ -0,0 +1,144 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') + +import torch +from torch import nn +import inspect + +from configs.config import ConfigManager + +from modules.pts_encoder.pts_encoder_factory import PointsEncoderFactory +from modules.view_finder.view_finder_factory import ViewFinderFactory +from modules.module_lib.fusion_layer import FeatureFusion +from modules.rgb_encoder.rgb_encoder_factory import RGBEncoderFactory + + +class Pipeline(nn.Module): + TRAIN_MODE: str = "train" + TEST_MODE: str = "test" + + def __init__(self, pipeline_config): + super(Pipeline, self).__init__() + + self.modules_config = ConfigManager.get("modules") + self.device = ConfigManager.get("settings", "general", "device") + self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache") + self.pts_encoder = PointsEncoderFactory.create(pipeline_config["pts_encoder"], self.modules_config) + self.view_finder = ViewFinderFactory.create(pipeline_config["view_finder"], self.modules_config) + self.has_rgb_encoder = "rgb_encoder" in pipeline_config + if self.has_rgb_encoder and not self.rgb_feat_cache: + self.rgb_encoder = RGBEncoderFactory.create(pipeline_config["rgb_encoder"], self.modules_config) + self.eps = 1e-5 + self.fusion_layer = FeatureFusion(rgb_dim=384, pts_dim=1024,output_dim=1024) + + self.to(self.device) + + def forward(self, data, mode): + if mode == self.TRAIN_MODE: + return self.forward_gradient(data) + elif mode == self.TEST_MODE: + return self.forward_view(data) + raise ValueError("Unknown mode: {}".format(self.mode)) + + def forward_gradient(self, data): + target_pts = data["target_pts"] + scene_pts = data["scene_pts"] + gt_delta_rot_6d = data["delta_rot_6d"] + + if hasattr(self,"rgb_encoder"): + if "rgb" in data: + rgb_feat = self.rgb_encoder.encode_rgb(data["rgb"]) + else: + rgb_feat = data["rgb_feat"] + if "rgb_feat" not in inspect.signature(self.pts_encoder.encode_points).parameters: + target_feat = self.pts_encoder.encode_points(target_pts) + scene_feat = self.pts_encoder.encode_points(scene_pts) + target_feat = self.fusion_layer(rgb_feat, target_feat) + scene_feat = self.fusion_layer(rgb_feat, scene_feat) + else: + target_feat = self.pts_encoder.encode_points(target_pts, rgb_feat) + scene_feat = self.pts_encoder.encode_points(scene_pts, rgb_feat) + else: + target_feat = self.pts_encoder.encode_points(target_pts) + scene_feat = self.pts_encoder.encode_points(scene_pts) + ''' get std ''' + bs = target_pts.shape[0] + random_t = torch.rand(bs, device=self.device) * (1. - self.eps) + self.eps + random_t = random_t.unsqueeze(-1) + mu, std = self.view_finder.marginal_prob(gt_delta_rot_6d, random_t) + std = std.view(-1, 1) + + ''' perturb data and get estimated score ''' + z = torch.randn_like(gt_delta_rot_6d) + perturbed_x = mu + z * std + input_data = { + "sampled_pose": perturbed_x, + "t": random_t, + "scene_feat": scene_feat, + "target_feat": target_feat + } + estimated_score = self.view_finder(input_data) + + ''' get target score ''' + target_score = - z * std / (std ** 2) + + result = { + "estimated_score": estimated_score, + "target_score": target_score, + "std": std + } + return result + + def forward_view(self, data): + target_pts = data["target_pts"] + scene_pts = data["scene_pts"] + + if self.has_rgb_encoder : + if self.rgb_feat_cache: + rgb_feat = data["rgb_feat"] + else: + rgb = data["rgb"] + rgb_feat = self.rgb_encoder.encode_rgb(rgb) + if "rgb_feat" not in inspect.signature(self.pts_encoder.encode_points).parameters: + target_feat = self.pts_encoder.encode_points(target_pts) + scene_feat = self.pts_encoder.encode_points(scene_pts) + target_feat = self.fusion_layer(rgb_feat, target_feat) + scene_feat = self.fusion_layer(rgb_feat, scene_feat) + else: + target_feat = self.pts_encoder.encode_points(target_pts, rgb_feat) + scene_feat = self.pts_encoder.encode_points(scene_pts, rgb_feat) + else: + target_feat = self.pts_encoder.encode_points(target_pts) + scene_feat = self.pts_encoder.encode_points(scene_pts) + estimated_delta_rot_6d, in_process_sample = self.view_finder.next_best_view(scene_feat, target_feat) + result = { + "estimated_delta_rot_6d": estimated_delta_rot_6d, + "in_process_sample": in_process_sample + } + return result + + +if __name__ == '__main__': + ConfigManager.load_config_with('../configs/server_train_config.yaml') + ConfigManager.print_config() + test_pipeline_config = ConfigManager.get("settings", "pipeline") + pipeline = Pipeline(test_pipeline_config) + test_scene = torch.rand(32, 1024, 3).to("cuda:0") + test_target = torch.rand(32, 1024, 3).to("cuda:0") + test_delta_rot_6d = torch.rand(32, 6).to("cuda:0") + a = test_delta_rot_6d[:, :3] + b = test_delta_rot_6d[:, 3:] + a_norm = a / a.norm(dim=1, keepdim=True) + b_norm = b / b.norm(dim=1, keepdim=True) + normalized_test_delta_rot_6d = torch.cat((a_norm, b_norm), dim=1) + test_data = { + 'target_pts': test_target, + 'scene_pts': test_scene, + 'delta_rot_6d': normalized_test_delta_rot_6d + } + # out_data = pipeline(test_data, "train") + # print(out_data.keys()) + out_data_test = pipeline(test_data, "test") + print(out_data_test.keys()) + print(out_data_test["estimated_delta_rot_6d"]) diff --git a/src/active_grasp/active_perception/modules/pts_encoder/__init__.py b/src/active_grasp/active_perception/modules/pts_encoder/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/src/active_grasp/active_perception/modules/pts_encoder/abstract_pts_encoder.py b/src/active_grasp/active_perception/modules/pts_encoder/abstract_pts_encoder.py new file mode 100755 index 0000000..f094892 --- /dev/null +++ b/src/active_grasp/active_perception/modules/pts_encoder/abstract_pts_encoder.py @@ -0,0 +1,12 @@ +from abc import abstractmethod + +from torch import nn + + +class PointsEncoder(nn.Module): + def __init__(self): + super(PointsEncoder, self).__init__() + + @abstractmethod + def encode_points(self, pts): + pass diff --git a/src/active_grasp/active_perception/modules/pts_encoder/pointnet2_encoder.py b/src/active_grasp/active_perception/modules/pts_encoder/pointnet2_encoder.py new file mode 100755 index 0000000..51c1914 --- /dev/null +++ b/src/active_grasp/active_perception/modules/pts_encoder/pointnet2_encoder.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +import os +import sys +path = os.path.abspath(__file__) +for i in range(3): + path = os.path.dirname(path) +PROJECT_ROOT = path +sys.path.append(PROJECT_ROOT) +from modules.module_lib.pointnet2_utils.pointnet2.pointnet2_modules import PointnetSAModuleMSG +from modules.pts_encoder.abstract_pts_encoder import PointsEncoder + +ClsMSG_CFG_Dense = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Light = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Lighter = { + 'NPOINTS': [512, 256, 128, 64, None], + 'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]], + 'NSAMPLE': [[64], [32], [16], [8], [None]], + 'MLPS': [[[32, 32, 64]], + [[64, 64, 128]], + [[128, 196, 256]], + [[256, 256, 512]], + [[512, 512, 1024]]], + 'DP_RATIO': 0.5, +} + + +def select_params(name): + if name == 'light': + return ClsMSG_CFG_Light + elif name == 'lighter': + return ClsMSG_CFG_Lighter + elif name == 'dense': + return ClsMSG_CFG_Dense + else: + raise NotImplementedError + + +def break_up_pc(pc): + xyz = pc[..., 0:3].contiguous() + features = ( + pc[..., 3:].transpose(1, 2).contiguous() + if pc.size(-1) > 3 else None + ) + + return xyz, features + + +class PointNet2Encoder(PointsEncoder): + def encode_points(self, pts): + return self.forward(pts) + + def __init__(self, input_channels=6, params_name="light"): + super().__init__() + + self.SA_modules = nn.ModuleList() + channel_in = input_channels + selected_params = select_params(params_name) + for k in range(selected_params['NPOINTS'].__len__()): + mlps = selected_params['MLPS'][k].copy() + channel_out = 0 + for idx in range(mlps.__len__()): + mlps[idx] = [channel_in] + mlps[idx] + channel_out += mlps[idx][-1] + + self.SA_modules.append( + PointnetSAModuleMSG( + npoint=selected_params['NPOINTS'][k], + radii=selected_params['RADIUS'][k], + nsamples=selected_params['NSAMPLE'][k], + mlps=mlps, + use_xyz=True, + bn=True + ) + ) + channel_in = channel_out + + def forward(self, point_cloud: torch.cuda.FloatTensor): + xyz, features = break_up_pc(point_cloud) + + l_xyz, l_features = [xyz], [features] + for i in range(len(self.SA_modules)): + li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) + l_xyz.append(li_xyz) + l_features.append(li_features) + return l_features[-1].squeeze(-1) + + +if __name__ == '__main__': + seed = 100 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + net = PointNet2Encoder(0).cuda() + pts = torch.randn(2, 1024, 3).cuda() + print(torch.mean(pts, dim=1)) + pre = net.encode_points(pts) + print(pre.shape) diff --git a/src/active_grasp/active_perception/modules/pts_encoder/pointnet3_encoder.py b/src/active_grasp/active_perception/modules/pts_encoder/pointnet3_encoder.py new file mode 100755 index 0000000..05110fa --- /dev/null +++ b/src/active_grasp/active_perception/modules/pts_encoder/pointnet3_encoder.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +from modules.module_lib.pointnet2_utils.pointnet2.pointnet2_modules import PointnetSAModuleMSG +from modules.pts_encoder.abstract_pts_encoder import PointsEncoder + +ClsMSG_CFG_Dense = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Light = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Lighter = { + 'NPOINTS': [512, 256, 128, 64, None], + 'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]], + 'NSAMPLE': [[64], [32], [16], [8], [None]], + 'MLPS': [[[32, 32, 64]], + [[64, 64, 128]], + [[128, 196, 256]], + [[256, 256, 512]], + [[512, 512, 1024]]], + 'DP_RATIO': 0.5, +} + + +def select_params(name): + if name == 'light': + return ClsMSG_CFG_Light + elif name == 'lighter': + return ClsMSG_CFG_Lighter + elif name == 'dense': + return ClsMSG_CFG_Dense + else: + raise NotImplementedError + + +def break_up_pc(pc): + xyz = pc[..., 0:3].contiguous() + features = ( + pc[..., 3:].transpose(1, 2).contiguous() + if pc.size(-1) > 3 else None + ) + + return xyz, features + + +class PointNet3Encoder(PointsEncoder): + def encode_points(self, pts, rgb_feat): + return self.forward(pts,rgb_feat) + + def __init__(self, input_channels=6, params_name="light",target_layer=2, rgb_feat_dim=384): + super().__init__() + self.SA_modules = nn.ModuleList() + channel_in = input_channels + self.target_layer = target_layer + selected_params = select_params(params_name) + for k in range(selected_params['NPOINTS'].__len__()): + mlps = selected_params['MLPS'][k].copy() + channel_out = 0 + if k==target_layer: + channel_in += rgb_feat_dim + for idx in range(mlps.__len__()): + mlps[idx] = [channel_in] + mlps[idx] + channel_out += mlps[idx][-1] + + self.SA_modules.append( + PointnetSAModuleMSG( + npoint=selected_params['NPOINTS'][k], + radii=selected_params['RADIUS'][k], + nsamples=selected_params['NSAMPLE'][k], + mlps=mlps, + use_xyz=True, + bn=True + ) + ) + channel_in = channel_out + + def forward(self, point_cloud: torch.cuda.FloatTensor, rgb_feat): + xyz, features = break_up_pc(point_cloud) + + l_xyz, l_features = [xyz], [features] + for i in range(len(self.SA_modules)): + if i==self.target_layer: + rgb_feat = torch.mean(rgb_feat, dim=1) + rgb_feat = rgb_feat.unsqueeze(-1).repeat(1,1,l_xyz[i].shape[1]) + l_features[-1] = torch.cat([l_features[-1], rgb_feat], dim=1) + li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) + l_xyz.append(li_xyz) + l_features.append(li_features) + return l_features[-1].squeeze(-1) + + +if __name__ == '__main__': + seed = 100 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + net = PointNet3Encoder(0).cuda() + pts = torch.randn(2, 1024, 3).cuda() + rgb_feat = torch.randn(2, 384).cuda() + print(torch.mean(pts, dim=1)) + pre = net.encode_points(pts,rgb_feat) + print(pre.shape) diff --git a/src/active_grasp/active_perception/modules/pts_encoder/pointnet_encoder.py b/src/active_grasp/active_perception/modules/pts_encoder/pointnet_encoder.py new file mode 100755 index 0000000..fe9dc3c --- /dev/null +++ b/src/active_grasp/active_perception/modules/pts_encoder/pointnet_encoder.py @@ -0,0 +1,110 @@ +from __future__ import print_function +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.utils.data +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F + +from modules.pts_encoder.abstract_pts_encoder import PointsEncoder + + +class STNkd(nn.Module): + def __init__(self, k=64): + super(STNkd, self).__init__() + self.conv1 = torch.nn.Conv1d(k, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, k * k) + self.relu = nn.ReLU() + + self.k = k + + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 1024) + + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + + iden = ( + Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))) + .view(1, self.k * self.k) + .repeat(batchsize, 1) + ) + if x.is_cuda: + iden = iden.to(x.get_device()) + x = x + iden + x = x.view(-1, self.k, self.k) + return x + + +# NOTE: removed BN +class PointNetEncoder(PointsEncoder): + + def __init__(self, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False): + super(PointNetEncoder, self).__init__() + self.out_dim = out_dim + self.feature_transform = feature_transform + self.stn = STNkd(k=in_dim) + self.conv1 = torch.nn.Conv1d(in_dim, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 512, 1) + self.conv4 = torch.nn.Conv1d(512, out_dim, 1) + self.global_feat = global_feat + if self.feature_transform: + self.f_stn = STNkd(k=64) + + def forward(self, x): + n_pts = x.shape[2] + trans = self.stn(x) + x = x.transpose(2, 1) + x = torch.bmm(x, trans) + x = x.transpose(2, 1) + x = F.relu(self.conv1(x)) + + if self.feature_transform: + trans_feat = self.f_stn(x) + x = x.transpose(2, 1) + x = torch.bmm(x, trans_feat) + x = x.transpose(2, 1) + + point_feat = x + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = self.conv4(x) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, self.out_dim) + if self.global_feat: + return x + else: + x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts) + return torch.cat([x, point_feat], 1) + + def encode_points(self, pts): + pts = pts.transpose(2, 1) + if not self.global_feat: + pts_feature = self(pts).transpose(2, 1) + else: + pts_feature = self(pts) + return pts_feature + + +if __name__ == "__main__": + sim_data = Variable(torch.rand(32, 2500, 3)) + + pointnet_global = PointNetEncoder(global_feat=True) + out = pointnet_global.encode_points(sim_data) + print("global feat", out.size()) + + pointnet = PointNetEncoder(global_feat=False) + out = pointnet.encode_points(sim_data) + print("point feat", out.size()) diff --git a/src/active_grasp/active_perception/modules/pts_encoder/pts_encoder_factory.py b/src/active_grasp/active_perception/modules/pts_encoder/pts_encoder_factory.py new file mode 100755 index 0000000..41c570d --- /dev/null +++ b/src/active_grasp/active_perception/modules/pts_encoder/pts_encoder_factory.py @@ -0,0 +1,56 @@ +import sys +import os +path = os.path.abspath(__file__) +for i in range(3): + path = os.path.dirname(path) +PROJECT_ROOT = path +sys.path.append(PROJECT_ROOT) + +from modules.pts_encoder.abstract_pts_encoder import PointsEncoder +from modules.pts_encoder.pointnet_encoder import PointNetEncoder +from modules.pts_encoder.pointnet2_encoder import PointNet2Encoder +from modules.pts_encoder.pointnet3_encoder import PointNet3Encoder + +class PointsEncoderFactory: + @staticmethod + def create(name, config) -> PointsEncoder: + general_config = config["general"] + pts_encoder_config = config["pts_encoder"][name] + if name == "pointnet": + return PointNetEncoder( + in_dim=general_config["pts_channels"], + out_dim=general_config["feature_dim"], + global_feat=not general_config["per_point_feature"] + ) + elif name == "pointnet++": + return PointNet2Encoder( + input_channels=general_config["pts_channels"] - 3, + params_name=pts_encoder_config["params_name"] + ) + elif name == "pointnet++rgb": + return PointNet3Encoder( + input_channels=general_config["pts_channels"] - 3, + params_name=pts_encoder_config["params_name"], + target_layer=pts_encoder_config["target_layer"], + rgb_feat_dim=pts_encoder_config["rgb_feat_dim"] + ) + else: + raise ValueError(f"Unknown encoder name: {name}") + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + from configs.config import ConfigManager + import torch + + pts = torch.rand(32, 1200, 3) # BxNxC + ConfigManager.load_config_with('configs/local_train_config.yaml') + ConfigManager.print_config() + pts_encoder = PointsEncoderFactory.create(name="pointnet++", config=ConfigManager.get("modules")) + print(pts_encoder) + pts = pts.to("cuda") + pts_encoder = pts_encoder.to("cuda") + + pts_feat = pts_encoder.encode_points(pts) + + print(pts_feat.shape) diff --git a/src/active_grasp/active_perception/modules/rgb_encoder/__init__.py b/src/active_grasp/active_perception/modules/rgb_encoder/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/src/active_grasp/active_perception/modules/rgb_encoder/abstract_rgb_encoder.py b/src/active_grasp/active_perception/modules/rgb_encoder/abstract_rgb_encoder.py new file mode 100755 index 0000000..355773a --- /dev/null +++ b/src/active_grasp/active_perception/modules/rgb_encoder/abstract_rgb_encoder.py @@ -0,0 +1,51 @@ +from abc import abstractmethod +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import torch +from torch import nn +import numpy as np + + +class RGBEncoder(nn.Module): + def __init__(self): + super(RGBEncoder, self).__init__() + + @abstractmethod + def encode_rgb(self, rgb): + pass + + @staticmethod + def visualize_features(features, save_path=None): + patch,feat_dim = features.shape + patch_h = int(patch ** 0.5) + patch_w = patch_h + total_features = features.reshape(patch_h * patch_w, feat_dim) + pca = PCA(n_components=3) + if isinstance(total_features, torch.Tensor): + total_features = total_features.cpu().numpy() + pca.fit(total_features) + pca_features = pca.transform(total_features) + pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / \ + (pca_features[:, 0].max() - pca_features[:, 0].min()) + plt.subplot(1, 3, 1) + plt.imshow(pca_features[:,0].reshape(patch_h, patch_w)) + pca_features_bg = pca_features[:, 0] > 0.5 # from first histogram + pca_features_fg = np.ones_like(pca_features_bg) + plt.subplot(1, 3, 2) + plt.imshow(pca_features_bg.reshape(patch_h, patch_w)) + pca.fit(total_features[pca_features_fg]) + pca_features_left = pca.transform(total_features[pca_features_fg]) + for i in range(3): + pca_features_left[:, i] = (pca_features_left[:, i] - pca_features_left[:, i].min()) / (pca_features_left[:, i].max() - pca_features_left[:, i].min()) + + pca_features_rgb = pca_features.copy() + pca_features_rgb[pca_features_bg] = 0 + pca_features_rgb[pca_features_fg] = pca_features_left + pca_features_rgb = pca_features_rgb.reshape(1, patch_h, patch_w, 3) + + plt.subplot(1, 3, 3) + if save_path: + plt.imsave(save_path, pca_features_rgb[0]) + else: + plt.imshow(pca_features_rgb[0]) + plt.show() \ No newline at end of file diff --git a/src/active_grasp/active_perception/modules/rgb_encoder/dinov2_encoder.py b/src/active_grasp/active_perception/modules/rgb_encoder/dinov2_encoder.py new file mode 100755 index 0000000..1ab3ee0 --- /dev/null +++ b/src/active_grasp/active_perception/modules/rgb_encoder/dinov2_encoder.py @@ -0,0 +1,20 @@ + +import torch +from modules.rgb_encoder.abstract_rgb_encoder import RGBEncoder +from annotations.external_module import external_freeze + +@external_freeze +class Dinov2Encoder(RGBEncoder): + def __init__(self, model_name): + super(Dinov2Encoder, self).__init__() + self.model_name = model_name + self.load() + + def load(self): + self.dinov2 = torch.hub.load('modules/module_lib/dinov2', self.model_name, source='local').cuda() + + def encode_rgb(self, rgb): + with torch.no_grad(): + features_dict = self.dinov2.forward_features(rgb) + features = features_dict['x_norm_patchtokens'] + return features diff --git a/src/active_grasp/active_perception/modules/rgb_encoder/rgb_encoder_factory.py b/src/active_grasp/active_perception/modules/rgb_encoder/rgb_encoder_factory.py new file mode 100755 index 0000000..f85fed4 --- /dev/null +++ b/src/active_grasp/active_perception/modules/rgb_encoder/rgb_encoder_factory.py @@ -0,0 +1,59 @@ +import sys +import os +path = os.path.abspath(__file__) +for i in range(3): + path = os.path.dirname(path) +PROJECT_ROOT = path +sys.path.append(PROJECT_ROOT) + +from modules.rgb_encoder.abstract_rgb_encoder import RGBEncoder +from modules.rgb_encoder.dinov2_encoder import Dinov2Encoder + + +class RGBEncoderFactory: + @staticmethod + def create(name, config) -> RGBEncoder: + general_config = config["general"] + rgb_encoder_config = config["rgb_encoder"][name] + if name == "dinov2": + return Dinov2Encoder( + model_name=rgb_encoder_config["model_name"] + ) + else: + raise ValueError(f"Unknown encoder name: {name}") + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + from configs.config import ConfigManager + import torch + from PIL import Image + import cv2 + from torchvision import transforms + ConfigManager.load_config_with('configs/local_train_config.yaml') + ConfigManager.print_config() + image_size = 480 + path = "/mnt/h/BaiduSyncdisk/workspace/ws_active_pose/project/ActivePerception/test/img0.jpg" + img = cv2.imread(path) + img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + + transform = transforms.Compose([ + transforms.Resize(image_size), + transforms.CenterCrop(int(image_size//14)*14), + transforms.ToTensor(), + transforms.Normalize(mean=0.5, std=0.2) + ]) + + rgb = transform(img) + print(rgb.shape) + rgb_encoder = RGBEncoderFactory.create(name="dinov2", config=ConfigManager.get("modules")) + rgb_encoder.load() + print(rgb_encoder) + rgb = rgb.to("cuda:0") + rgb = rgb.unsqueeze(0) + rgb_encoder = rgb_encoder.to("cuda:0") + + rgb_feat = rgb_encoder.encode_rgb(rgb) + + print(rgb_feat.shape) + rgb_encoder.visualize_features(rgb_feat[0]) diff --git a/src/active_grasp/active_perception/modules/view_finder/__init__.py b/src/active_grasp/active_perception/modules/view_finder/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/src/active_grasp/active_perception/modules/view_finder/abstract_view_finder.py b/src/active_grasp/active_perception/modules/view_finder/abstract_view_finder.py new file mode 100755 index 0000000..b688c16 --- /dev/null +++ b/src/active_grasp/active_perception/modules/view_finder/abstract_view_finder.py @@ -0,0 +1,12 @@ +from abc import abstractmethod + +from torch import nn + + +class ViewFinder(nn.Module): + def __init__(self): + super(ViewFinder, self).__init__() + + @abstractmethod + def next_best_view(self, scene_pts_feat, target_pts_feat): + pass diff --git a/src/active_grasp/active_perception/modules/view_finder/gf_view_finder.py b/src/active_grasp/active_perception/modules/view_finder/gf_view_finder.py new file mode 100755 index 0000000..b0fe790 --- /dev/null +++ b/src/active_grasp/active_perception/modules/view_finder/gf_view_finder.py @@ -0,0 +1,165 @@ +import torch +import torch.nn as nn +from utils.pose_util import PoseUtil +from modules.view_finder.abstract_view_finder import ViewFinder +import modules.module_lib as mlib +import modules.func_lib as flib + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class GradientFieldViewFinder(ViewFinder): + def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False, + sample_mode="ode", sampling_steps=None, sde_mode="ve"): + + super(GradientFieldViewFinder, self).__init__() + self.regression_head = regression_head + self.per_point_feature = per_point_feature + self.act = nn.ReLU(True) + self.sample_mode = sample_mode + self.pose_mode = pose_mode + pose_dim = PoseUtil.get_pose_dim(pose_mode) + self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode) + self.sampling_steps = sampling_steps + + ''' encode pose ''' + self.pose_encoder = nn.Sequential( + nn.Linear(pose_dim, 256), + self.act, + nn.Linear(256, 256), + self.act, + ) + + ''' encode t ''' + self.t_encoder = nn.Sequential( + mlib.GaussianFourierProjection(embed_dim=128), + nn.Linear(128, 128), + self.act, + ) + + ''' fusion tail ''' + if self.regression_head == 'Rx_Ry': + if pose_mode != 'rot_matrix': + raise NotImplementedError + if not per_point_feature: + ''' rotation_x_axis regress head ''' + self.fusion_tail_rot_x = nn.Sequential( + nn.Linear(128 + 256 + 1024 + 1024, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + self.fusion_tail_rot_y = nn.Sequential( + nn.Linear(128 + 256 + 1024 + 1024, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + def forward(self, data): + """ + Args: + data, dict { + 'target_pts_feat': [bs, c] + 'scene_pts_feat': [bs, c] + 'pose_sample': [bs, pose_dim] + 't': [bs, 1] + } + """ + + scene_pts_feat = data['scene_feat'] + target_pts_feat = data['target_feat'] + sampled_pose = data['sampled_pose'] + t = data['t'] + t_feat = self.t_encoder(t.squeeze(1)) + pose_feat = self.pose_encoder(sampled_pose) + + if self.per_point_feature: + raise NotImplementedError + else: + total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1) + _, std = self.marginal_prob_fn(total_feat, t) + + if self.regression_head == 'Rx_Ry': + rot_x = self.fusion_tail_rot_x(total_feat) + rot_y = self.fusion_tail_rot_y(total_feat) + out_score = torch.cat([rot_x, rot_y], dim=-1) / (std + 1e-7) # normalisation + else: + raise NotImplementedError + + return out_score + + def marginal_prob(self, x, t): + return self.marginal_prob_fn(x,t) + + def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None): + + if self.sample_mode == 'pc': + in_process_sample, res = flib.cond_pc_sampler( + score_model=self, + data=data, + prior=self.prior_fn, + sde_coeff=self.sde_fn, + num_steps=self.sampling_steps, + snr=snr, + eps=self.sampling_eps, + pose_mode=self.pose_mode, + init_x=init_x + ) + + elif self.sample_mode == 'ode': + T0 = self.T if T0 is None else T0 + in_process_sample, res = flib.cond_ode_sampler( + score_model=self, + data=data, + prior=self.prior_fn, + sde_coeff=self.sde_fn, + atol=atol, + rtol=rtol, + eps=self.sampling_eps, + T=T0, + num_steps=self.sampling_steps, + pose_mode=self.pose_mode, + denoise=denoise, + init_x=init_x + ) + else: + raise NotImplementedError + + return in_process_sample, res + + def next_best_view(self, scene_pts_feat, target_pts_feat): + data = { + 'scene_feat': scene_pts_feat, + 'target_feat': target_pts_feat, + } + in_process_sample, res = self.sample(data) + return res.to(dtype=torch.float32), in_process_sample + + +''' ----------- DEBUG -----------''' +if __name__ == "__main__": + test_scene_feat = torch.rand(32, 1024).to("cuda:0") + test_target_feat = torch.rand(32, 1024).to("cuda:0") + test_pose = torch.rand(32, 6).to("cuda:0") + test_t = torch.rand(32, 1).to("cuda:0") + view_finder = GradientFieldViewFinder().to("cuda:0") + test_data = { + 'target_feat': test_target_feat, + 'scene_feat': test_scene_feat, + 'sampled_pose': test_pose, + 't': test_t + } + score = view_finder(test_data) + + result = view_finder.next_best_view(test_scene_feat, test_target_feat) + print(result) diff --git a/src/active_grasp/active_perception/modules/view_finder/view_finder_factory.py b/src/active_grasp/active_perception/modules/view_finder/view_finder_factory.py new file mode 100755 index 0000000..92adfee --- /dev/null +++ b/src/active_grasp/active_perception/modules/view_finder/view_finder_factory.py @@ -0,0 +1,45 @@ +from modules.view_finder.abstract_view_finder import ViewFinder +from modules.view_finder.gf_view_finder import GradientFieldViewFinder + + +class ViewFinderFactory: + @staticmethod + def create(name, config) -> ViewFinder: + general_config = config["general"] + view_finder_config = config["view_finder"][name] + if name == "gradient_field": + return GradientFieldViewFinder( + pose_mode=view_finder_config["pose_mode"], + regression_head=view_finder_config["regression_head"], + per_point_feature=general_config["per_point_feature"], + sample_mode=view_finder_config["sample_mode"], + sampling_steps=view_finder_config.get("sampling_steps", None), + sde_mode=view_finder_config["sde_mode"] + ) + else: + raise ValueError(f"Unknown next-best-view finder name: {name}") + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + from configs.config import ConfigManager + import torch + + ConfigManager.load_config_with('../../configs/local_train_config.yaml') + ConfigManager.print_config() + view_finder = ViewFinderFactory.create(name="gradient_field", config=ConfigManager.get("modules")) + test_scene_feat = torch.rand(32, 1024).to("cuda:0") + test_target_feat = torch.rand(32, 1024).to("cuda:0") + test_pose = torch.rand(32, 6).to("cuda:0") + test_t = torch.rand(32, 1).to("cuda:0") + view_finder = view_finder.to("cuda:0") + test_data = { + 'target_feat': test_target_feat, + 'scene_feat': test_scene_feat, + 'sampled_pose': test_pose, + 't': test_t + } + score = view_finder(test_data) + print(score.shape) + pose_6d = view_finder.next_best_view(scene_pts_feat=test_data["scene_feat"], target_pts_feat=test_data["target_feat"]) + print(pose_6d.shape) \ No newline at end of file diff --git a/src/active_grasp/active_perception/utils/__init__.py b/src/active_grasp/active_perception/utils/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/src/active_grasp/active_perception/utils/cache_util.py b/src/active_grasp/active_perception/utils/cache_util.py new file mode 100755 index 0000000..3226d37 --- /dev/null +++ b/src/active_grasp/active_perception/utils/cache_util.py @@ -0,0 +1,19 @@ +from collections import OrderedDict + +class LRUCache: + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def get(self, key): + if key not in self.cache: + return None + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key, value): + if key in self.cache: + self.cache.move_to_end(key) + elif len(self.cache) >= self.capacity: + self.cache.popitem(last=False) + self.cache[key] = value diff --git a/src/active_grasp/active_perception/utils/file_util.py b/src/active_grasp/active_perception/utils/file_util.py new file mode 100755 index 0000000..200ee8b --- /dev/null +++ b/src/active_grasp/active_perception/utils/file_util.py @@ -0,0 +1,83 @@ +import os +import pickle +import json + +import numpy as np + + +class FileUtil: + @staticmethod + def get_path(file_name, target_dir=None): + if target_dir is None: + file_path = file_name + else: + file_path = os.path.join(target_dir, file_name) + return file_path + + @staticmethod + def load_pickle(file_name, target_dir=None): + file_path = FileUtil.get_path(file_name, target_dir) + with open(file_path, "rb") as f: + return pickle.load(f) + + @staticmethod + def save_pickle(data, file_name, target_dir=None): + file_path = FileUtil.get_path(file_name, target_dir) + with open(file_path, "wb") as f: + pickle.dump(data, f) + return True + + @staticmethod + def load_json(file_name, target_dir=None): + file_path = FileUtil.get_path(file_name, target_dir) + with open(file_path, "r") as f: + return json.load(f) + + @staticmethod + def save_json(data, file_name, target_dir=None): + file_path = FileUtil.get_path(file_name, target_dir) + with open(file_path, "w") as f: + json.dump(data, f) + return True + + @staticmethod + def save_np_txt(np_data, file_name, target_dir=None): + if len(np_data.shape) > 2: + raise ValueError("Only 2D arrays are supported.") + file_path = FileUtil.get_path(file_name, target_dir) + np.savetxt(file_path, np_data) + + @staticmethod + def load_np_txt(file_name, target_dir=None, shuffle=False): + file_path = FileUtil.get_path(file_name, target_dir) + np_data = np.loadtxt(file_path) + if shuffle: + indices = np.arange(np_data.shape[0]) + np.random.shuffle(indices) + np_data_shuffled = np_data[indices] + return np_data_shuffled + else: + return np_data + + @staticmethod + def find_object_models(path): + obj_files = {} + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith(".obj"): + full_path = os.path.join(root, file) + modified_name = full_path.replace(path, "").replace(os.sep, "_").rstrip(".obj") + if modified_name.startswith("_"): + modified_name = modified_name[1:] + obj_files[modified_name] = full_path + return obj_files + + +''' ------------ Debug ------------ ''' +if __name__ == "__main__": + arr2d = np.random.random((4, 3)) + print(arr2d) + np.savetxt("test.txt", arr2d) + loaded_arr2d = FileUtil.load_np_txt("test.txt") + print() + print(loaded_arr2d) diff --git a/src/active_grasp/active_perception/utils/metric_util.py b/src/active_grasp/active_perception/utils/metric_util.py new file mode 100755 index 0000000..de532d9 --- /dev/null +++ b/src/active_grasp/active_perception/utils/metric_util.py @@ -0,0 +1,124 @@ +import numpy as np + + +class MetricUtil: + + @staticmethod + def rotate_around(axis, angle_deg): + angle = angle_deg * np.pi / 180 + if axis == "x": + return np.array([[1, 0, 0], + [0, np.cos(angle), -np.sin(angle)], + [0, np.sin(angle), np.cos(angle)]]) + elif axis == "y": + return np.array([[np.cos(angle), 0, np.sin(angle)], + [0, 1, 0], + [-np.sin(angle), 0, np.cos(angle)]]) + elif axis == "z": + return np.array([[np.cos(angle), -np.sin(angle), 0], + [np.sin(angle), np.cos(angle), 0], + [0, 0, 1]]) + else: + raise ValueError("Invalid axis") + + @staticmethod + def basic_rot_diff(r0, r1): + mat_diff = np.matmul(r0, r1.swapaxes(-1, -2)) + diff = np.trace(mat_diff) - 1 + return np.arccos(np.clip(diff / 2.0, a_min=-1.0, a_max=1.0)) + + @staticmethod + def axis_rot_diff(r0, r1, axis): + axis1, axis2 = r0[..., axis], r1[..., axis] + diff = np.sum(axis1 * axis2, axis=-1) + return np.arccos(np.clip(diff, a_min=-1.0, a_max=1.0)) + + @staticmethod + def turn_rot_diff(r0, r1, axis, turn_degrees): + diffs = [] + for i in turn_degrees: + rotation_matrix = MetricUtil.rotate_around(axis, i) + diffs.append(MetricUtil.basic_rot_diff(np.matmul(r0, rotation_matrix), r1)) + return np.min(diffs, axis=0) + + @staticmethod + def rot_diff_rad(r0, r1, sym): + + axis_map = {0: "x", 1: "y", 2: "z"} + if sym is None or sym == 0: # no symmetry + return MetricUtil.basic_rot_diff(r0, r1) + elif sym in [1, 2, 3]: # free rotation around axis + return MetricUtil.axis_rot_diff(r0, r1, sym - 1) + else: # symmetry + turns = 0 + axis_idx = 0 + if sym in [4, 5, 6]: # half turn + axis_idx = sym - 4 + turns = 2 + elif sym in [7, 8, 9]: # quarter turn + axis_idx = sym - 7 + turns = 4 + turn_degrees = np.arange(0, 360, 360 / turns) + return MetricUtil.turn_rot_diff(r0, r1, axis_map[axis_idx], turn_degrees) + + @staticmethod + def collect_metric(pred_pose_mat, gt_pose_mat, sym): + pred_rot_mat = pred_pose_mat[:, :3, :3] + gt_rot_mat = gt_pose_mat[:, :3, :3] + pred_trans = pred_pose_mat[:, :3, 3] + gt_trans = gt_pose_mat[:, :3, 3] + + trans_error = [] + rot_error = [] + for i in range(pred_rot_mat.shape[0]): + tdiff = np.linalg.norm(pred_trans[i] - gt_trans[i], ord=2) * 100 + rdiff = MetricUtil.rot_diff_rad(pred_rot_mat[i], gt_rot_mat[i], sym[i]) / np.pi * 180.0 + trans_error.append(tdiff) + rot_error.append(rdiff) + + rot_error = { + 'mean': np.mean(rot_error), + 'median': np.median(rot_error), + 'item': rot_error, + } + trans_error = { + 'mean': np.mean(trans_error), + 'median': np.median(trans_error), + 'item': trans_error, + } + error = {'rot_error': rot_error, + 'trans_error': trans_error} + return error + + +# -------------- Debug --------------- + +def test_MetricUtil(): + print("test case 0: no rotation") + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), 0) * 180 / np.pi) + print("test case 1: 29 degree rotation around x-axis") + rotation_matrix = MetricUtil.rotate_around("x", 29) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 0) * 180 / np.pi) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 1) * 180 / np.pi) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 8) * 180 / np.pi) + print("test case 2: 90 degree rotation around y-axis") + rotation_matrix = MetricUtil.rotate_around("y", 90) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 0) * 180 / np.pi) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 2) * 180 / np.pi) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 8) * 180 / np.pi) + print("test case 3: 60 degree rotation around y-axis") + rotation_matrix = MetricUtil.rotate_around("y", 60) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 0) * 180 / np.pi) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 2) * 180 / np.pi) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 8) * 180 / np.pi) + print("test case 4: 78 degree rotation around z-axis and 60 degree rotation around x-axis") + rotation_matrix = MetricUtil.rotate_around("z", 78) @ MetricUtil.rotate_around("x", 60) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 0) * 180 / np.pi) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 2) * 180 / np.pi) + print(MetricUtil.rot_diff_rad(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), rotation_matrix, 8) * 180 / np.pi) + + +if __name__ == "__main__": + pass + test_MetricUtil() diff --git a/src/active_grasp/active_perception/utils/omni_util.py b/src/active_grasp/active_perception/utils/omni_util.py new file mode 100755 index 0000000..d407588 --- /dev/null +++ b/src/active_grasp/active_perception/utils/omni_util.py @@ -0,0 +1,439 @@ +import numpy as np +import pickle +import json +import pickle +import cv2 +import os +import re +from scipy.spatial.transform import Rotation as R + +class DepthToPCL: + + def __new__(cls, *args, **kwargs): + raise RuntimeError( + "Use init_from_disk or init_from_memory to create an instance" + ) + + @classmethod + def _initialize( + cls, + distance_to_camera_path=None, + rgb_path=None, + camera_params_path=None, + seg_path=None, + seg_label_path=None, + depth=None, + rgb=None, + seg=None, + seg_label=None, + camera_params=None, + ): + instance = super().__new__(cls) + instance._distance_to_camera_path = distance_to_camera_path + instance._rgb_path = rgb_path + instance._camera_params_path = camera_params_path + instance._seg_path = seg_path + instance._seg_label_path = seg_label_path + instance._depth = depth + instance._rgb = rgb + instance._seg = seg + instance._seg_label = seg_label + instance._camera_params = camera_params + + if any( + path is not None + for path in [ + distance_to_camera_path, + rgb_path, + camera_params_path, + seg_path, + seg_label_path, + ] + ): + instance._load_from_disk() + + instance._setup() + return instance + + @classmethod + def init_from_disk( + cls, + distance_to_camera_path, + rgb_path, + camera_params_path, + seg_path, + seg_label_path, + ): + return cls._initialize( + distance_to_camera_path=distance_to_camera_path, + rgb_path=rgb_path, + camera_params_path=camera_params_path, + seg_path=seg_path, + seg_label_path=seg_label_path, + ) + + @classmethod + def init_from_memory(cls, depth, rgb, seg, seg_label, camera_params): + return cls._initialize( + depth=depth, + rgb=rgb, + seg=seg, + seg_label=seg_label, + camera_params=camera_params, + ) + + def _load_from_disk(self): + self._depth = np.load(self._distance_to_camera_path) + self._seg = cv2.imread(self._seg_path, cv2.IMREAD_UNCHANGED) + + with open(self._seg_label_path, "r") as f: + self._seg_label = json.load(f) + with open(self._camera_params_path) as f: + self._camera_params = json.load(f) + + def _setup(self): + self._read_camera_params() + self._get_intrinsic_matrix() + + def _read_camera_params(self): + self._h_aperture = self._camera_params["cameraAperture"][0] + self._v_aperture = self._camera_params["cameraAperture"][1] + self._h_aperture_offset = self._camera_params["cameraApertureOffset"][0] + self._v_aperture_offset = self._camera_params["cameraApertureOffset"][1] + self._focal_length = self._camera_params["cameraFocalLength"] + self._h_resolution = self._camera_params["renderProductResolution"][0] + self._v_resolution = self._camera_params["renderProductResolution"][1] + self._cam_t = self._camera_params["cameraViewTransform"] + + def _get_intrinsic_matrix(self): + self._focal_x = self._h_resolution * self._focal_length / self._h_aperture + self._focal_y = self._v_resolution * self._focal_length / self._v_aperture + self._center_x = self._h_resolution / 2 + self._center_y = self._v_resolution / 2 + self.intrinsic_matrix = np.array( + [ + [self._focal_x, 0, self._center_x], + [0, self._focal_y, self._center_y], + [0, 0, 1], + ] + ) + return self.intrinsic_matrix + + def _get_extrinsic_matrix(self): + self._cam_pose = np.linalg.inv(np.resize(self._cam_t, (4, 4))).T.dot( + np.mat([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1.0], [0, 0, 0, 1]]) + ) + return self._cam_pose + + + + def get_pcd(self, target_name=None): + u_indices, v_indices = np.meshgrid( + np.arange(self._h_resolution), np.arange(self._v_resolution) + ) + x_factors = (u_indices - self._center_x) / self._focal_x + y_factors = (v_indices - self._center_y) / self._focal_y + if target_name is not None: + if target_name == OmniUtil.FOREGROUND: + unlabelled_mask = self.get_mask_rgba( + self._seg_label, OmniUtil.UNLABELLED + ) + background_mask = self.get_mask_rgba( + self._seg_label, OmniUtil.BACKGROUND + ) + if unlabelled_mask is None: + target_mask = (self._seg != background_mask).any(axis=2) + else: + target_mask = (self._seg != unlabelled_mask).any(axis=2) & ( + self._seg != background_mask + ).any(axis=2) + else: + target_mask = ( + self._seg == self.get_mask_rgba(self._seg_label, target_name) + ).all(axis=2) + else: + target_mask = np.ones((self._v_resolution, self._h_resolution), dtype=bool) + valid_x_factors = x_factors[target_mask] + valid_y_factors = y_factors[target_mask] + valid_z_factors = self._depth[target_mask] + points = np.stack([valid_x_factors, valid_y_factors, valid_z_factors], axis=1) + return points + + @staticmethod + def get_mask_rgba(mask_labels, mask_name): + name_list = [name_dict["class"] for name_dict in list(mask_labels.values())] + if mask_name not in name_list: + return None + rgba_list = list(mask_labels.keys()) + mask_rgba_str = rgba_list[name_list.index(mask_name)] + r, g, b, a = re.findall("\d+", mask_rgba_str) + r, g, b, a = int(b), int(g), int(r), int(a) + return r, g, b, a + + def get_segmented_pcd(self, target_list, N=15000): + u_indices, v_indices = np.meshgrid( + np.arange(self._h_resolution), np.arange(self._v_resolution) + ) + x_factors = (u_indices - self._center_x) / self._focal_x + y_factors = (v_indices - self._center_y) / self._focal_y + points_dict = {} + total_points_with_label = [] + for target_idx in range(len(target_list)): + target_name = target_list[target_idx] + target_mask = ( + self._seg == self.get_mask_rgba(self._seg_label, target_name) + ).all(axis=2) + valid_x_factors = x_factors[target_mask] + valid_y_factors = y_factors[target_mask] + valid_z_factors = self._depth[target_mask] + label = np.ones_like(valid_x_factors) * target_idx + target_points_with_label = np.stack( + [valid_x_factors, valid_y_factors, valid_z_factors, label], axis=1 + ) + total_points_with_label.append(target_points_with_label) + total_points_with_label = np.concatenate(total_points_with_label, axis=0) + total_points_with_label = self.sample_pcl(total_points_with_label, N) + total_points = total_points_with_label[:, :3] + for target_idx in range(len(target_list)): + target_name = target_list[target_idx] + pts_seg = total_points_with_label[:, 3] == target_idx + points_dict[target_name] = total_points_with_label[pts_seg, :3] + + return total_points, points_dict + + def get_rgb(self): + return self._rgb + + @staticmethod + def sample_pcl(pcl, n_pts=1024): + indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts) + return pcl[indices, :] + + +class OmniUtil: + FOREGROUND = "FOREGROUND" + BACKGROUND = "BACKGROUND" + UNLABELLED = "UNLABELLED" + NON_OBJECT_LIST = ['chair_028', 'chair_029', 'chair_026', 'chair_027', 'table_025', 'table_027', 'table_026', 'table_028', 'sofa_014', 'sofa_013', 'picnic_basket_010', 'picnic_basket_011', 'cabinet_009', 'flower_pot_023', 'flower_pot_022', 'flower_pot_021', 'chair_017', 'chair_020', 'chair_012', 'chair_010', 'chair_018', 'chair_025', 'chair_024', 'chair_011', 'chair_001', 'chair_013', 'chair_004', 'chair_021', 'chair_023', 'chair_006', 'chair_014', 'chair_007', 'chair_003', 'chair_009', 'chair_022', 'chair_015', 'chair_016', 'chair_008', 'chair_005', 'chair_019', 'chair_002', 'table_004', 'table_023', 'table_014', 'table_024', 'table_019', 'table_022', 'table_007', 'table_017', 'table_013', 'table_002', 'table_016', 'table_009', 'table_008', 'table_003', 'table_015', 'table_001', 'table_018', 'table_005', 'table_020', 'table_021', 'sofa_001', 'sofa_005', 'sofa_012', 'sofa_009', 'sofa_006', 'sofa_008', 'sofa_011', 'sofa_004', 'sofa_003', 'sofa_002', 'sofa_007', 'sofa_010', 'picnic_basket_005', 'picnic_basket_004', 'picnic_basket_001', 'picnic_basket_008', 'picnic_basket_002', 'picnic_basket_009', 'picnic_basket_006', 'picnic_basket_003', 'picnic_basket_007', 'cabinet_006', 'cabinet_008', 'cabinet_002', 'cabinet_001', 'cabinet_005', 'cabinet_007', 'flower_pot_013', 'flower_pot_005', 'flower_pot_008', 'flower_pot_001', 'flower_pot_003', 'flower_pot_020', 'flower_pot_006', 'flower_pot_012', 'flower_pot_018', 'flower_pot_007', 'flower_pot_002', 'flower_pot_011', 'flower_pot_010', 'flower_pot_016', 'flower_pot_004', 'flower_pot_014', 'flower_pot_017', 'flower_pot_019'] + CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json" + DISTANCE_TEMPLATE = "distance_to_image_plane_{}.npy" + RGB_TEMPLATE = "rgb_{}.png" + MASK_TEMPLATE = "semantic_segmentation_{}.png" + MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json" + SCORE_LABEL_TEMPLATE = "label_{}.json" + RGB_FEAT_TEMPLATE = "rgb_feat_{}.npy" + + @staticmethod + def get_depth_to_pointcloud_instance(path): + root, idx = path[:-4], path[-4:] + distance2plane_path = os.path.join(root, OmniUtil.DISTANCE_TEMPLATE.format(idx)) + rgb_path = os.path.join(root, OmniUtil.RGB_TEMPLATE.format(idx)) + cam_params_path = os.path.join( + root, OmniUtil.CAMERA_PARAMS_TEMPLATE.format(idx) + ) + seg_path = os.path.join(root, OmniUtil.MASK_TEMPLATE.format(idx)) + seg_labels_path = os.path.join(root, OmniUtil.MASK_LABELS_TEMPLATE.format(idx)) + depth_to_pcd = DepthToPCL.init_from_disk( + distance2plane_path, rgb_path, cam_params_path, seg_path, seg_labels_path + ) + return depth_to_pcd + + @staticmethod + def get_points(path, object_name=None): + depth_to_pcd = OmniUtil.get_depth_to_pointcloud_instance(path) + pcd = depth_to_pcd.get_pcd(object_name) + points = np.asarray(pcd, dtype=np.float32) + return points + + @staticmethod + def get_segmented_points(path, target_list): + depth_to_pcd = OmniUtil.get_depth_to_pointcloud_instance(path) + total_points, target_points_dict = depth_to_pcd.get_segmented_pcd(target_list) + return total_points, target_points_dict + + @staticmethod + def get_object_list(path, contains_non_obj=False): + root, idx = path[:-4], path[-4:] + seg_labels_path = os.path.join(root, OmniUtil.MASK_LABELS_TEMPLATE.format(idx)) + with open(seg_labels_path, "r") as f: + seg_labels = json.load(f) + object_list = [v["class"] for v in seg_labels.values()] + + object_list.remove(OmniUtil.BACKGROUND) + if OmniUtil.UNLABELLED in object_list: + object_list.remove(OmniUtil.UNLABELLED) + occluder_list = pickle.load(open(os.path.join(root,"occluder.pickle"), "rb")) + fall_objects_list = pickle.load(open(os.path.join(root,"fall_objects.pickle"), "rb")) + non_obj_list = occluder_list + fall_objects_list + if not contains_non_obj: + for non_obj in non_obj_list: + if non_obj in object_list: + object_list.remove(non_obj) + return object_list + + @staticmethod + def get_rotation_mat(path): + root, idx = os.path.split(path) + camera_params_path = os.path.join( + root, OmniUtil.CAMERA_PARAMS_TEMPLATE.format(idx) + ) + with open(camera_params_path, "r") as f: + raw_camera_params = json.load(f) + cam_transform = np.asarray(raw_camera_params["cameraViewTransform"]).reshape( + (4, 4) + ) + cam_rot_mat = cam_transform[:3, :3].dot( + np.mat([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + ) + return cam_rot_mat + + @staticmethod + def get_rgb(path): + root, idx = os.path.split(path) + rgb_path = os.path.join(root, OmniUtil.RGB_TEMPLATE.format(idx)) + rgb = cv2.imread(rgb_path) + return cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) + + @staticmethod + def get_depth(path): + root, idx = os.path.split(path) + depth_path = os.path.join(root, OmniUtil.DISTANCE_TEMPLATE.format(idx)) + depth = np.load(depth_path) + return depth + + @staticmethod + def get_seg_data(path): + root, idx = os.path.split(path) + seg_labels_path = os.path.join(root, OmniUtil.MASK_LABELS_TEMPLATE.format(idx)) + with open(seg_labels_path, "r") as f: + seg_labels = json.load(f) + seg_path = os.path.join(root, OmniUtil.MASK_TEMPLATE.format(idx)) + seg = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED) + return seg, seg_labels + + @staticmethod + def get_single_seg(path, object_name): + root, idx = os.path.split(path) + seg_labels_path = os.path.join(root, OmniUtil.MASK_LABELS_TEMPLATE.format(idx)) + with open(seg_labels_path, "r") as f: + seg_labels = json.load(f) + seg_path = os.path.join(root, OmniUtil.MASK_TEMPLATE.format(idx)) + seg = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED) + object_mask = ( + seg == OmniUtil.get_mask_rgba(seg_labels, object_name) + ).all(axis=2) + return object_mask + + + @staticmethod + def get_mask_rgba(mask_labels, mask_name): + name_list = [name_dict["class"] for name_dict in list(mask_labels.values())] + if mask_name not in name_list: + return None + rgba_list = list(mask_labels.keys()) + mask_rgba_str = rgba_list[name_list.index(mask_name)] + r, g, b, a = re.findall("\d+", mask_rgba_str) + r, g, b, a = int(b), int(g), int(r), int(a) + return r, g, b, a + + @staticmethod + def get_rgb_feat(path): + root, idx = os.path.split(path) + rgb_feat_path = os.path.join(root, OmniUtil.RGB_FEAT_TEMPLATE.format(idx)) + rgb_feat = np.load(rgb_feat_path) + return rgb_feat + + @staticmethod + def get_target_object_list(path): + return OmniUtil.get_object_list(path, contains_non_obj=False) # TODO: generalize this + + + @staticmethod + def get_transform_mat(path): + root, idx = os.path.split(path) + camera_params_path = os.path.join( + root, OmniUtil.CAMERA_PARAMS_TEMPLATE.format(idx) + ) + with open(camera_params_path, "r") as f: + raw_camera_params = json.load(f) + cam_transform = np.asarray(raw_camera_params["cameraViewTransform"]).reshape( + (4, 4) + ) + real_cam_transform = np.linalg.inv(cam_transform).T + real_cam_transform = real_cam_transform.dot( + np.mat([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + ) + return real_cam_transform + + @staticmethod + def get_intrinsic_matrix(path): + root, idx = os.path.split(path) + camera_params_path = os.path.join( + root, OmniUtil.CAMERA_PARAMS_TEMPLATE.format(idx) + ) + with open(camera_params_path, "r") as f: + raw_camera_params = json.load(f) + h_aperture = raw_camera_params["cameraAperture"][0] + v_aperture = raw_camera_params["cameraAperture"][1] + focal_length = raw_camera_params["cameraFocalLength"] + h_resolution = raw_camera_params["renderProductResolution"][0] + v_resolution = raw_camera_params["renderProductResolution"][1] + focal_x = h_resolution * focal_length / h_aperture + focal_y = v_resolution * focal_length / v_aperture + center_x = h_resolution / 2 + center_y = v_resolution / 2 + intrinsic_matrix = np.array( + [ + [focal_x, 0, center_x], + [0, focal_y, center_y], + [0, 0, 1], + ] + ) + return intrinsic_matrix + + @staticmethod + def get_extrinsic_matrix(path): + root, idx = os.path.split(path) + camera_params_path = os.path.join( + root, OmniUtil.CAMERA_PARAMS_TEMPLATE.format(idx) + ) + with open(camera_params_path, "r") as f: + raw_camera_params = json.load(f) + cam_transform = np.asarray(raw_camera_params["cameraViewTransform"]).reshape( + (4, 4) + ) + real_cam_transform = np.linalg.inv(cam_transform).T + real_cam_transform = real_cam_transform.dot( + np.mat([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + ) + return real_cam_transform + + @staticmethod + def get_scene_data(path): + root, _ = os.path.split(path) + scene_data_path = os.path.join( + root, "scene.pickle" + ) + with open(scene_data_path, "rb") as f: + scene_data = pickle.load(f) + return scene_data + + @staticmethod + def get_o2c_pose(path, object_name): + scene_data = OmniUtil.get_scene_data(path) + cam_pose = OmniUtil.get_extrinsic_matrix(path) + pos = scene_data[object_name]["position"] + quat = scene_data[object_name]["rotation"] + rot = R.from_quat(quat).as_matrix() + obj_pose = np.eye(4) + obj_pose[:3, :3] = rot + obj_pose[:3, 3] = pos + obj_cam_pose = np.linalg.inv(cam_pose) @ obj_pose + return np.asarray(obj_cam_pose) + +if __name__ == "__main__": + test_path = r"/mnt/h/AI/Datasets/nbv1/sample_one/scene_0/0050" + obj_list = OmniUtil.get_object_list(test_path, contains_non_obj=True) + print(obj_list) + pts = OmniUtil.get_segmented_points(test_path, target_list=obj_list) + np.savetxt("pts1.txt", pts) diff --git a/src/active_grasp/active_perception/utils/pcl_util.py b/src/active_grasp/active_perception/utils/pcl_util.py new file mode 100755 index 0000000..a94cb5e --- /dev/null +++ b/src/active_grasp/active_perception/utils/pcl_util.py @@ -0,0 +1,78 @@ +import numpy as np +import torch +from scipy.spatial.distance import cdist + + +class PclUtil: + CHAMFER = 1 + + @staticmethod + def transform(pts, pose=np.eye(4), scale=np.ones(3), inverse=False): + aug_scale = np.ones(4) + aug_scale[:3] = scale + aug_scale_mat = np.diag(aug_scale) + scale_pose = pose @ aug_scale_mat + aug_pts = np.hstack((pts, np.ones((pts.shape[0], 1)))) + if inverse: + scale_pose = np.linalg.inv(scale_pose) + transformed_pts = scale_pose @ aug_pts.T + return transformed_pts.T[:, :3] + + @staticmethod + def cam2canonical(cam_pts, cam2canonical_pose): + aug_pts = np.hstack((cam_pts, np.ones((cam_pts.shape[0], 1)))) + transformed_pts = cam2canonical_pose @ aug_pts.T + return transformed_pts.T[:, :3] + + @staticmethod + def transform_batch(pts, pose, scale, inverse=False): + batch_size = pts.shape[0] + aug_scale_mat = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1) + for i in range(3): + aug_scale_mat[..., i, i] = scale[..., i] + scale_pose = pose @ aug_scale_mat + aug_pts = torch.cat((pts, torch.ones_like(pts[..., :1])), dim=-1) + if inverse: + scale_pose = torch.inverse(scale_pose) + transformers_pts = scale_pose @ aug_pts.transpose(1, 2) + return transformers_pts.transpose(1, 2)[..., :3] + + @staticmethod + def transform_n_batch(pts, pose, scale=None, inverse=False): + transformed_pts_shape = (pts.shape[0], pose.shape[1], pts.shape[1], pts.shape[2]) + transformed_pts = np.zeros(transformed_pts_shape) + batch_size = pose.shape[0] + n = pose.shape[1] + if scale is None: + scale = np.ones((batch_size, n, 3)) + for batch_i in range(batch_size): + for i in range(n): + transformed_pts[batch_i, i, :, :] = PclUtil.transform(pts[batch_i], pose[batch_i, i], + scale[batch_i, i], inverse) + return transformed_pts + + @staticmethod + def chamfer_distance(pts1, pts2): + dist_matrix1 = cdist(pts1, pts2, 'euclidean') + dist_matrix2 = cdist(pts2, pts1, 'euclidean') + chamfer_dist = np.mean(np.min(dist_matrix1, axis=1)) + np.mean(np.min(dist_matrix2, axis=1)) + return chamfer_dist + + @staticmethod + def distance(pts1, pts2, eval_type=1): + if eval_type == PclUtil.CHAMFER: + return PclUtil.chamfer_distance(pts1, pts2) + else: + raise ValueError('Unknown evaluation type:', eval_type) + + @staticmethod + def sample_pcl(pcl, n_pts=1024): + indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts) + return pcl[indices, :] + + +if __name__ == '__main__': + batch_pts = np.random.random((2, 16, 3)) + batch_n_pose = np.random.random((2, 3, 4, 4)) + batch_n_scale = np.random.random((2, 3, 3)) + poses = PclUtil.transform_n_batch(batch_pts, batch_n_pose, batch_n_scale) diff --git a/src/active_grasp/active_perception/utils/pose_util.py b/src/active_grasp/active_perception/utils/pose_util.py new file mode 100755 index 0000000..0a88f33 --- /dev/null +++ b/src/active_grasp/active_perception/utils/pose_util.py @@ -0,0 +1,188 @@ +import numpy as np +import torch +import torch.nn.functional as F + +class PoseUtil: + ROTATION = 1 + TRANSLATION = 2 + SCALE = 3 + + @staticmethod + def get_uniform_translation(trans_m_min, trans_m_max, trans_unit, debug=False): + if isinstance(trans_m_min, list): + x_min, y_min, z_min = trans_m_min + x_max, y_max, z_max = trans_m_max + else: + x_min, y_min, z_min = trans_m_min, trans_m_min, trans_m_min + x_max, y_max, z_max = trans_m_max, trans_m_max, trans_m_max + + x = np.random.uniform(x_min, x_max) + y = np.random.uniform(y_min, y_max) + z = np.random.uniform(z_min, z_max) + translation = np.array([x, y, z]) + if trans_unit == "cm": + translation = translation / 100 + if debug: + print("uniform translation:", translation) + return translation + + @staticmethod + def get_uniform_rotation(rot_degree_min=0, rot_degree_max=180, debug=False): + axis = np.random.randn(3) + axis /= np.linalg.norm(axis) + theta = np.random.uniform(rot_degree_min / 180 * np.pi, rot_degree_max / 180 * np.pi) + + K = np.array([[0, -axis[2], axis[1]], + [axis[2], 0, -axis[0]], + [-axis[1], axis[0], 0]]) + R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K) + if debug: + print("uniform rotation:", theta * 180 / np.pi) + return R + + @staticmethod + def get_uniform_pose(trans_min, trans_max, rot_min=0, rot_max=180, trans_unit="cm", debug=False): + translation = PoseUtil.get_uniform_translation(trans_min, trans_max, trans_unit, debug) + rotation = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + pose = np.eye(4) + pose[:3, :3] = rotation + pose[:3, 3] = translation + return pose + + @staticmethod + def get_n_uniform_pose(trans_min, trans_max, rot_min=0, rot_max=180, n=1, + trans_unit="cm", fix=None, contain_canonical=True, debug=False): + if fix == PoseUtil.ROTATION: + translations = np.zeros((n, 3)) + for i in range(n): + translations[i] = PoseUtil.get_uniform_translation(trans_min, trans_max, trans_unit, debug) + if contain_canonical: + translations[0] = np.zeros(3) + rotations = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + elif fix == PoseUtil.TRANSLATION: + rotations = np.zeros((n, 3, 3)) + for i in range(n): + rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + if contain_canonical: + rotations[0] = np.eye(3) + translations = PoseUtil.get_uniform_translation(trans_min, trans_max, trans_unit, debug) + else: + translations = np.zeros((n, 3)) + rotations = np.zeros((n, 3, 3)) + for i in range(n): + translations[i] = PoseUtil.get_uniform_translation(trans_min, trans_max, trans_unit, debug) + for i in range(n): + rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + if contain_canonical: + translations[0] = np.zeros(3) + rotations[0] = np.eye(3) + + pose = np.eye(4, 4, k=0)[np.newaxis, :].repeat(n, axis=0) + pose[:, :3, :3] = rotations + pose[:, :3, 3] = translations + + return pose + + @staticmethod + def get_n_uniform_pose_batch(trans_min, trans_max, rot_min=0, rot_max=180, n=1, batch_size=1, + trans_unit="cm", fix=None, contain_canonical=False, debug=False): + + batch_poses = [] + for i in range(batch_size): + pose = PoseUtil.get_n_uniform_pose(trans_min, trans_max, rot_min, rot_max, n, + trans_unit, fix, contain_canonical, debug) + batch_poses.append(pose) + pose_batch = np.stack(batch_poses, axis=0) + return pose_batch + + @staticmethod + def get_uniform_scale(scale_min, scale_max, debug=False): + if isinstance(scale_min, list): + x_min, y_min, z_min = scale_min + x_max, y_max, z_max = scale_max + else: + x_min, y_min, z_min = scale_min, scale_min, scale_min + x_max, y_max, z_max = scale_max, scale_max, scale_max + + x = np.random.uniform(x_min, x_max) + y = np.random.uniform(y_min, y_max) + z = np.random.uniform(z_min, z_max) + scale = np.array([x, y, z]) + if debug: + print("uniform scale:", scale) + return scale + + @staticmethod + def normalize_rotation(rotation, rotation_mode): + if rotation_mode == 'quat_wxyz' or rotation_mode == 'quat_xyzw': + rotation /= torch.norm(rotation, dim=-1, keepdim=True) + elif rotation_mode == 'rot_matrix': + rot_matrix = PoseUtil.rotation_6d_to_matrix_tensor_batch(rotation) + rotation[:, :3] = rot_matrix[:, 0, :] + rotation[:, 3:6] = rot_matrix[:, 1, :] + elif rotation_mode == 'euler_xyz_sx_cx': + rot_sin_theta = rotation[:, :3] + rot_cos_theta = rotation[:, 3:6] + theta = torch.atan2(rot_sin_theta, rot_cos_theta) + rotation[:, :3] = torch.sin(theta) + rotation[:, 3:6] = torch.cos(theta) + elif rotation_mode == 'euler_xyz': + pass + else: + raise NotImplementedError + return rotation + + @staticmethod + def get_pose_dim(rot_mode): + assert rot_mode in ['quat_wxyz', 'quat_xyzw', 'euler_xyz', 'euler_xyz_sx_cx', 'rot_matrix'], \ + f"the rotation mode {rot_mode} is not supported!" + + if rot_mode == 'quat_wxyz' or rot_mode == 'quat_xyzw': + pose_dim = 4 + elif rot_mode == 'euler_xyz': + pose_dim = 3 + elif rot_mode == 'euler_xyz_sx_cx' or rot_mode == 'rot_matrix': + pose_dim = 6 + else: + raise NotImplementedError + return pose_dim + + @staticmethod + def rotation_6d_to_matrix_tensor_batch(d6: torch.Tensor) -> torch.Tensor: + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + @staticmethod + def matrix_to_rotation_6d_tensor_batch(matrix: torch.Tensor) -> torch.Tensor: + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) + + @staticmethod + def rotation_6d_to_matrix_numpy(d6): + a1, a2 = d6[:3], d6[3:] + b1 = a1 / np.linalg.norm(a1) + b2 = a2 - np.dot(b1, a2) * b1 + b2 = b2 / np.linalg.norm(b2) + b3 = np.cross(b1, b2) + return np.stack((b1, b2, b3),axis=-2) + + @staticmethod + def matrix_to_rotation_6d_numpy(matrix): + return np.copy(matrix[:2, :]).reshape((6,)) + + + +''' ------------ Debug ------------ ''' + +if __name__ == '__main__': + for _ in range(1): + PoseUtil.get_uniform_pose(trans_min=[-25, -25, 10], trans_max=[25, 25, 60], + rot_min=0, rot_max=10, debug=True) + PoseUtil.get_uniform_scale(scale_min=0.25, scale_max=0.30, debug=True) + PoseUtil.get_n_uniform_pose_batch(trans_min=[-25, -25, 10], trans_max=[25, 25, 60], + rot_min=0, rot_max=10, batch_size=2, n=2, fix=PoseUtil.TRANSLATION, debug=True) diff --git a/src/active_grasp/active_perception/utils/tensorboard_util.py b/src/active_grasp/active_perception/utils/tensorboard_util.py new file mode 100755 index 0000000..8e635a7 --- /dev/null +++ b/src/active_grasp/active_perception/utils/tensorboard_util.py @@ -0,0 +1,47 @@ +import torch + + +class TensorboardWriter: + @staticmethod + def write_tensorboard(writer, panel, data_dict, step): + complex_dict = False + if "scalars" in data_dict: + scalar_data_dict = data_dict["scalars"] + TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step) + complex_dict = True + if "images" in data_dict: + image_data_dict = data_dict["images"] + TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step) + complex_dict = True + if "points" in data_dict: + point_data_dict = data_dict["points"] + TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step) + complex_dict = True + + if not complex_dict: + TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step) + + @staticmethod + def write_scalar_tensorboard(writer, panel, data_dict, step): + for key, value in data_dict.items(): + if isinstance(value, dict): + writer.add_scalars(f'{panel}/{key}', value, step) + else: + writer.add_scalar(f'{panel}/{key}', value, step) + + @staticmethod + def write_image_tensorboard(writer, panel, data_dict, step): + pass + + @staticmethod + def write_points_tensorboard(writer, panel, data_dict, step): + for key, value in data_dict.items(): + if value.shape[-1] == 3: + colors = torch.zeros_like(value) + vertices = torch.cat([value, colors], dim=-1) + elif value.shape[-1] == 6: + vertices = value + else: + raise ValueError(f'Unexpected value shape: {value.shape}') + faces = None + writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step) diff --git a/src/active_grasp/active_perception/utils/view_util.py b/src/active_grasp/active_perception/utils/view_util.py new file mode 100755 index 0000000..02858b6 --- /dev/null +++ b/src/active_grasp/active_perception/utils/view_util.py @@ -0,0 +1,239 @@ +import json +import numpy as np +import requests +import torch +from PIL import Image + +from utils.cache_util import LRUCache + + +class ViewUtil: + view_cache = LRUCache(1024) + def load_camera_pose_from_frame(camera_params_path): + with open(camera_params_path, "r") as f: + camera_params = json.load(f) + + view_transform = camera_params["cameraViewTransform"] + view_transform = np.resize(view_transform, (4,4)) + view_transform = np.linalg.inv(view_transform).T + offset = np.mat([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]]) + view_transform = view_transform.dot(offset) + return view_transform + + def save_image(rgb, filename): + if rgb.dtype != np.uint8: + rgb = rgb.astype(np.uint8) + img = Image.fromarray(rgb, 'RGB') + img.save(filename) + + def save_depth(depth, filename): + if depth.dtype != np.uint16: + depth = depth.astype(np.uint16) + depth_img = Image.fromarray(depth) + depth_img.save(filename) + + def save_segmentation(seg, filename): + if seg.dtype != np.uint8: + seg = seg.astype(np.uint8) + seg_img = Image.fromarray(seg) + seg_img.save(filename) + + @staticmethod + def get_view(camera_pose,source, data_type,scene,port): + camera_pose_tuple = tuple(map(tuple, camera_pose.tolist())) + cache_key = (camera_pose_tuple, source, data_type, scene, port) + cached_result = ViewUtil.view_cache.get(cache_key) + if cached_result: + print("Cache hit") + return cached_result + + url = f"http://127.0.0.1:{port}/get_images" + headers = { + 'Content-Type': 'application/json' + } + data = { + 'camera_pose': camera_pose.tolist(), + 'data_type': data_type, + 'source': source, + 'scene': scene + } + response = requests.post(url, headers=headers, data=json.dumps(data)) + + if response.status_code == 200: + results = response.json() + + rgb = np.asarray(results['rgb'],dtype=np.uint8) + depth = np.asarray(results['depth'])/1000 + seg = np.asarray(results['segmentation']) + seg_labels = results['segmentation_labels'] + camera_params = results['camera_params'] + ViewUtil.view_cache.put(cache_key, (rgb, depth, seg, seg_labels, camera_params)) + return rgb, depth, seg, seg_labels, camera_params + else: + return None + + @staticmethod + def get_object_pose_batch(K, mesh, rgb_batch, depth_batch, mask_batch, gt_pose_batch ,port): + url = f"http://127.0.0.1:{port}/predict_estimation_batch" + headers = { + 'Content-Type': 'application/json' + } + mesh_data = { + 'vertices': mesh.vertices.tolist(), + 'faces': mesh.faces.tolist() + } + data = { + 'K': K.tolist(), + 'rgb_batch': rgb_batch.tolist(), + 'depth_batch': depth_batch.tolist(), + 'mask_batch': mask_batch.tolist(), + 'mesh': mesh_data, + 'gt_pose_batch': gt_pose_batch.tolist() + } + response = requests.post(url, headers=headers, data=json.dumps(data)) + + if response.status_code == 200: + results = response.json() + pose_batch = np.array(results['pose_batch']) + results_batch = results["eval_result_batch"] + return pose_batch, results_batch + else: + return None + + @staticmethod + def get_visualized_result(K, mesh, rgb, pose ,port): + url = f"http://127.0.0.1:{port}/get_visualized_result" + headers = { + 'Content-Type': 'application/json' + } + mesh_data = { + 'vertices': mesh.vertices.tolist(), + 'faces': mesh.faces.tolist() + } + data = { + 'K': K.tolist(), + 'rgb': rgb.tolist(), + 'mesh': mesh_data, + 'pose': pose.tolist() + } + response = requests.post(url, headers=headers, data=json.dumps(data)) + + if response.status_code == 200: + results = response.json() + vis_rgb = np.array(results['vis_rgb']) + return vis_rgb + else: + return None + + @staticmethod + def get_object_pose(K, mesh, rgb, depth, mask, gt_pose ,port): + url = f"http://127.0.0.1:{port}/predict_estimation" + headers = { + 'Content-Type': 'application/json' + } + mesh_data = { + 'vertices': mesh.vertices.tolist(), + 'faces': mesh.faces.tolist() + } + data = { + 'K': K.tolist(), + 'rgb': rgb.tolist(), + 'depth': depth.tolist(), + 'mask': mask.tolist(), + 'mesh': mesh_data, + 'gt_pose': gt_pose.tolist() + } + response = requests.post(url, headers=headers, data=json.dumps(data)) + + if response.status_code == 200: + results = response.json() + pose_batch = np.array(results['pose_batch']) + results_batch = results["eval_result_batch"] + return pose_batch, results_batch + else: + return None + + def get_pts_dict(depth, seg, seg_labels, camera_params): + cx = camera_params['cx'] + cy = camera_params['cy'] + fx = camera_params['fx'] + fy = camera_params['fy'] + width = camera_params['width'] + height = camera_params['height'] + pts_dict = {name: [] for name in seg_labels.values()} + u = np.arange(width) + v = np.arange(height) + u, v = np.meshgrid(u, v) + Z = depth + X = (u - cx) * Z / fx + Y = (v - cy) * Z / fy + points = np.stack((X, Y, Z), axis=-1).reshape(-1, 3) + labels = seg.reshape(-1) + for label, name in seg_labels.items(): + mask = labels == int(label) + pts_dict[name] = points[mask] + return pts_dict + + def get_object_center_from_pts_dict(obj,pts_dict): + if obj is None: + for _, pts in pts_dict.items(): + if pts.size != 0: + obj_pts = pts + break + else: + obj_pts = pts_dict[obj] + if obj_pts.size == 0: + for _, pts in pts_dict.items(): + if pts.size != 0: + obj_pts = pts + break + obj_center = obj_pts.mean(axis=0) + return obj_center + + def get_pts_center(pts): + pts_center = pts.mean(axis=0) + return pts_center + + def get_scene_pts(pts_dict): + if any(isinstance(pts, torch.Tensor) for pts in pts_dict.values()): + scene_pts = torch.cat([pts for _, pts in pts_dict.items()], dim=0) + return scene_pts + else: + scene_pts = np.concatenate([pts for _, pts in pts_dict.items()]) + return scene_pts + + def crop_pts(scene_pts, crop_center, radius=0.2): + if isinstance(scene_pts, torch.Tensor): + crop_mask = torch.norm(scene_pts - crop_center, dim=1) < radius + return scene_pts[crop_mask] + else: + crop_mask = np.linalg.norm(scene_pts - crop_center, axis=1) < radius + return scene_pts[crop_mask] + + def crop_pts_dict(pts_dict, crop_center, radius=0.2, min_pts_num = 5000): + crop_dict = {} + max_loop = 100 + loop = 0 + while(loop<=max_loop): + croped_length = 0 + for obj, pts in pts_dict.items(): + if isinstance(pts, torch.Tensor): + crop_mask = torch.norm(pts - crop_center, dim=1) < radius + crop_dict[obj] = pts[crop_mask] + else: + crop_mask = np.linalg.norm(pts - crop_center, axis=1) < radius + crop_dict[obj] = pts[crop_mask] + croped_length += crop_dict[obj].shape[0] + if croped_length >= min_pts_num: + break + radius += 0.02 + loop += 1 + return crop_dict + + def get_cam_pose_focused_on_point(point_w, cam_pose_w, old_camera_center_w): + distance = np.linalg.norm(point_w-old_camera_center_w) + z_axis_camera = cam_pose_w[:3, 2].reshape(-1) + new_camera_position_w = point_w - distance * z_axis_camera + new_camera_pose_w = cam_pose_w.copy() + new_camera_pose_w[:3, 3] = new_camera_position_w.reshape((3,1)) + return new_camera_pose_w \ No newline at end of file