yonikremer
commited on
Commit
•
5967916
1
Parent(s):
70d3eba
checks if the model is downloaded before downloading it
Browse files- hanlde_form_submit.py +19 -6
hanlde_form_submit.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from time import time
|
2 |
|
3 |
import streamlit as st
|
@@ -11,6 +12,17 @@ from supported_models import get_supported_model_names
|
|
11 |
SUPPORTED_MODEL_NAMES = get_supported_model_names()
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
|
15 |
"""
|
16 |
Creates a pipeline with the given model name and group size.
|
@@ -18,12 +30,13 @@ def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine
|
|
18 |
:param group_size: The size of the groups to use.
|
19 |
:return: A pipeline with the given model name and group size.
|
20 |
"""
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
27 |
st.write(f"Starts creating pipeline with model: {model_name}")
|
28 |
pipeline_start_time = time()
|
29 |
pipeline = GroupedSamplingPipeLine(
|
|
|
1 |
+
import os
|
2 |
from time import time
|
3 |
|
4 |
import streamlit as st
|
|
|
12 |
SUPPORTED_MODEL_NAMES = get_supported_model_names()
|
13 |
|
14 |
|
15 |
+
def is_downloaded(model_name: str) -> bool:
|
16 |
+
"""
|
17 |
+
Checks if the model is downloaded.
|
18 |
+
:param model_name: The name of the model to check.
|
19 |
+
:return: True if the model is downloaded, False otherwise.
|
20 |
+
"""
|
21 |
+
models_dir = "/root/.cache/huggingface/hub"
|
22 |
+
model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}")
|
23 |
+
return os.path.isdir(model_dir)
|
24 |
+
|
25 |
+
|
26 |
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
|
27 |
"""
|
28 |
Creates a pipeline with the given model name and group size.
|
|
|
30 |
:param group_size: The size of the groups to use.
|
31 |
:return: A pipeline with the given model name and group size.
|
32 |
"""
|
33 |
+
if not is_downloaded(model_name):
|
34 |
+
download_repository_start_time = time()
|
35 |
+
st.write(f"Starts downloading model: {model_name} from the internet.")
|
36 |
+
download_repository(model_name)
|
37 |
+
download_repository_end_time = time()
|
38 |
+
download_time = download_repository_end_time - download_repository_start_time
|
39 |
+
st.write(f"Finished downloading model: {model_name} from the internet in {download_time} seconds.")
|
40 |
st.write(f"Starts creating pipeline with model: {model_name}")
|
41 |
pipeline_start_time = time()
|
42 |
pipeline = GroupedSamplingPipeLine(
|