16 lines
532 B
Python
Raw Permalink Normal View History

2024-08-21 17:26:28 +08:00
import torch
import PytorchBoot.stereotype as stereotype
@stereotype.loss_function("gf_loss")
class GFLoss:
2024-08-23 13:04:38 +08:00
def __init__(self, _):
2024-08-21 17:26:28 +08:00
pass
def compute(self, output, _):
estimated_score = output['estimated_score']
target_score = output['target_score']
std = output['std']
bs = estimated_score.shape[0]
loss_weighting = std ** 2
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
return loss