Files
nbv_grasping/modules/rgb_encoder/dinov2_encoder.py

21 lines
690 B
Python
Raw Normal View History

2024-10-09 16:13:22 +00:00
import torch
from modules.rgb_encoder.abstract_rgb_encoder import RGBEncoder
from annotations.external_module import external_freeze
@external_freeze
class Dinov2Encoder(RGBEncoder):
def __init__(self, model_name):
super(Dinov2Encoder, self).__init__()
self.model_name = model_name
self.load()
def load(self):
self.dinov2 = torch.hub.load('modules/module_lib/dinov2', self.model_name, source='local').cuda()
def encode_rgb(self, rgb):
with torch.no_grad():
features_dict = self.dinov2.forward_features(rgb)
features = features_dict['x_norm_patchtokens']
return features