success
This commit is contained in:
0
losses/__init__.py
Executable file
0
losses/__init__.py
Executable file
11
losses/gf_loss.py
Executable file
11
losses/gf_loss.py
Executable file
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
|
||||
def compute_loss(output, data):
|
||||
estimated_score = output['estimated_score']
|
||||
target_score = output['target_score']
|
||||
std = output['std']
|
||||
bs = estimated_score.shape[0]
|
||||
loss_weighting = std ** 2
|
||||
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
|
||||
return loss
|
18
losses/loss_function_factory.py
Executable file
18
losses/loss_function_factory.py
Executable file
@@ -0,0 +1,18 @@
|
||||
import losses.gf_loss
|
||||
|
||||
|
||||
class LossFunctionFactory:
|
||||
@staticmethod
|
||||
def create(function_name):
|
||||
if function_name == "gf_loss":
|
||||
return losses.gf_loss.compute_loss
|
||||
else:
|
||||
raise ValueError("Unknown loss function {}".format(function_name))
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
|
||||
ConfigManager.load_config_with('../configs/local_train_config.yaml')
|
||||
ConfigManager.print_config()
|
Reference in New Issue
Block a user