ACR-Loss / acr_loss.py
daliprf
init
ad1e0a0
raw
history blame
914 Bytes
from config import DatasetName
import tensorflow as tf
class ACRLoss:
def acr_loss(self, x_pr, x_gt, phi, lambda_weight, ds_name):
low_map = tf.cast(tf.abs(x_pr - x_gt) <= 1.0, dtype=tf.float32)
high_map = tf.cast(tf.abs(x_pr - x_gt) > 1.0, dtype=tf.float32)
'''Big errors'''
ln_2 = tf.ones_like(x_pr, dtype=tf.float32) * tf.math.log(2.0)
C = tf.cast(tf.cast(phi, dtype=tf.double) * tf.cast(ln_2, dtype=tf.double) - 1.0, dtype=tf.float32)
loss_high = 100 * tf.reduce_mean(tf.math.multiply(high_map, (tf.square(x_pr - x_gt) + C)))
'''Small errors'''
power = tf.cast(2.0 - phi, tf.dtypes.float32)
ll = tf.pow(tf.abs(x_pr - x_gt), power)
loss_low = 100 * tf.reduce_mean(tf.math.multiply(low_map, (lambda_weight * tf.math.log(1.0 + ll))))
loss_total = loss_low + loss_high
return loss_total, loss_low, loss_high