Spaces:
Runtime error
Runtime error
Vokturz
commited on
Commit
•
fddae32
1
Parent(s):
2d9aa2d
cache default model
Browse files- src/app.py +11 -11
src/app.py
CHANGED
@@ -22,6 +22,10 @@ st.markdown(
|
|
22 |
def get_gpu_specs():
|
23 |
return pd.read_csv("data/gpu_specs.csv")
|
24 |
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def show_gpu_info(info, trainable_params=0):
|
27 |
for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
|
@@ -46,13 +50,6 @@ def get_name(index):
|
|
46 |
row = gpu_specs.iloc[index]
|
47 |
return f"{row['Product Name']} ({row['RAM (GB)']} GB, {row['Year']})"
|
48 |
|
49 |
-
def create_plot(memory_table, y, title, container):
|
50 |
-
fig = px.bar(memory_table, x=memory_table.index, y=y, color_continuous_scale="RdBu_r")
|
51 |
-
fig.update_layout(yaxis_title="Number of GPUs", title=dict(text=title, font=dict(size=25)))
|
52 |
-
fig.update_coloraxes(showscale=False)
|
53 |
-
|
54 |
-
container.plotly_chart(fig, use_container_width=True)
|
55 |
-
|
56 |
gpu_specs = get_gpu_specs()
|
57 |
|
58 |
access_token = st.sidebar.text_input("Access token")
|
@@ -61,16 +58,19 @@ if not model_name:
|
|
61 |
st.info("Please enter a model name")
|
62 |
st.stop()
|
63 |
|
64 |
-
|
65 |
-
|
66 |
model_name = extract_from_url(model_name)
|
67 |
if model_name not in st.session_state:
|
68 |
if 'actual_model' in st.session_state:
|
69 |
del st.session_state[st.session_state['actual_model']]
|
70 |
del st.session_state['actual_model']
|
71 |
gc.collect()
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
st.session_state['actual_model'] = model_name
|
75 |
|
76 |
|
|
|
22 |
def get_gpu_specs():
|
23 |
return pd.read_csv("data/gpu_specs.csv")
|
24 |
|
25 |
+
@st.cache_resource
|
26 |
+
def get_mistralai_table():
|
27 |
+
model = get_model("mistralai/Mistral-7B-v0.1", library="transformers", access_token="")
|
28 |
+
return calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
|
29 |
|
30 |
def show_gpu_info(info, trainable_params=0):
|
31 |
for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
|
|
|
50 |
row = gpu_specs.iloc[index]
|
51 |
return f"{row['Product Name']} ({row['RAM (GB)']} GB, {row['Year']})"
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
gpu_specs = get_gpu_specs()
|
54 |
|
55 |
access_token = st.sidebar.text_input("Access token")
|
|
|
58 |
st.info("Please enter a model name")
|
59 |
st.stop()
|
60 |
|
|
|
|
|
61 |
model_name = extract_from_url(model_name)
|
62 |
if model_name not in st.session_state:
|
63 |
if 'actual_model' in st.session_state:
|
64 |
del st.session_state[st.session_state['actual_model']]
|
65 |
del st.session_state['actual_model']
|
66 |
gc.collect()
|
67 |
+
if model_name == "mistralai/Mistral-7B-v0.1": # cache Mistral
|
68 |
+
st.session_state[model_name] = get_mistralai_table()
|
69 |
+
else:
|
70 |
+
model = get_model(model_name, library="transformers", access_token=access_token)
|
71 |
+
st.session_state[model_name] = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
|
72 |
+
del model
|
73 |
+
gc.collect()
|
74 |
st.session_state['actual_model'] = model_name
|
75 |
|
76 |
|