|
import logging |
|
import os |
|
import re |
|
|
|
from typing import Optional |
|
|
|
import huggingface_hub |
|
import requests |
|
|
|
from tqdm.auto import tqdm |
|
|
|
_MODELS = ( |
|
"medium" |
|
|
|
) |
|
|
|
|
|
def get_assets_path(): |
|
"""Returns the path to the assets directory.""" |
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") |
|
|
|
|
|
def get_logger(): |
|
"""Returns the module logger.""" |
|
return logging.getLogger("faster_whisper") |
|
|
|
|
|
def download_model( |
|
size_or_id: str, |
|
output_dir: Optional[str] = None, |
|
local_files_only: bool = False, |
|
cache_dir: Optional[str] = None, |
|
): |
|
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub. |
|
|
|
The model is downloaded from https://huggingface.co/DuyTa. |
|
|
|
Args: |
|
size_or_id: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en, |
|
medium, medium.en, large-v1, or large-v2), or a CTranslate2-converted model ID |
|
from the Hugging Face Hub (e.g. guillaumekln/faster-whisper-large-v2). |
|
output_dir: Directory where the model should be saved. If not set, the model is saved in |
|
the cache directory. |
|
local_files_only: If True, avoid downloading the file and return the path to the local |
|
cached file if it exists. |
|
cache_dir: Path to the folder where cached files are stored. |
|
|
|
Returns: |
|
The path to the downloaded model. |
|
|
|
Raises: |
|
ValueError: if the model size is invalid. |
|
""" |
|
if re.match(r".*/.*", size_or_id): |
|
repo_id = size_or_id |
|
else: |
|
if size_or_id not in _MODELS: |
|
raise ValueError( |
|
"Invalid model size '%s', expected one of: %s" |
|
% (size_or_id, ", ".join(_MODELS)) |
|
) |
|
|
|
|
|
repo_id = "DuyTa/Vietnamese_ASR" |
|
|
|
allow_patterns = [ |
|
"config.json", |
|
"model.bin", |
|
"tokenizer.json", |
|
"vocabulary.*", |
|
] |
|
|
|
kwargs = { |
|
"local_files_only": local_files_only, |
|
"allow_patterns": allow_patterns, |
|
"tqdm_class": disabled_tqdm, |
|
} |
|
|
|
if output_dir is not None: |
|
kwargs["local_dir"] = output_dir |
|
kwargs["local_dir_use_symlinks"] = False |
|
|
|
if cache_dir is not None: |
|
kwargs["cache_dir"] = cache_dir |
|
|
|
try: |
|
return huggingface_hub.snapshot_download(repo_id, **kwargs) |
|
except ( |
|
huggingface_hub.utils.HfHubHTTPError, |
|
requests.exceptions.ConnectionError, |
|
) as exception: |
|
logger = get_logger() |
|
logger.warning( |
|
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s", |
|
repo_id, |
|
exception, |
|
) |
|
logger.warning( |
|
"Trying to load the model directly from the local cache, if it exists." |
|
) |
|
|
|
kwargs["local_files_only"] = True |
|
return huggingface_hub.snapshot_download(repo_id, **kwargs) |
|
|
|
|
|
def format_timestamp( |
|
seconds: float, |
|
always_include_hours: bool = False, |
|
decimal_marker: str = ".", |
|
) -> str: |
|
assert seconds >= 0, "non-negative timestamp expected" |
|
milliseconds = round(seconds * 1000.0) |
|
|
|
hours = milliseconds // 3_600_000 |
|
milliseconds -= hours * 3_600_000 |
|
|
|
minutes = milliseconds // 60_000 |
|
milliseconds -= minutes * 60_000 |
|
|
|
seconds = milliseconds // 1_000 |
|
milliseconds -= seconds * 1_000 |
|
|
|
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" |
|
return ( |
|
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" |
|
) |
|
|
|
|
|
class disabled_tqdm(tqdm): |
|
def __init__(self, *args, **kwargs): |
|
kwargs["disable"] = True |
|
super().__init__(*args, **kwargs) |