File size: 3,682 Bytes
c6b1960 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import logging
import os
import re
from typing import Optional
import huggingface_hub
import requests
from tqdm.auto import tqdm
_MODELS = (
"medium"
)
def get_assets_path():
"""Returns the path to the assets directory."""
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
def get_logger():
"""Returns the module logger."""
return logging.getLogger("faster_whisper")
def download_model(
size_or_id: str,
output_dir: Optional[str] = None,
local_files_only: bool = False,
cache_dir: Optional[str] = None,
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
The model is downloaded from https://huggingface.co/DuyTa.
Args:
size_or_id: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
medium, medium.en, large-v1, or large-v2), or a CTranslate2-converted model ID
from the Hugging Face Hub (e.g. guillaumekln/faster-whisper-large-v2).
output_dir: Directory where the model should be saved. If not set, the model is saved in
the cache directory.
local_files_only: If True, avoid downloading the file and return the path to the local
cached file if it exists.
cache_dir: Path to the folder where cached files are stored.
Returns:
The path to the downloaded model.
Raises:
ValueError: if the model size is invalid.
"""
if re.match(r".*/.*", size_or_id):
repo_id = size_or_id
else:
if size_or_id not in _MODELS:
raise ValueError(
"Invalid model size '%s', expected one of: %s"
% (size_or_id, ", ".join(_MODELS))
)
#repo_id = "DuyTa/vi-whisper-%s-Lora" % size_or_id
repo_id = "DuyTa/Vietnamese_ASR"
allow_patterns = [
"config.json",
"model.bin",
"tokenizer.json",
"vocabulary.*",
]
kwargs = {
"local_files_only": local_files_only,
"allow_patterns": allow_patterns,
"tqdm_class": disabled_tqdm,
}
if output_dir is not None:
kwargs["local_dir"] = output_dir
kwargs["local_dir_use_symlinks"] = False
if cache_dir is not None:
kwargs["cache_dir"] = cache_dir
try:
return huggingface_hub.snapshot_download(repo_id, **kwargs)
except (
huggingface_hub.utils.HfHubHTTPError,
requests.exceptions.ConnectionError,
) as exception:
logger = get_logger()
logger.warning(
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
repo_id,
exception,
)
logger.warning(
"Trying to load the model directly from the local cache, if it exists."
)
kwargs["local_files_only"] = True
return huggingface_hub.snapshot_download(repo_id, **kwargs)
def format_timestamp(
seconds: float,
always_include_hours: bool = False,
decimal_marker: str = ".",
) -> str:
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class disabled_tqdm(tqdm):
def __init__(self, *args, **kwargs):
kwargs["disable"] = True
super().__init__(*args, **kwargs) |