Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from typing import List, Dict | |
from .base_decoder import BaseDecoder | |
class ViterbiDecoder(BaseDecoder): | |
def decode( | |
self, | |
emissions: torch.FloatTensor, | |
) -> List[List[Dict[str, torch.LongTensor]]]: | |
def get_pred(e): | |
toks = e.argmax(dim=-1).unique_consecutive() | |
return toks[toks != self.blank] | |
return [[{"tokens": get_pred(x), "score": 0}] for x in emissions] | |