jerome-revaud's picture
Initial commit
83ae704
raw
history blame
896 Bytes
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# losses for sparse ga
# --------------------------------------------------------
import torch
import numpy as np
def l05_loss(x, y):
return torch.linalg.norm(x - y, dim=-1).sqrt()
def l1_loss(x, y):
return torch.linalg.norm(x - y, dim=-1)
def gamma_loss(gamma, mul=1, offset=None, clip=np.inf):
if offset is None:
if gamma == 1:
return l1_loss
# d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1))
offset = (1 / gamma)**(1 / (gamma - 1))
def loss_func(x, y):
return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma
return loss_func
def meta_gamma_loss():
return lambda alpha: gamma_loss(alpha)