eyad-silx commited on
Commit
b95e4a2
·
verified ·
1 Parent(s): d606189

Update model_dtat.py

Browse files
Files changed (1) hide show
  1. model_dtat.py +1 -1
model_dtat.py CHANGED
@@ -317,7 +317,7 @@ class DTATTransformer(nn.Module):
317
  logits = logits.view(B*T, C)
318
  targets = targets.view(B*T)
319
  # Calculate loss directly in BPC instead of nats
320
- loss = F.cross_entropy(logits, targets) * math.log2(math.e)
321
 
322
  return logits, loss, importance_scores
323
 
 
317
  logits = logits.view(B*T, C)
318
  targets = targets.view(B*T)
319
  # Calculate loss directly in BPC instead of nats
320
+ loss = F.cross_entropy(logits, targets) / math.log(2)
321
 
322
  return logits, loss, importance_scores
323