aadnk's picture
Add configuration file and support for custom models
44d964a
raw
history blame
5 kB
import urllib
import os
from typing import List
from urllib.parse import urlparse
from tqdm import tqdm
from src.conversion.hf_converter import convert_hf_whisper
class ModelConfig:
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
"""
Initialize a model configuration.
name: Name of the model
url: URL to download the model from
path: Path to the model file. If not set, the model will be downloaded from the URL.
type: Type of model. Can be whisper or huggingface.
"""
self.name = name
self.url = url
self.path = path
self.type = type
def download_url(self, root_dir: str):
import whisper
# See if path is already set
if self.path is not None:
return self.path
if root_dir is None:
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
model_type = self.type.lower() if self.type is not None else "whisper"
if model_type in ["huggingface", "hf"]:
self.path = self.url
destination_target = os.path.join(root_dir, self.name + ".pt")
# Convert from HuggingFace format to Whisper format
if os.path.exists(destination_target):
print(f"File {destination_target} already exists, skipping conversion")
else:
print("Saving HuggingFace model in Whisper format to " + destination_target)
convert_hf_whisper(self.url, destination_target)
self.path = destination_target
elif model_type in ["whisper", "w"]:
self.path = self.url
# See if URL is just a file
if self.url in whisper._MODELS:
# No need to download anything - Whisper will handle it
self.path = self.url
elif self.url.startswith("file://"):
# Get file path
self.path = urlparse(self.url).path
# See if it is an URL
elif self.url.startswith("http://") or self.url.startswith("https://"):
# Extension (or file name)
extension = os.path.splitext(self.url)[-1]
download_target = os.path.join(root_dir, self.name + extension)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if not os.path.isfile(download_target):
self._download_file(self.url, download_target)
else:
print(f"File {download_target} already exists, skipping download")
self.path = download_target
# Must be a local file
else:
self.path = self.url
else:
raise ValueError(f"Unknown model type {model_type}")
return self.path
def _download_file(self, url: str, destination: str):
with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
class ApplicationConfig:
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
share: bool = False, server_name: str = None, server_port: int = 7860, default_model_name: str = "medium",
default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
auto_parallel: bool = False, output_dir: str = None):
self.models = models
self.input_audio_max_duration = input_audio_max_duration
self.share = share
self.server_name = server_name
self.server_port = server_port
self.default_model_name = default_model_name
self.default_vad = default_vad
self.vad_parallel_devices = vad_parallel_devices
self.vad_cpu_cores = vad_cpu_cores
self.vad_process_timeout = vad_process_timeout
self.auto_parallel = auto_parallel
self.output_dir = output_dir
def get_model_names(self):
return [ x.name for x in self.models ]
@staticmethod
def parse_file(config_path: str):
import json5
with open(config_path, "r") as f:
# Load using json5
data = json5.load(f)
data_models = data.pop("models", [])
models = [ ModelConfig(**x) for x in data_models ]
return ApplicationConfig(models, **data)