Spaces:
Running
on
Zero
Running
on
Zero
Florian valade
commited on
Commit
·
97675ea
1
Parent(s):
ad27be1
Update graph to be more understandable
Browse files
app.py
CHANGED
@@ -3,7 +3,10 @@ import time
|
|
3 |
import streamlit as st
|
4 |
import torch
|
5 |
import pandas as pd
|
|
|
|
|
6 |
|
|
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from typer import clear
|
9 |
from annotated_text import annotated_text
|
@@ -24,14 +27,14 @@ def annotated_to_normal(text):
|
|
24 |
result += elem
|
25 |
return result
|
26 |
|
27 |
-
def generate_next_token():
|
28 |
print(f"Generating next token from {st.session_state.messages}")
|
29 |
inputs = ""
|
30 |
for message in st.session_state.messages:
|
31 |
inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
|
32 |
inputs += "Assistant:"
|
33 |
print(f"Inputs: {inputs}")
|
34 |
-
inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt")
|
35 |
for i in range(50):
|
36 |
start = time.time()
|
37 |
outputs = st.session_state.model(inputs)
|
@@ -51,25 +54,26 @@ def generate_next_token():
|
|
51 |
print(sorted(branch_locations, reverse=True))
|
52 |
early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
|
53 |
else:
|
54 |
-
early_exit =
|
55 |
# Add data to dataframe
|
56 |
new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
|
57 |
st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
|
58 |
yield next_token, early_exit
|
59 |
|
60 |
@st.cache_resource
|
61 |
-
def load_model(model_str, tokenizer_str):
|
62 |
-
model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True)
|
63 |
model.eval()
|
64 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
|
65 |
return model, tokenizer
|
66 |
|
67 |
model_str = "valcore/Branchy-Phi-2"
|
68 |
tokenizer_str = "microsoft/Phi-2"
|
|
|
69 |
|
70 |
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
71 |
-
print("Loading model
|
72 |
-
st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str)
|
73 |
|
74 |
# Initialize chat history and dataframe
|
75 |
if "messages" not in st.session_state:
|
@@ -109,8 +113,8 @@ with col2:
|
|
109 |
with st.chat_message("Assistant"):
|
110 |
response = []
|
111 |
with st.spinner('Running inference...'):
|
112 |
-
for next_token, early_exit in generate_next_token():
|
113 |
-
if early_exit > 0.0:
|
114 |
response.append(tuple((next_token, str(early_exit))))
|
115 |
else:
|
116 |
response.append(next_token)
|
@@ -119,5 +123,54 @@ with col2:
|
|
119 |
|
120 |
# Add assistant response to chat history
|
121 |
st.session_state.messages.append({"role": "Assistant", "content": response})
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
print(st.session_state.messages)
|
|
|
3 |
import streamlit as st
|
4 |
import torch
|
5 |
import pandas as pd
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
import numpy as np
|
8 |
|
9 |
+
from plotly.subplots import make_subplots
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
from typer import clear
|
12 |
from annotated_text import annotated_text
|
|
|
27 |
result += elem
|
28 |
return result
|
29 |
|
30 |
+
def generate_next_token(device="cpu"):
|
31 |
print(f"Generating next token from {st.session_state.messages}")
|
32 |
inputs = ""
|
33 |
for message in st.session_state.messages:
|
34 |
inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
|
35 |
inputs += "Assistant:"
|
36 |
print(f"Inputs: {inputs}")
|
37 |
+
inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt").to(device)
|
38 |
for i in range(50):
|
39 |
start = time.time()
|
40 |
outputs = st.session_state.model(inputs)
|
|
|
54 |
print(sorted(branch_locations, reverse=True))
|
55 |
early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
|
56 |
else:
|
57 |
+
early_exit = 1.25
|
58 |
# Add data to dataframe
|
59 |
new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
|
60 |
st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
|
61 |
yield next_token, early_exit
|
62 |
|
63 |
@st.cache_resource
|
64 |
+
def load_model(model_str, tokenizer_str, device="cpu"):
|
65 |
+
model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device)
|
66 |
model.eval()
|
67 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
|
68 |
return model, tokenizer
|
69 |
|
70 |
model_str = "valcore/Branchy-Phi-2"
|
71 |
tokenizer_str = "microsoft/Phi-2"
|
72 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
73 |
|
74 |
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
75 |
+
print(f"Loading model on {device}")
|
76 |
+
st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str, device)
|
77 |
|
78 |
# Initialize chat history and dataframe
|
79 |
if "messages" not in st.session_state:
|
|
|
113 |
with st.chat_message("Assistant"):
|
114 |
response = []
|
115 |
with st.spinner('Running inference...'):
|
116 |
+
for next_token, early_exit in generate_next_token(device):
|
117 |
+
if early_exit > 0.0 and early_exit != 1.25:
|
118 |
response.append(tuple((next_token, str(early_exit))))
|
119 |
else:
|
120 |
response.append(next_token)
|
|
|
123 |
|
124 |
# Add assistant response to chat history
|
125 |
st.session_state.messages.append({"role": "Assistant", "content": response})
|
126 |
+
|
127 |
+
# Assuming st.session_state.data is a pandas DataFrame
|
128 |
+
df = st.session_state.data
|
129 |
+
|
130 |
+
# Calculate the max time taken and add a 10% margin
|
131 |
+
max_time = df["Time taken (in ms)"].max()
|
132 |
+
time_axis_max = max_time * 1.1 # 10% margin
|
133 |
+
|
134 |
+
|
135 |
+
# Create figure with secondary y-axis
|
136 |
+
fig = make_subplots(specs=[[{"secondary_y": True}]])
|
137 |
+
|
138 |
+
# Add traces
|
139 |
+
fig.add_trace(
|
140 |
+
go.Scatter(x=df.index, y=df["Time taken (in ms)"], name="Time taken (in ms)"),
|
141 |
+
secondary_y=False,
|
142 |
+
)
|
143 |
+
|
144 |
+
fig.add_trace(
|
145 |
+
go.Scatter(x=df.index, y=df["Early exit depth"], name="Early exit depth"),
|
146 |
+
secondary_y=True,
|
147 |
+
)
|
148 |
+
|
149 |
+
# Set x-axis title
|
150 |
+
fig.update_xaxes(title_text="Index")
|
151 |
+
|
152 |
+
# Set y-axes titles
|
153 |
+
fig.update_yaxes(
|
154 |
+
title_text="Time taken (in ms)",
|
155 |
+
secondary_y=False,
|
156 |
+
range=[0, time_axis_max],
|
157 |
+
tickmode='linear',
|
158 |
+
dtick=np.ceil(time_axis_max / 5 / 10) * 10 # Round to nearest 10
|
159 |
+
)
|
160 |
+
fig.update_yaxes(
|
161 |
+
title_text="Early exit depth",
|
162 |
+
secondary_y=True,
|
163 |
+
range=[0, 1.25],
|
164 |
+
tickmode='linear',
|
165 |
+
dtick=0.25
|
166 |
+
)
|
167 |
+
|
168 |
+
fig.update_layout(
|
169 |
+
title_text="Time Taken vs Early Exit Depth",
|
170 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
|
171 |
+
)
|
172 |
+
# Use Streamlit to display the Plotly chart
|
173 |
+
st.plotly_chart(fig)
|
174 |
+
|
175 |
+
#st.line_chart(st.session_state.data, x=None, y=["Time taken (in ms)", "Early exit depth"])
|
176 |
print(st.session_state.messages)
|