Spaces:
Sleeping
Sleeping
Florian valade
commited on
Commit
·
1b7b650
1
Parent(s):
48f630f
Initial push
Browse files- .gitignore +1 -0
- README.md +37 -1
- app.py +123 -0
- requirements.txt +5 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
README.md
CHANGED
@@ -9,4 +9,40 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Early Exit Computational Savings Demo
|
13 |
+
|
14 |
+
## Overview
|
15 |
+
|
16 |
+
This project demonstrates the concept of "early exiting" in deep learning models to save computational resources without significantly compromising on model performance. Early exit strategies allow a neural network to make predictions at intermediate layers for easy-to-classify instances, thus reducing the overall computation time and resources needed for inference. The application is built to run with Streamlit, offering an interactive web interface to explore the functionalities of the early exit model.
|
17 |
+
|
18 |
+
## Features
|
19 |
+
|
20 |
+
- **BranchyModel:** An implementation of a deep learning model with early exit points. This model architecture is designed to evaluate the performance and computational savings of using early exits.
|
21 |
+
- **Utility Functions:** A set of utilities to support the model's operation, including data preprocessing and performance evaluation metrics.
|
22 |
+
- **Streamlit Application:** A user-friendly web interface to interact with the model, visualize its performance, and understand the benefits of early exits.
|
23 |
+
|
24 |
+
## Getting Started
|
25 |
+
|
26 |
+
### Prerequisites
|
27 |
+
|
28 |
+
Ensure you have Python 3.x installed on your machine. You can install all the required dependencies via:
|
29 |
+
|
30 |
+
```bash
|
31 |
+
pip install -r requirements.txt
|
32 |
+
```
|
33 |
+
|
34 |
+
### Running the Application
|
35 |
+
|
36 |
+
To run the Streamlit application, execute the following command from the root of the project:
|
37 |
+
|
38 |
+
```bash
|
39 |
+
streamlit run app.py
|
40 |
+
```
|
41 |
+
|
42 |
+
The command will start a local web server and open the application in your default web browser, allowing you to interact with the BranchyModel and explore its features.
|
43 |
+
|
44 |
+
## Project Structure
|
45 |
+
|
46 |
+
- **app.py**: The main Streamlit application script.
|
47 |
+
- **requirements.txt**: Lists all the Python dependencies required by the project.
|
48 |
+
- **src/**: Contains the source code for the project.
|
app.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Save this as app.py and run with `streamlit run app.py`
|
2 |
+
import time
|
3 |
+
import streamlit as st
|
4 |
+
import torch
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
+
from typer import clear
|
9 |
+
from annotated_text import annotated_text
|
10 |
+
|
11 |
+
st.title("Multi-Head LLM Demo")
|
12 |
+
st.markdown("""This is a demo of a multi-head language model with early exit capabilities.
|
13 |
+
The model is based on the Phi-2 architecture and model is available here : https://huggingface.co/valcore/Branchy-Phi-2.
|
14 |
+
\nThe model has four heads, each of which can be exited early based on a threshold. The graph show the depth of early exit for each token (the deeper being the faster) and the time taken to generate each token.
|
15 |
+
Early exited tokens are annotated with the depth of early exit (with a float smaller than 1, 1 being the deepest)
|
16 |
+
""")
|
17 |
+
|
18 |
+
def annotated_to_normal(text):
|
19 |
+
result = ""
|
20 |
+
for elem in text:
|
21 |
+
if isinstance(elem, tuple):
|
22 |
+
result += elem[0]
|
23 |
+
else:
|
24 |
+
result += elem
|
25 |
+
return result
|
26 |
+
|
27 |
+
def generate_next_token():
|
28 |
+
print(f"Generating next token from {st.session_state.messages}")
|
29 |
+
inputs = ""
|
30 |
+
for message in st.session_state.messages:
|
31 |
+
inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
|
32 |
+
inputs += "Assistant:"
|
33 |
+
print(f"Inputs: {inputs}")
|
34 |
+
inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt")
|
35 |
+
for i in range(50):
|
36 |
+
start = time.time()
|
37 |
+
outputs = st.session_state.model(inputs)
|
38 |
+
stop = time.time()
|
39 |
+
next_token_logits = outputs.logits[:, -1, :].squeeze()
|
40 |
+
next_token_probs = torch.softmax(next_token_logits, dim=-1)
|
41 |
+
next_token_id = torch.argmax(next_token_probs, dim=-1)
|
42 |
+
if next_token_id == 50256:
|
43 |
+
break
|
44 |
+
print(inputs.shape, next_token_id.shape)
|
45 |
+
inputs = torch.cat([inputs, next_token_id.unsqueeze(0).unsqueeze(-1)], dim=-1)
|
46 |
+
next_token = st.session_state.tokenizer.decode(next_token_id, return_tensors="pt")
|
47 |
+
time_taken = stop - start
|
48 |
+
branch_locations = st.session_state.model.config.branch_locations
|
49 |
+
print(outputs.head_indices)
|
50 |
+
if outputs.head_indices in branch_locations:
|
51 |
+
print(sorted(branch_locations, reverse=True))
|
52 |
+
early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
|
53 |
+
else:
|
54 |
+
early_exit = 0
|
55 |
+
# Add data to dataframe
|
56 |
+
new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
|
57 |
+
st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
|
58 |
+
yield next_token, early_exit
|
59 |
+
|
60 |
+
@st.cache_resource
|
61 |
+
def load_model(model_str, tokenizer_str):
|
62 |
+
model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True)
|
63 |
+
model.eval()
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
|
65 |
+
return model, tokenizer
|
66 |
+
|
67 |
+
model_str = "valcore/Branchy-Phi-2"
|
68 |
+
tokenizer_str = "microsoft/Phi-2"
|
69 |
+
|
70 |
+
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
71 |
+
print("Loading model...")
|
72 |
+
st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str)
|
73 |
+
|
74 |
+
# Initialize chat history and dataframe
|
75 |
+
if "messages" not in st.session_state:
|
76 |
+
st.session_state.messages = []
|
77 |
+
st.session_state.data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth"])
|
78 |
+
|
79 |
+
col1, col2 = st.columns([1, 4])
|
80 |
+
|
81 |
+
with col1:
|
82 |
+
early_exit = st.checkbox("Early exit", value=False)
|
83 |
+
if early_exit:
|
84 |
+
st.session_state.model.head_thresholds = [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703]
|
85 |
+
else:
|
86 |
+
st.session_state.model.head_thresholds = [10., 10., 10., 10.]
|
87 |
+
clear_session = st.button("Clear session")
|
88 |
+
if clear_session:
|
89 |
+
print("Clearing session")
|
90 |
+
st.session_state.messages = []
|
91 |
+
st.session_state.data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth"])
|
92 |
+
|
93 |
+
with col2:
|
94 |
+
# Display chat messages from history on app rerun
|
95 |
+
for message in st.session_state.messages:
|
96 |
+
with st.chat_message(message["role"]):
|
97 |
+
annotated_text(message["content"])
|
98 |
+
|
99 |
+
prompt = st.chat_input("What is up?")
|
100 |
+
# React to user input
|
101 |
+
if prompt:
|
102 |
+
# Display user message in chat message container
|
103 |
+
with st.chat_message("User"):
|
104 |
+
st.markdown(prompt)
|
105 |
+
# Add user message to chat history
|
106 |
+
st.session_state.messages.append({"role": "User", "content": prompt})
|
107 |
+
|
108 |
+
# Display assistant response in chat message container
|
109 |
+
with st.chat_message("Assistant"):
|
110 |
+
response = []
|
111 |
+
with st.spinner('Running inference...'):
|
112 |
+
for next_token, early_exit in generate_next_token():
|
113 |
+
if early_exit > 0.0:
|
114 |
+
response.append(tuple((next_token, str(early_exit))))
|
115 |
+
else:
|
116 |
+
response.append(next_token)
|
117 |
+
print(response)
|
118 |
+
annotated_text(response)
|
119 |
+
|
120 |
+
# Add assistant response to chat history
|
121 |
+
st.session_state.messages.append({"role": "Assistant", "content": response})
|
122 |
+
st.line_chart(st.session_state.data, x=None, y=["Time taken (in ms)", "Early exit depth"])
|
123 |
+
print(st.session_state.messages)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.31.0
|
2 |
+
torch==2.0.1
|
3 |
+
pandas==2.0.3
|
4 |
+
transformers==4.36.0
|
5 |
+
st-annotated-text
|