""" | |
A script that is run when the server starts. | |
""" | |
from concurrent.futures import ThreadPoolExecutor | |
from transformers import logging as transformers_logging | |
from huggingface_hub import logging as huggingface_hub_logging | |
from available_models import AVAILABLE_MODELS | |
from download_repo import download_pytorch_model | |
def disable_progress_bar(): | |
""" | |
Disables the progress bar when downloading models. | |
""" | |
transformers_logging.disable_progress_bar() | |
huggingface_hub_logging.disable_propagation() | |
def download_useful_models(): | |
""" | |
Downloads the models that are useful for this project. | |
So that the user doesn't have to wait for the models to download when they first use the app. | |
""" | |
print("Downloading useful models. It might take a while...") | |
with ThreadPoolExecutor() as executor: | |
executor.map(download_pytorch_model, AVAILABLE_MODELS) | |
def main(): | |
# disable_progress_bar() | |
download_useful_models() | |
if __name__ == "__main__": | |
main() | |