Spaces:
Sleeping
Sleeping
import urllib | |
import os | |
from typing import List | |
from urllib.parse import urlparse | |
from tqdm import tqdm | |
from src.conversion.hf_converter import convert_hf_whisper | |
class ModelConfig: | |
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"): | |
""" | |
Initialize a model configuration. | |
name: Name of the model | |
url: URL to download the model from | |
path: Path to the model file. If not set, the model will be downloaded from the URL. | |
type: Type of model. Can be whisper or huggingface. | |
""" | |
self.name = name | |
self.url = url | |
self.path = path | |
self.type = type | |
def download_url(self, root_dir: str): | |
import whisper | |
# See if path is already set | |
if self.path is not None: | |
return self.path | |
if root_dir is None: | |
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper") | |
model_type = self.type.lower() if self.type is not None else "whisper" | |
if model_type in ["huggingface", "hf"]: | |
self.path = self.url | |
destination_target = os.path.join(root_dir, self.name + ".pt") | |
# Convert from HuggingFace format to Whisper format | |
if os.path.exists(destination_target): | |
print(f"File {destination_target} already exists, skipping conversion") | |
else: | |
print("Saving HuggingFace model in Whisper format to " + destination_target) | |
convert_hf_whisper(self.url, destination_target) | |
self.path = destination_target | |
elif model_type in ["whisper", "w"]: | |
self.path = self.url | |
# See if URL is just a file | |
if self.url in whisper._MODELS: | |
# No need to download anything - Whisper will handle it | |
self.path = self.url | |
elif self.url.startswith("file://"): | |
# Get file path | |
self.path = urlparse(self.url).path | |
# See if it is an URL | |
elif self.url.startswith("http://") or self.url.startswith("https://"): | |
# Extension (or file name) | |
extension = os.path.splitext(self.url)[-1] | |
download_target = os.path.join(root_dir, self.name + extension) | |
if os.path.exists(download_target) and not os.path.isfile(download_target): | |
raise RuntimeError(f"{download_target} exists and is not a regular file") | |
if not os.path.isfile(download_target): | |
self._download_file(self.url, download_target) | |
else: | |
print(f"File {download_target} already exists, skipping download") | |
self.path = download_target | |
# Must be a local file | |
else: | |
self.path = self.url | |
else: | |
raise ValueError(f"Unknown model type {model_type}") | |
return self.path | |
def _download_file(self, url: str, destination: str): | |
with urllib.request.urlopen(url) as source, open(destination, "wb") as output: | |
with tqdm( | |
total=int(source.info().get("Content-Length")), | |
ncols=80, | |
unit="iB", | |
unit_scale=True, | |
unit_divisor=1024, | |
) as loop: | |
while True: | |
buffer = source.read(8192) | |
if not buffer: | |
break | |
output.write(buffer) | |
loop.update(len(buffer)) | |
class ApplicationConfig: | |
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600, | |
share: bool = False, server_name: str = None, server_port: int = 7860, default_model_name: str = "medium", | |
default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800, | |
auto_parallel: bool = False, output_dir: str = None): | |
self.models = models | |
self.input_audio_max_duration = input_audio_max_duration | |
self.share = share | |
self.server_name = server_name | |
self.server_port = server_port | |
self.default_model_name = default_model_name | |
self.default_vad = default_vad | |
self.vad_parallel_devices = vad_parallel_devices | |
self.vad_cpu_cores = vad_cpu_cores | |
self.vad_process_timeout = vad_process_timeout | |
self.auto_parallel = auto_parallel | |
self.output_dir = output_dir | |
def get_model_names(self): | |
return [ x.name for x in self.models ] | |
def parse_file(config_path: str): | |
import json5 | |
with open(config_path, "r") as f: | |
# Load using json5 | |
data = json5.load(f) | |
data_models = data.pop("models", []) | |
models = [ ModelConfig(**x) for x in data_models ] | |
return ApplicationConfig(models, **data) |