File size: 518 Bytes
626eca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from dataclasses import dataclass
import torch
PRECISION_MAP = {
None: torch.float32,
16: torch.float16,
32: torch.float32,
"float16": torch.float16,
"float32": torch.float32,
"half": torch.float16,
"float": torch.float32,
"16": torch.float16,
"32": torch.float32,
"fp16": torch.float16,
"fp32": torch.float32,
}
@dataclass
class RetrievedSample:
"""
Dataclass for the output of the GoldenRetriever model.
"""
score: float
index: int
label: str
|