|
|
|
|
|
import os |
|
from typing import List, Optional |
|
from src.simuleval_transcoder import SimulevalTranscoder |
|
import json |
|
import logging |
|
|
|
logger = logging.getLogger("socketio_server_pubsub") |
|
|
|
|
|
M4T_P0_LANGS = [ |
|
"eng", |
|
"arb", "ben", "cat", "ces", "cmn", "cym", "dan", |
|
"deu", "est", "fin", "fra", "hin", "ind", "ita", |
|
"jpn", "kor", "mlt", "nld", "pes", "pol", "por", |
|
"ron", "rus", "slk", "spa", "swe", "swh", "tel", |
|
"tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", |
|
] |
|
|
|
|
|
|
|
class NoAvailableAgentException(Exception): |
|
pass |
|
|
|
|
|
class AgentWithInfo: |
|
def __init__( |
|
self, |
|
agent, |
|
name: str, |
|
modalities: List[str], |
|
target_langs: List[str], |
|
|
|
dynamic_params: List[str] = [], |
|
description="", |
|
has_expressive: Optional[bool] = None, |
|
): |
|
self.agent = agent |
|
self.has_expressive = has_expressive |
|
self.name = name |
|
self.description = description |
|
self.modalities = modalities |
|
self.target_langs = target_langs |
|
self.dynamic_params = dynamic_params |
|
|
|
def get_capabilities_for_json(self): |
|
return { |
|
"name": self.name, |
|
"description": self.description, |
|
"modalities": self.modalities, |
|
"targetLangs": self.target_langs, |
|
"dynamicParams": self.dynamic_params, |
|
} |
|
|
|
@classmethod |
|
def load_from_json(cls, config: str): |
|
""" |
|
Takes in JSON array of models to load in, e.g. |
|
[{"name": "s2s_m4t_emma-unity2_multidomain_v0.1", "description": "M4T model that supports simultaneous S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}, |
|
{"name": "s2s_m4t_expr-emma_v0.1", "description": "ES-EN expressive model that supports S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}] |
|
""" |
|
configs = json.loads(config) |
|
agents = [] |
|
for config in configs: |
|
agent = SimulevalTranscoder.build_agent(config["name"]) |
|
agents.append( |
|
AgentWithInfo( |
|
agent=agent, |
|
name=config["name"], |
|
modalities=config["modalities"], |
|
target_langs=config["targetLangs"], |
|
) |
|
) |
|
return agents |
|
|
|
|
|
class SimulevalAgentDirectory: |
|
|
|
seamless_streaming_agent = "SeamlessStreaming" |
|
seamless_agent = "Seamless" |
|
|
|
def __init__(self): |
|
self.agents = [] |
|
self.did_build_and_add_agents = False |
|
|
|
def add_agent(self, agent: AgentWithInfo): |
|
self.agents.append(agent) |
|
|
|
def build_agent_if_available(self, model_id, config_name=None): |
|
agent = None |
|
try: |
|
if config_name is not None: |
|
agent = SimulevalTranscoder.build_agent( |
|
model_id, |
|
config_name=config_name, |
|
) |
|
else: |
|
agent = SimulevalTranscoder.build_agent( |
|
model_id, |
|
) |
|
except Exception as e: |
|
from fairseq2.assets.error import AssetError |
|
logger.warning("Failed to build agent %s: %s" % (model_id, e)) |
|
if isinstance(e, AssetError): |
|
logger.warning( |
|
"Please download gated assets and set `gated_model_dir` in the config" |
|
) |
|
raise e |
|
|
|
return agent |
|
|
|
def build_and_add_agents(self, models_override=None): |
|
if self.did_build_and_add_agents: |
|
return |
|
|
|
if models_override is not None: |
|
agent_infos = AgentWithInfo.load_from_json(models_override) |
|
for agent_info in agent_infos: |
|
self.add_agent(agent_info) |
|
else: |
|
s2s_agent = None |
|
if os.environ.get("USE_EXPRESSIVE_MODEL", "0") == "1": |
|
logger.info("Building expressive model...") |
|
s2s_agent = self.build_agent_if_available( |
|
SimulevalAgentDirectory.seamless_agent, |
|
config_name="vad_s2st_sc_24khz_main.yaml", |
|
) |
|
has_expressive = True |
|
else: |
|
logger.info("Building non-expressive model...") |
|
s2s_agent = self.build_agent_if_available( |
|
SimulevalAgentDirectory.seamless_streaming_agent, |
|
config_name="vad_s2st_sc_main.yaml", |
|
) |
|
has_expressive = False |
|
|
|
if s2s_agent: |
|
self.add_agent( |
|
AgentWithInfo( |
|
agent=s2s_agent, |
|
name=SimulevalAgentDirectory.seamless_streaming_agent, |
|
modalities=["s2t", "s2s"], |
|
target_langs=M4T_P0_LANGS, |
|
dynamic_params=["expressive"], |
|
description="multilingual expressive model that supports S2S and S2T", |
|
has_expressive=has_expressive, |
|
) |
|
) |
|
|
|
if len(self.agents) == 0: |
|
logger.error( |
|
"No agents were loaded. This likely means you are missing the actual model files specified in simuleval_agent_directory." |
|
) |
|
|
|
self.did_build_and_add_agents = True |
|
|
|
def get_agent(self, name): |
|
for agent in self.agents: |
|
if agent.name == name: |
|
return agent |
|
return None |
|
|
|
def get_agent_or_throw(self, name): |
|
agent = self.get_agent(name) |
|
if agent is None: |
|
raise NoAvailableAgentException("No agent found with name= %s" % (name)) |
|
return agent |
|
|
|
def get_agents_capabilities_list_for_json(self): |
|
return [agent.get_capabilities_for_json() for agent in self.agents] |
|
|