translator_api / translator.py
jxtan's picture
Cache only specific files from SeamlessM4T
58ac03b
import os
import pathlib
import torch
from fairseq2.assets import InProcAssetMetadataProvider, asset_store
from seamless_communication.inference import Translator
CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
if not CHECKPOINTS_PATH.exists():
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
raise FileNotFoundError(f"Checkpoint path {CHECKPOINTS_PATH} does not exist")
asset_store.env_resolvers.clear()
asset_store.env_resolvers.append(lambda: "demo")
demo_metadata = [
# https://github.com/facebookresearch/seamless_communication/blob/dd67e71317d66752ef16cf21bd842ca3273244c9/src/seamless_communication/cards/seamlessM4T_v2_large.yaml#L10
# char_tokenizer: "https://huggingface.co/facebook/seamless-m4t-v2-large/resolve/main/spm_char_lang38_tc.model"
# checkpoint: "https://huggingface.co/facebook/seamless-m4t-v2-large/resolve/main/seamlessM4T_v2_large.pt"
{
"name": "seamlessM4T_v2_large@demo",
"checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
"char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
},
# https://github.com/facebookresearch/seamless_communication/blob/dd67e71317d66752ef16cf21bd842ca3273244c9/src/seamless_communication/cards/unity_nllb-100.yaml#L9C1-L9C93
# tokenizer: "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model"
{
"name": "unity_nllb-100@demo",
"tokenizer": f"file://{CHECKPOINTS_PATH}/tokenizer.model",
},
# https://github.com/facebookresearch/seamless_communication/blob/dd67e71317d66752ef16cf21bd842ca3273244c9/src/seamless_communication/cards/vocoder_v2.yaml#L10
# checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/vocoder_v2.pt"
{
"name": "vocoder_v2@demo",
"checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
},
]
asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
translator = Translator(
model_name_or_card="seamlessM4T_v2_large",
vocoder_name_or_card="vocoder_v2",
device=device,
dtype=dtype,
apply_mintox=True,
)
if __name__ == '__main__':
input_text = "Hello, how are you today?"
source_language_code = "eng"
target_language_code = "zsm"
result = translator.predict(
input=input_text,
task_str="T2TT",
src_lang=source_language_code,
tgt_lang=target_language_code,
)
print(str(result[0]))