Update model_dtat.py
Browse files- model_dtat.py +1 -1
model_dtat.py
CHANGED
@@ -317,7 +317,7 @@ class DTATTransformer(nn.Module):
|
|
317 |
logits = logits.view(B*T, C)
|
318 |
targets = targets.view(B*T)
|
319 |
# Calculate loss directly in BPC instead of nats
|
320 |
-
loss = F.cross_entropy(logits, targets)
|
321 |
|
322 |
return logits, loss, importance_scores
|
323 |
|
|
|
317 |
logits = logits.view(B*T, C)
|
318 |
targets = targets.view(B*T)
|
319 |
# Calculate loss directly in BPC instead of nats
|
320 |
+
loss = F.cross_entropy(logits, targets) / math.log(2)
|
321 |
|
322 |
return logits, loss, importance_scores
|
323 |
|