File size: 1,397 Bytes
a4b0060 cad3946 a4b0060 d3e85c8 474b6f1 d3e85c8 2101135 a4b0060 2101135 a4b0060 d3e85c8 de0f77a d3e85c8 fca1dff d3e85c8 cad3946 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import urllib3
from huggingface_hub import snapshot_download
from available_models import AVAILABLE_MODELS
def change_default_timeout(new_timeout: int) -> None:
"""
Changes the default timeout for downloading repositories from the Hugging Face Hub.
Prevents the following errors:
urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443):
Read timed out. (read timeout=10)
"""
urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout
def download_pytorch_model(name: str) -> None:
"""
Downloads a pytorch model and all the small files from the model's repository.
Other model formats (tensorflow, tflite, safetensors, msgpack, ot...) are not downloaded.
"""
number_of_seconds_in_a_year: int = 60 * 60 * 24 * 365
change_default_timeout(number_of_seconds_in_a_year)
snapshot_download(
repo_id=name,
etag_timeout=number_of_seconds_in_a_year,
resume_download=True,
repo_type="model",
library_name="pt",
# h5, tflite, safetensors, msgpack and ot models files are not needed
ignore_patterns=[
"*.h5",
"*.tflite",
"*.safetensors",
"*.msgpack",
"*.ot",
"*.md"
],
)
if __name__ == "__main__":
for model_name in AVAILABLE_MODELS:
download_pytorch_model(model_name)
|