File size: 914 Bytes
ad1e0a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from config import DatasetName
import tensorflow as tf


class ACRLoss:
    def acr_loss(self, x_pr, x_gt, phi, lambda_weight, ds_name):
        low_map = tf.cast(tf.abs(x_pr - x_gt) <= 1.0, dtype=tf.float32)
        high_map = tf.cast(tf.abs(x_pr - x_gt) > 1.0, dtype=tf.float32)

        '''Big errors'''
        ln_2 = tf.ones_like(x_pr, dtype=tf.float32) * tf.math.log(2.0)
        C = tf.cast(tf.cast(phi, dtype=tf.double) * tf.cast(ln_2, dtype=tf.double) - 1.0, dtype=tf.float32)
        loss_high = 100 * tf.reduce_mean(tf.math.multiply(high_map, (tf.square(x_pr - x_gt) + C)))

        '''Small errors'''
        power = tf.cast(2.0 - phi, tf.dtypes.float32)
        ll = tf.pow(tf.abs(x_pr - x_gt), power)
        loss_low = 100 * tf.reduce_mean(tf.math.multiply(low_map, (lambda_weight * tf.math.log(1.0 + ll))))

        loss_total = loss_low + loss_high

        return loss_total, loss_low, loss_high