Vokturz commited on
Commit
fddae32
1 Parent(s): 2d9aa2d

cache default model

Browse files
Files changed (1) hide show
  1. src/app.py +11 -11
src/app.py CHANGED
@@ -22,6 +22,10 @@ st.markdown(
22
  def get_gpu_specs():
23
  return pd.read_csv("data/gpu_specs.csv")
24
 
 
 
 
 
25
 
26
  def show_gpu_info(info, trainable_params=0):
27
  for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
@@ -46,13 +50,6 @@ def get_name(index):
46
  row = gpu_specs.iloc[index]
47
  return f"{row['Product Name']} ({row['RAM (GB)']} GB, {row['Year']})"
48
 
49
- def create_plot(memory_table, y, title, container):
50
- fig = px.bar(memory_table, x=memory_table.index, y=y, color_continuous_scale="RdBu_r")
51
- fig.update_layout(yaxis_title="Number of GPUs", title=dict(text=title, font=dict(size=25)))
52
- fig.update_coloraxes(showscale=False)
53
-
54
- container.plotly_chart(fig, use_container_width=True)
55
-
56
  gpu_specs = get_gpu_specs()
57
 
58
  access_token = st.sidebar.text_input("Access token")
@@ -61,16 +58,19 @@ if not model_name:
61
  st.info("Please enter a model name")
62
  st.stop()
63
 
64
-
65
-
66
  model_name = extract_from_url(model_name)
67
  if model_name not in st.session_state:
68
  if 'actual_model' in st.session_state:
69
  del st.session_state[st.session_state['actual_model']]
70
  del st.session_state['actual_model']
71
  gc.collect()
72
- model = get_model(model_name, library="transformers", access_token=access_token)
73
- st.session_state[model_name] = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
 
 
 
 
 
74
  st.session_state['actual_model'] = model_name
75
 
76
 
 
22
  def get_gpu_specs():
23
  return pd.read_csv("data/gpu_specs.csv")
24
 
25
+ @st.cache_resource
26
+ def get_mistralai_table():
27
+ model = get_model("mistralai/Mistral-7B-v0.1", library="transformers", access_token="")
28
+ return calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
29
 
30
  def show_gpu_info(info, trainable_params=0):
31
  for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
 
50
  row = gpu_specs.iloc[index]
51
  return f"{row['Product Name']} ({row['RAM (GB)']} GB, {row['Year']})"
52
 
 
 
 
 
 
 
 
53
  gpu_specs = get_gpu_specs()
54
 
55
  access_token = st.sidebar.text_input("Access token")
 
58
  st.info("Please enter a model name")
59
  st.stop()
60
 
 
 
61
  model_name = extract_from_url(model_name)
62
  if model_name not in st.session_state:
63
  if 'actual_model' in st.session_state:
64
  del st.session_state[st.session_state['actual_model']]
65
  del st.session_state['actual_model']
66
  gc.collect()
67
+ if model_name == "mistralai/Mistral-7B-v0.1": # cache Mistral
68
+ st.session_state[model_name] = get_mistralai_table()
69
+ else:
70
+ model = get_model(model_name, library="transformers", access_token=access_token)
71
+ st.session_state[model_name] = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
72
+ del model
73
+ gc.collect()
74
  st.session_state['actual_model'] = model_name
75
 
76