|
""" |
|
A script that is run when the server starts. |
|
""" |
|
from download_repo import download_repository |
|
|
|
|
|
def disable_progress_bar(): |
|
""" |
|
Disables the progress bar when downloading models. |
|
""" |
|
import transformers |
|
transformers.logging.disable_progress_bar() |
|
|
|
|
|
def download_useful_models(): |
|
""" |
|
Downloads the models that are useful for this project. |
|
So that the user doesn't have to wait for the models to download when they first use the app. |
|
""" |
|
print("Downloading useful models...") |
|
useful_models = ( |
|
"facebook/opt-125m", |
|
) |
|
for model_name in useful_models: |
|
download_repository( |
|
model_name, |
|
|
|
) |
|
|
|
|
|
def main(): |
|
disable_progress_bar() |
|
download_useful_models() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|