|
import os |
|
|
|
import urllib3 |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
def change_default_timeout(new_timeout: int) -> None: |
|
""" |
|
Changes the default timeout for downloading repositories from the Hugging Face Hub. |
|
Prevents the following errors: |
|
urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443): |
|
Read timed out. (read timeout=10) |
|
""" |
|
urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout |
|
|
|
|
|
def download_pytorch_model(name: str) -> None: |
|
""" |
|
Downloads a pytorch model and all the small files from the model's repository. |
|
Other model formats (tensorflow, tflite, safetensors, msgpack and ot) are not downloaded. |
|
""" |
|
number_of_seconds_in_a_day: int = 86_400 |
|
change_default_timeout(number_of_seconds_in_a_day) |
|
curr_folder: str = os.path.dirname(__file__) |
|
snapshot_download( |
|
cache_dir=os.path.join(curr_folder, "huggingface", "models"), |
|
repo_id=name, |
|
etag_timeout=number_of_seconds_in_a_day, |
|
resume_download=True, |
|
repo_type="model", |
|
library_name="pt", |
|
|
|
ignore_patterns=[ |
|
"*.h5", |
|
"*.tflite", |
|
"*.safetensors", |
|
"*.msgpack", |
|
"*.ot", |
|
"*.md" |
|
], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
download_pytorch_model("facebook/opt-125m") |
|
|