File size: 896 Bytes
83ae704 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# losses for sparse ga
# --------------------------------------------------------
import torch
import numpy as np
def l05_loss(x, y):
return torch.linalg.norm(x - y, dim=-1).sqrt()
def l1_loss(x, y):
return torch.linalg.norm(x - y, dim=-1)
def gamma_loss(gamma, mul=1, offset=None, clip=np.inf):
if offset is None:
if gamma == 1:
return l1_loss
# d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1))
offset = (1 / gamma)**(1 / (gamma - 1))
def loss_func(x, y):
return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma
return loss_func
def meta_gamma_loss():
return lambda alpha: gamma_loss(alpha)
|