SoybeanMilk commited on
Commit
a0db9dd
·
1 Parent(s): 4dcbce8

Add madlad400 support.

Browse files
Files changed (1) hide show
  1. src/config.py +10 -3
src/config.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Dict, Literal
5
 
6
 
7
  class ModelConfig:
8
- def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None):
9
  """
10
  Initialize a model configuration.
11
 
@@ -13,12 +13,19 @@ class ModelConfig:
13
  url: URL to download the model from
14
  path: Path to the model file. If not set, the model will be downloaded from the URL.
15
  type: Type of model. Can be whisper or huggingface.
 
 
 
 
 
16
  """
17
  self.name = name
18
  self.url = url
19
  self.path = path
20
  self.type = type
21
  self.tokenizer_url = tokenizer_url
 
 
22
 
23
  VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
24
 
@@ -43,7 +50,7 @@ class VadInitialPromptMode(Enum):
43
  return None
44
 
45
  class ApplicationConfig:
46
- def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
47
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
48
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
49
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
@@ -169,7 +176,7 @@ class ApplicationConfig:
169
  # Load using json5
170
  data = json5.load(f)
171
  data_models = data.pop("models", [])
172
- models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
173
  key: [ModelConfig(**item) for item in value]
174
  for key, value in data_models.items()
175
  }
 
5
 
6
 
7
  class ModelConfig:
8
+ def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None, revision: str = None, model_file: str = None,):
9
  """
10
  Initialize a model configuration.
11
 
 
13
  url: URL to download the model from
14
  path: Path to the model file. If not set, the model will be downloaded from the URL.
15
  type: Type of model. Can be whisper or huggingface.
16
+ revision: [by transformers] The specific model version to use.
17
+ It can be a branch name, a tag name, or a commit id,
18
+ since we use a git-based system for storing models and other artifacts on huggingface.co,
19
+ so revision can be any identifier allowed by git.
20
+ model_file: The name of the model file in repo or directory.[from marella/ctransformers]
21
  """
22
  self.name = name
23
  self.url = url
24
  self.path = path
25
  self.type = type
26
  self.tokenizer_url = tokenizer_url
27
+ self.revision = revision
28
+ self.model_file = model_file
29
 
30
  VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
31
 
 
50
  return None
51
 
52
  class ApplicationConfig:
53
+ def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400"], List[ModelConfig]],
54
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
55
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
56
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
 
176
  # Load using json5
177
  data = json5.load(f)
178
  data_models = data.pop("models", [])
179
+ models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400"], List[ModelConfig]] = {
180
  key: [ModelConfig(**item) for item in value]
181
  for key, value in data_models.items()
182
  }