21 lines
690 B
Python
21 lines
690 B
Python
|
|
||
|
import torch
|
||
|
from modules.rgb_encoder.abstract_rgb_encoder import RGBEncoder
|
||
|
from annotations.external_module import external_freeze
|
||
|
|
||
|
@external_freeze
|
||
|
class Dinov2Encoder(RGBEncoder):
|
||
|
def __init__(self, model_name):
|
||
|
super(Dinov2Encoder, self).__init__()
|
||
|
self.model_name = model_name
|
||
|
self.load()
|
||
|
|
||
|
def load(self):
|
||
|
self.dinov2 = torch.hub.load('modules/module_lib/dinov2', self.model_name, source='local').cuda()
|
||
|
|
||
|
def encode_rgb(self, rgb):
|
||
|
with torch.no_grad():
|
||
|
features_dict = self.dinov2.forward_features(rgb)
|
||
|
features = features_dict['x_norm_patchtokens']
|
||
|
return features
|