21 lines
690 B
Python
Executable File
21 lines
690 B
Python
Executable File
|
|
import torch
|
|
from modules.rgb_encoder.abstract_rgb_encoder import RGBEncoder
|
|
from annotations.external_module import external_freeze
|
|
|
|
@external_freeze
|
|
class Dinov2Encoder(RGBEncoder):
|
|
def __init__(self, model_name):
|
|
super(Dinov2Encoder, self).__init__()
|
|
self.model_name = model_name
|
|
self.load()
|
|
|
|
def load(self):
|
|
self.dinov2 = torch.hub.load('modules/module_lib/dinov2', self.model_name, source='local').cuda()
|
|
|
|
def encode_rgb(self, rgb):
|
|
with torch.no_grad():
|
|
features_dict = self.dinov2.forward_features(rgb)
|
|
features = features_dict['x_norm_patchtokens']
|
|
return features
|