|
""" |
|
A script that is run when the server starts. |
|
""" |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
from transformers import logging as transformers_logging |
|
from huggingface_hub import logging as huggingface_hub_logging |
|
|
|
from download_repo import download_pytorch_model |
|
|
|
|
|
def disable_progress_bar(): |
|
""" |
|
Disables the progress bar when downloading models. |
|
""" |
|
transformers_logging.disable_progress_bar() |
|
huggingface_hub_logging.disable_propagation() |
|
|
|
|
|
def download_useful_models(): |
|
""" |
|
Downloads the models that are useful for this project. |
|
So that the user doesn't have to wait for the models to download when they first use the app. |
|
""" |
|
print("Downloading useful models...") |
|
useful_models = ( |
|
"facebook/opt-125m", |
|
"facebook/opt-iml-max-30b", |
|
) |
|
with ThreadPoolExecutor() as executor: |
|
executor.map(download_pytorch_model, useful_models) |
|
|
|
|
|
async def main(): |
|
disable_progress_bar() |
|
download_useful_models() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|