Spaces:
Runtime error
Runtime error
File size: 7,542 Bytes
44d964a 1acaa19 44d964a 1acaa19 44d964a 1acaa19 44d964a 1acaa19 44d964a 1acaa19 44d964a 1acaa19 44d964a 1acaa19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import urllib
import os
from typing import List
from urllib.parse import urlparse
import json5
import torch
from tqdm import tqdm
from src.conversion.hf_converter import convert_hf_whisper
class ModelConfig:
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
"""
Initialize a model configuration.
name: Name of the model
url: URL to download the model from
path: Path to the model file. If not set, the model will be downloaded from the URL.
type: Type of model. Can be whisper or huggingface.
"""
self.name = name
self.url = url
self.path = path
self.type = type
def download_url(self, root_dir: str):
import whisper
# See if path is already set
if self.path is not None:
return self.path
if root_dir is None:
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
model_type = self.type.lower() if self.type is not None else "whisper"
if model_type in ["huggingface", "hf"]:
self.path = self.url
destination_target = os.path.join(root_dir, self.name + ".pt")
# Convert from HuggingFace format to Whisper format
if os.path.exists(destination_target):
print(f"File {destination_target} already exists, skipping conversion")
else:
print("Saving HuggingFace model in Whisper format to " + destination_target)
convert_hf_whisper(self.url, destination_target)
self.path = destination_target
elif model_type in ["whisper", "w"]:
self.path = self.url
# See if URL is just a file
if self.url in whisper._MODELS:
# No need to download anything - Whisper will handle it
self.path = self.url
elif self.url.startswith("file://"):
# Get file path
self.path = urlparse(self.url).path
# See if it is an URL
elif self.url.startswith("http://") or self.url.startswith("https://"):
# Extension (or file name)
extension = os.path.splitext(self.url)[-1]
download_target = os.path.join(root_dir, self.name + extension)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if not os.path.isfile(download_target):
self._download_file(self.url, download_target)
else:
print(f"File {download_target} already exists, skipping download")
self.path = download_target
# Must be a local file
else:
self.path = self.url
else:
raise ValueError(f"Unknown model type {model_type}")
return self.path
def _download_file(self, url: str, destination: str):
with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
class ApplicationConfig:
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
share: bool = False, server_name: str = None, server_port: int = 7860, delete_uploaded_files: bool = True,
default_model_name: str = "medium", default_vad: str = "silero-vad",
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
auto_parallel: bool = False, output_dir: str = None,
model_dir: str = None, device: str = None,
verbose: bool = True, task: str = "transcribe", language: str = None,
vad_merge_window: float = 5, vad_max_merge_size: float = 30,
vad_padding: float = 1, vad_prompt_window: float = 3,
temperature: float = 0, best_of: int = 5, beam_size: int = 5,
patience: float = None, length_penalty: float = None,
suppress_tokens: str = "-1", initial_prompt: str = None,
condition_on_previous_text: bool = True, fp16: bool = True,
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.models = models
# WebUI settings
self.input_audio_max_duration = input_audio_max_duration
self.share = share
self.server_name = server_name
self.server_port = server_port
self.delete_uploaded_files = delete_uploaded_files
self.default_model_name = default_model_name
self.default_vad = default_vad
self.vad_parallel_devices = vad_parallel_devices
self.vad_cpu_cores = vad_cpu_cores
self.vad_process_timeout = vad_process_timeout
self.auto_parallel = auto_parallel
self.output_dir = output_dir
self.model_dir = model_dir
self.device = device
self.verbose = verbose
self.task = task
self.language = language
self.vad_merge_window = vad_merge_window
self.vad_max_merge_size = vad_max_merge_size
self.vad_padding = vad_padding
self.vad_prompt_window = vad_prompt_window
self.temperature = temperature
self.best_of = best_of
self.beam_size = beam_size
self.patience = patience
self.length_penalty = length_penalty
self.suppress_tokens = suppress_tokens
self.initial_prompt = initial_prompt
self.condition_on_previous_text = condition_on_previous_text
self.fp16 = fp16
self.temperature_increment_on_fallback = temperature_increment_on_fallback
self.compression_ratio_threshold = compression_ratio_threshold
self.logprob_threshold = logprob_threshold
self.no_speech_threshold = no_speech_threshold
def get_model_names(self):
return [ x.name for x in self.models ]
def update(self, **new_values):
result = ApplicationConfig(**self.__dict__)
for key, value in new_values.items():
setattr(result, key, value)
return result
@staticmethod
def create_default(**kwargs):
app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
# Update with kwargs
if len(kwargs) > 0:
app_config = app_config.update(**kwargs)
return app_config
@staticmethod
def parse_file(config_path: str):
import json5
with open(config_path, "r") as f:
# Load using json5
data = json5.load(f)
data_models = data.pop("models", [])
models = [ ModelConfig(**x) for x in data_models ]
return ApplicationConfig(models, **data)
|