Samuel Mueller
working locally
f50f696
raw
history blame contribute delete
No virus
209 Bytes
import torch
from torch import nn
class ScaledSoftmaxCE(nn.Module):
def forward(self, x, label):
logits = x[..., :-10]
temp_scales = x[..., -10:]
logprobs = logits.softmax(-1)