Spaces:
Running
on
T4
Running
on
T4
File size: 1,853 Bytes
565faca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from fam.llm.adapters.base import BaseDataAdapter
class TiltedEncodec(BaseDataAdapter):
def __init__(self, end_of_audio_token):
self._end_of_audio_token = end_of_audio_token
def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
assert len(tokens) > 1
text_ids = []
extracted_audio_ids = []
extracted_audio_ids.append([])
# Handle first hierarchy as special case as it contains text tokens as well
# TODO: maybe it doesn't need special case, and can be handled on it's own :)
for t in tokens[0]:
if t > self._end_of_audio_token:
text_ids.append(t)
elif t < self._end_of_audio_token:
extracted_audio_ids[0].append(t)
# Handle the rest of the hierarchies
for i in range(1, len(tokens)):
token_hierarchy_ids = tokens[i]
extracted_audio_ids.append([])
for t in token_hierarchy_ids:
if t < self._end_of_audio_token:
extracted_audio_ids[i].append(t)
if len(set([len(x) for x in extracted_audio_ids])) != 1:
min_len = min([len(x) for x in extracted_audio_ids])
max_len = max([len(x) for x in extracted_audio_ids])
print("WARNING: Number of tokens at each hierarchy must be of the same length!")
print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
print([len(x) for x in extracted_audio_ids])
extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]
return text_ids[:-1], extracted_audio_ids
def encode(self, text_tokens: list[int], audio_tokens: list[list[int]]):
"""
Performs the required combination and padding as needed.
"""
raise NotImplementedError
|