Downloaded multitask_unity_large.pt, how can I use it
I downloaded the model multitask_unity_large.pt but I don't know how to use it, how to ask for
- Speech-to-speech translation (S2ST)
- Speech-to-text translation (S2TT)
- Text-to-speech translation (T2ST)
- Text-to-text translation (T2TT)
- Automatic speech recognition (ASR) ?
You can pass the path to the model checkpoint to the Translator
class instead of the checkpoint name. See instructions-to-run-inference-with-seamlessm4t-models for details
How it should look like?
I put all weights to local folder "seamless_weights" and pass
them to Translator class:
import torch
from seamless_communication.models.inference import Translator
translator = Translator("seamless_weights/multitask_unity_medium.pt",
vocoder_name_or_card="seamless_weights/vocoder_36langs.pt",
device=torch.device("cpu"))
it outputs me Value Error exeception:
ValueError: name
must be a valid filename, but is 'seamless_weights/multitask_unity_medium.pt' instead.
What does it mean 'valid filename' in this context?
It is valid file name
os.path.exists('seamless_weights/multitask_unity_medium.pt')
True
Now I try to alter model cards seamless_communication/assets/cards/seamlessM4T_medium.yaml
to make them load the weights from the local folder
#checkpoint: "https://huggingface.co/facebook/seamless-m4t-medium/resolve/main/multitask_unity_medium.pt"
checkpoint: "/home/local/seamless_communication/seamless_weights/multitask_unity_medium.pt"
but got exeception that this path should be a valid uri...
It is quite unclear how to use local weights...
Update:
This seamless_communication models seems based on farseq2 library and use its classes for weights downloading
So to achieve the goal above we can implement subclasses for two fairseq2 classes:
from fairseq2.models.utils.model_loader import ModelLoader
from fairseq2.models.nllb.loader import NllbTokenizerLoader
Where get rid of check if path to weights are uri in fashion like this
try:
# Load the checkpoint.
uri = card.field("checkpoint").as_uri()
pathname = self.download_manager.download_checkpoint(uri, card.name, force=force, progress=progress)
except AssetCardError:
pathname = card.field("checkpoint").data
And use these subclasses in dedicated classes of seamless_communication.
But this doesn't fix a problem then weights are located on remote share...