Spaces:
Runtime error
Runtime error
SoybeanMilk
commited on
Commit
·
a0db9dd
1
Parent(s):
4dcbce8
Add madlad400 support.
Browse files- src/config.py +10 -3
src/config.py
CHANGED
@@ -5,7 +5,7 @@ from typing import List, Dict, Literal
|
|
5 |
|
6 |
|
7 |
class ModelConfig:
|
8 |
-
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None):
|
9 |
"""
|
10 |
Initialize a model configuration.
|
11 |
|
@@ -13,12 +13,19 @@ class ModelConfig:
|
|
13 |
url: URL to download the model from
|
14 |
path: Path to the model file. If not set, the model will be downloaded from the URL.
|
15 |
type: Type of model. Can be whisper or huggingface.
|
|
|
|
|
|
|
|
|
|
|
16 |
"""
|
17 |
self.name = name
|
18 |
self.url = url
|
19 |
self.path = path
|
20 |
self.type = type
|
21 |
self.tokenizer_url = tokenizer_url
|
|
|
|
|
22 |
|
23 |
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
24 |
|
@@ -43,7 +50,7 @@ class VadInitialPromptMode(Enum):
|
|
43 |
return None
|
44 |
|
45 |
class ApplicationConfig:
|
46 |
-
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
|
47 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
48 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
49 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
@@ -169,7 +176,7 @@ class ApplicationConfig:
|
|
169 |
# Load using json5
|
170 |
data = json5.load(f)
|
171 |
data_models = data.pop("models", [])
|
172 |
-
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
|
173 |
key: [ModelConfig(**item) for item in value]
|
174 |
for key, value in data_models.items()
|
175 |
}
|
|
|
5 |
|
6 |
|
7 |
class ModelConfig:
|
8 |
+
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None, revision: str = None, model_file: str = None,):
|
9 |
"""
|
10 |
Initialize a model configuration.
|
11 |
|
|
|
13 |
url: URL to download the model from
|
14 |
path: Path to the model file. If not set, the model will be downloaded from the URL.
|
15 |
type: Type of model. Can be whisper or huggingface.
|
16 |
+
revision: [by transformers] The specific model version to use.
|
17 |
+
It can be a branch name, a tag name, or a commit id,
|
18 |
+
since we use a git-based system for storing models and other artifacts on huggingface.co,
|
19 |
+
so revision can be any identifier allowed by git.
|
20 |
+
model_file: The name of the model file in repo or directory.[from marella/ctransformers]
|
21 |
"""
|
22 |
self.name = name
|
23 |
self.url = url
|
24 |
self.path = path
|
25 |
self.type = type
|
26 |
self.tokenizer_url = tokenizer_url
|
27 |
+
self.revision = revision
|
28 |
+
self.model_file = model_file
|
29 |
|
30 |
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
31 |
|
|
|
50 |
return None
|
51 |
|
52 |
class ApplicationConfig:
|
53 |
+
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400"], List[ModelConfig]],
|
54 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
55 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
56 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
|
|
176 |
# Load using json5
|
177 |
data = json5.load(f)
|
178 |
data_models = data.pop("models", [])
|
179 |
+
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400"], List[ModelConfig]] = {
|
180 |
key: [ModelConfig(**item) for item in value]
|
181 |
for key, value in data_models.items()
|
182 |
}
|