grouped-sampling-demo / on_server_start.py
yonikremer's picture
reimplemented model downloading
e67f273
raw
history blame
928 Bytes
"""
A script that is run when the server starts.
"""
from huggingface_hub import snapshot_download
def download_model(model_name: str):
"""
Downloads a model from hugging face hub to the disk but not to the RAM.
:param model_name: The name of the model to download.
"""
number_of_seconds_in_a_day: int = 86_400
snapshot_download(
repo_id=model_name,
etag_timeout=number_of_seconds_in_a_day,
resume_download=True,
)
def download_useful_models():
"""
Downloads the models that are useful for this project.
So that the user doesn't have to wait for the models to download when they first use the app.
"""
print("Downloading useful models...")
useful_models = (
"facebook/opt-125m",
)
for model_name in useful_models:
download_model(model_name)
def main():
download_useful_models()
if __name__ == "__main__":
main()