update loss
This commit is contained in:
16
core/loss.py
Normal file
16
core/loss.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
@stereotype.loss_function("gf_loss")
|
||||
class GFLoss:
|
||||
def __init__(self, config):
|
||||
pass
|
||||
|
||||
def compute(self, output, _):
|
||||
estimated_score = output['estimated_score']
|
||||
target_score = output['target_score']
|
||||
std = output['std']
|
||||
bs = estimated_score.shape[0]
|
||||
loss_weighting = std ** 2
|
||||
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
|
||||
return loss
|
@@ -1,12 +0,0 @@
|
||||
import torch
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
@stereotype.loss_function("gf_loss")
|
||||
def compute_loss(output, data):
|
||||
estimated_score = output['estimated_score']
|
||||
target_score = output['target_score']
|
||||
std = output['std']
|
||||
bs = estimated_score.shape[0]
|
||||
loss_weighting = std ** 2
|
||||
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
|
||||
return loss
|
Reference in New Issue
Block a user