success
This commit is contained in:
0
modules/rgb_encoder/__init__.py
Executable file
0
modules/rgb_encoder/__init__.py
Executable file
51
modules/rgb_encoder/abstract_rgb_encoder.py
Executable file
51
modules/rgb_encoder/abstract_rgb_encoder.py
Executable file
@@ -0,0 +1,51 @@
|
||||
from abc import abstractmethod
|
||||
from sklearn.decomposition import PCA
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RGBEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(RGBEncoder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def encode_rgb(self, rgb):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def visualize_features(features, save_path=None):
|
||||
patch,feat_dim = features.shape
|
||||
patch_h = int(patch ** 0.5)
|
||||
patch_w = patch_h
|
||||
total_features = features.reshape(patch_h * patch_w, feat_dim)
|
||||
pca = PCA(n_components=3)
|
||||
if isinstance(total_features, torch.Tensor):
|
||||
total_features = total_features.cpu().numpy()
|
||||
pca.fit(total_features)
|
||||
pca_features = pca.transform(total_features)
|
||||
pca_features[:, 0] = (pca_features[:, 0] - pca_features[:, 0].min()) / \
|
||||
(pca_features[:, 0].max() - pca_features[:, 0].min())
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.imshow(pca_features[:,0].reshape(patch_h, patch_w))
|
||||
pca_features_bg = pca_features[:, 0] > 0.5 # from first histogram
|
||||
pca_features_fg = np.ones_like(pca_features_bg)
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.imshow(pca_features_bg.reshape(patch_h, patch_w))
|
||||
pca.fit(total_features[pca_features_fg])
|
||||
pca_features_left = pca.transform(total_features[pca_features_fg])
|
||||
for i in range(3):
|
||||
pca_features_left[:, i] = (pca_features_left[:, i] - pca_features_left[:, i].min()) / (pca_features_left[:, i].max() - pca_features_left[:, i].min())
|
||||
|
||||
pca_features_rgb = pca_features.copy()
|
||||
pca_features_rgb[pca_features_bg] = 0
|
||||
pca_features_rgb[pca_features_fg] = pca_features_left
|
||||
pca_features_rgb = pca_features_rgb.reshape(1, patch_h, patch_w, 3)
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
if save_path:
|
||||
plt.imsave(save_path, pca_features_rgb[0])
|
||||
else:
|
||||
plt.imshow(pca_features_rgb[0])
|
||||
plt.show()
|
20
modules/rgb_encoder/dinov2_encoder.py
Executable file
20
modules/rgb_encoder/dinov2_encoder.py
Executable file
@@ -0,0 +1,20 @@
|
||||
|
||||
import torch
|
||||
from modules.rgb_encoder.abstract_rgb_encoder import RGBEncoder
|
||||
from annotations.external_module import external_freeze
|
||||
|
||||
@external_freeze
|
||||
class Dinov2Encoder(RGBEncoder):
|
||||
def __init__(self, model_name):
|
||||
super(Dinov2Encoder, self).__init__()
|
||||
self.model_name = model_name
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
self.dinov2 = torch.hub.load('modules/module_lib/dinov2', self.model_name, source='local').cuda()
|
||||
|
||||
def encode_rgb(self, rgb):
|
||||
with torch.no_grad():
|
||||
features_dict = self.dinov2.forward_features(rgb)
|
||||
features = features_dict['x_norm_patchtokens']
|
||||
return features
|
59
modules/rgb_encoder/rgb_encoder_factory.py
Executable file
59
modules/rgb_encoder/rgb_encoder_factory.py
Executable file
@@ -0,0 +1,59 @@
|
||||
import sys
|
||||
import os
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(3):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from modules.rgb_encoder.abstract_rgb_encoder import RGBEncoder
|
||||
from modules.rgb_encoder.dinov2_encoder import Dinov2Encoder
|
||||
|
||||
|
||||
class RGBEncoderFactory:
|
||||
@staticmethod
|
||||
def create(name, config) -> RGBEncoder:
|
||||
general_config = config["general"]
|
||||
rgb_encoder_config = config["rgb_encoder"][name]
|
||||
if name == "dinov2":
|
||||
return Dinov2Encoder(
|
||||
model_name=rgb_encoder_config["model_name"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder name: {name}")
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
import torch
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from torchvision import transforms
|
||||
ConfigManager.load_config_with('configs/local_train_config.yaml')
|
||||
ConfigManager.print_config()
|
||||
image_size = 480
|
||||
path = "/mnt/h/BaiduSyncdisk/workspace/ws_active_pose/project/ActivePerception/test/img0.jpg"
|
||||
img = cv2.imread(path)
|
||||
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.CenterCrop(int(image_size//14)*14),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=0.5, std=0.2)
|
||||
])
|
||||
|
||||
rgb = transform(img)
|
||||
print(rgb.shape)
|
||||
rgb_encoder = RGBEncoderFactory.create(name="dinov2", config=ConfigManager.get("modules"))
|
||||
rgb_encoder.load()
|
||||
print(rgb_encoder)
|
||||
rgb = rgb.to("cuda:0")
|
||||
rgb = rgb.unsqueeze(0)
|
||||
rgb_encoder = rgb_encoder.to("cuda:0")
|
||||
|
||||
rgb_feat = rgb_encoder.encode_rgb(rgb)
|
||||
|
||||
print(rgb_feat.shape)
|
||||
rgb_encoder.visualize_features(rgb_feat[0])
|
Reference in New Issue
Block a user