success
This commit is contained in:
20
modules/rgb_encoder/dinov2_encoder.py
Executable file
20
modules/rgb_encoder/dinov2_encoder.py
Executable file
@@ -0,0 +1,20 @@
|
||||
|
||||
import torch
|
||||
from modules.rgb_encoder.abstract_rgb_encoder import RGBEncoder
|
||||
from annotations.external_module import external_freeze
|
||||
|
||||
@external_freeze
|
||||
class Dinov2Encoder(RGBEncoder):
|
||||
def __init__(self, model_name):
|
||||
super(Dinov2Encoder, self).__init__()
|
||||
self.model_name = model_name
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
self.dinov2 = torch.hub.load('modules/module_lib/dinov2', self.model_name, source='local').cuda()
|
||||
|
||||
def encode_rgb(self, rgb):
|
||||
with torch.no_grad():
|
||||
features_dict = self.dinov2.forward_features(rgb)
|
||||
features = features_dict['x_norm_patchtokens']
|
||||
return features
|
Reference in New Issue
Block a user