Florian valade commited on
Commit
97675ea
·
1 Parent(s): ad27be1

Update graph to be more understandable

Browse files
Files changed (1) hide show
  1. app.py +63 -10
app.py CHANGED
@@ -3,7 +3,10 @@ import time
3
  import streamlit as st
4
  import torch
5
  import pandas as pd
 
 
6
 
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from typer import clear
9
  from annotated_text import annotated_text
@@ -24,14 +27,14 @@ def annotated_to_normal(text):
24
  result += elem
25
  return result
26
 
27
- def generate_next_token():
28
  print(f"Generating next token from {st.session_state.messages}")
29
  inputs = ""
30
  for message in st.session_state.messages:
31
  inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
32
  inputs += "Assistant:"
33
  print(f"Inputs: {inputs}")
34
- inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt")
35
  for i in range(50):
36
  start = time.time()
37
  outputs = st.session_state.model(inputs)
@@ -51,25 +54,26 @@ def generate_next_token():
51
  print(sorted(branch_locations, reverse=True))
52
  early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
53
  else:
54
- early_exit = 0
55
  # Add data to dataframe
56
  new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
57
  st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
58
  yield next_token, early_exit
59
 
60
  @st.cache_resource
61
- def load_model(model_str, tokenizer_str):
62
- model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True)
63
  model.eval()
64
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
65
  return model, tokenizer
66
 
67
  model_str = "valcore/Branchy-Phi-2"
68
  tokenizer_str = "microsoft/Phi-2"
 
69
 
70
  if "model" not in st.session_state or "tokenizer" not in st.session_state:
71
- print("Loading model...")
72
- st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str)
73
 
74
  # Initialize chat history and dataframe
75
  if "messages" not in st.session_state:
@@ -109,8 +113,8 @@ with col2:
109
  with st.chat_message("Assistant"):
110
  response = []
111
  with st.spinner('Running inference...'):
112
- for next_token, early_exit in generate_next_token():
113
- if early_exit > 0.0:
114
  response.append(tuple((next_token, str(early_exit))))
115
  else:
116
  response.append(next_token)
@@ -119,5 +123,54 @@ with col2:
119
 
120
  # Add assistant response to chat history
121
  st.session_state.messages.append({"role": "Assistant", "content": response})
122
- st.line_chart(st.session_state.data, x=None, y=["Time taken (in ms)", "Early exit depth"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  print(st.session_state.messages)
 
3
  import streamlit as st
4
  import torch
5
  import pandas as pd
6
+ import plotly.graph_objects as go
7
+ import numpy as np
8
 
9
+ from plotly.subplots import make_subplots
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
  from typer import clear
12
  from annotated_text import annotated_text
 
27
  result += elem
28
  return result
29
 
30
+ def generate_next_token(device="cpu"):
31
  print(f"Generating next token from {st.session_state.messages}")
32
  inputs = ""
33
  for message in st.session_state.messages:
34
  inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
35
  inputs += "Assistant:"
36
  print(f"Inputs: {inputs}")
37
+ inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt").to(device)
38
  for i in range(50):
39
  start = time.time()
40
  outputs = st.session_state.model(inputs)
 
54
  print(sorted(branch_locations, reverse=True))
55
  early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
56
  else:
57
+ early_exit = 1.25
58
  # Add data to dataframe
59
  new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
60
  st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
61
  yield next_token, early_exit
62
 
63
  @st.cache_resource
64
+ def load_model(model_str, tokenizer_str, device="cpu"):
65
+ model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device)
66
  model.eval()
67
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
68
  return model, tokenizer
69
 
70
  model_str = "valcore/Branchy-Phi-2"
71
  tokenizer_str = "microsoft/Phi-2"
72
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
73
 
74
  if "model" not in st.session_state or "tokenizer" not in st.session_state:
75
+ print(f"Loading model on {device}")
76
+ st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str, device)
77
 
78
  # Initialize chat history and dataframe
79
  if "messages" not in st.session_state:
 
113
  with st.chat_message("Assistant"):
114
  response = []
115
  with st.spinner('Running inference...'):
116
+ for next_token, early_exit in generate_next_token(device):
117
+ if early_exit > 0.0 and early_exit != 1.25:
118
  response.append(tuple((next_token, str(early_exit))))
119
  else:
120
  response.append(next_token)
 
123
 
124
  # Add assistant response to chat history
125
  st.session_state.messages.append({"role": "Assistant", "content": response})
126
+
127
+ # Assuming st.session_state.data is a pandas DataFrame
128
+ df = st.session_state.data
129
+
130
+ # Calculate the max time taken and add a 10% margin
131
+ max_time = df["Time taken (in ms)"].max()
132
+ time_axis_max = max_time * 1.1 # 10% margin
133
+
134
+
135
+ # Create figure with secondary y-axis
136
+ fig = make_subplots(specs=[[{"secondary_y": True}]])
137
+
138
+ # Add traces
139
+ fig.add_trace(
140
+ go.Scatter(x=df.index, y=df["Time taken (in ms)"], name="Time taken (in ms)"),
141
+ secondary_y=False,
142
+ )
143
+
144
+ fig.add_trace(
145
+ go.Scatter(x=df.index, y=df["Early exit depth"], name="Early exit depth"),
146
+ secondary_y=True,
147
+ )
148
+
149
+ # Set x-axis title
150
+ fig.update_xaxes(title_text="Index")
151
+
152
+ # Set y-axes titles
153
+ fig.update_yaxes(
154
+ title_text="Time taken (in ms)",
155
+ secondary_y=False,
156
+ range=[0, time_axis_max],
157
+ tickmode='linear',
158
+ dtick=np.ceil(time_axis_max / 5 / 10) * 10 # Round to nearest 10
159
+ )
160
+ fig.update_yaxes(
161
+ title_text="Early exit depth",
162
+ secondary_y=True,
163
+ range=[0, 1.25],
164
+ tickmode='linear',
165
+ dtick=0.25
166
+ )
167
+
168
+ fig.update_layout(
169
+ title_text="Time Taken vs Early Exit Depth",
170
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
171
+ )
172
+ # Use Streamlit to display the Plotly chart
173
+ st.plotly_chart(fig)
174
+
175
+ #st.line_chart(st.session_state.data, x=None, y=["Time taken (in ms)", "Early exit depth"])
176
  print(st.session_state.messages)