Florian valade commited on
Commit
1b7b650
·
1 Parent(s): 48f630f

Initial push

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +37 -1
  3. app.py +123 -0
  4. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -9,4 +9,40 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # Early Exit Computational Savings Demo
13
+
14
+ ## Overview
15
+
16
+ This project demonstrates the concept of "early exiting" in deep learning models to save computational resources without significantly compromising on model performance. Early exit strategies allow a neural network to make predictions at intermediate layers for easy-to-classify instances, thus reducing the overall computation time and resources needed for inference. The application is built to run with Streamlit, offering an interactive web interface to explore the functionalities of the early exit model.
17
+
18
+ ## Features
19
+
20
+ - **BranchyModel:** An implementation of a deep learning model with early exit points. This model architecture is designed to evaluate the performance and computational savings of using early exits.
21
+ - **Utility Functions:** A set of utilities to support the model's operation, including data preprocessing and performance evaluation metrics.
22
+ - **Streamlit Application:** A user-friendly web interface to interact with the model, visualize its performance, and understand the benefits of early exits.
23
+
24
+ ## Getting Started
25
+
26
+ ### Prerequisites
27
+
28
+ Ensure you have Python 3.x installed on your machine. You can install all the required dependencies via:
29
+
30
+ ```bash
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ ### Running the Application
35
+
36
+ To run the Streamlit application, execute the following command from the root of the project:
37
+
38
+ ```bash
39
+ streamlit run app.py
40
+ ```
41
+
42
+ The command will start a local web server and open the application in your default web browser, allowing you to interact with the BranchyModel and explore its features.
43
+
44
+ ## Project Structure
45
+
46
+ - **app.py**: The main Streamlit application script.
47
+ - **requirements.txt**: Lists all the Python dependencies required by the project.
48
+ - **src/**: Contains the source code for the project.
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save this as app.py and run with `streamlit run app.py`
2
+ import time
3
+ import streamlit as st
4
+ import torch
5
+ import pandas as pd
6
+
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from typer import clear
9
+ from annotated_text import annotated_text
10
+
11
+ st.title("Multi-Head LLM Demo")
12
+ st.markdown("""This is a demo of a multi-head language model with early exit capabilities.
13
+ The model is based on the Phi-2 architecture and model is available here : https://huggingface.co/valcore/Branchy-Phi-2.
14
+ \nThe model has four heads, each of which can be exited early based on a threshold. The graph show the depth of early exit for each token (the deeper being the faster) and the time taken to generate each token.
15
+ Early exited tokens are annotated with the depth of early exit (with a float smaller than 1, 1 being the deepest)
16
+ """)
17
+
18
+ def annotated_to_normal(text):
19
+ result = ""
20
+ for elem in text:
21
+ if isinstance(elem, tuple):
22
+ result += elem[0]
23
+ else:
24
+ result += elem
25
+ return result
26
+
27
+ def generate_next_token():
28
+ print(f"Generating next token from {st.session_state.messages}")
29
+ inputs = ""
30
+ for message in st.session_state.messages:
31
+ inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
32
+ inputs += "Assistant:"
33
+ print(f"Inputs: {inputs}")
34
+ inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt")
35
+ for i in range(50):
36
+ start = time.time()
37
+ outputs = st.session_state.model(inputs)
38
+ stop = time.time()
39
+ next_token_logits = outputs.logits[:, -1, :].squeeze()
40
+ next_token_probs = torch.softmax(next_token_logits, dim=-1)
41
+ next_token_id = torch.argmax(next_token_probs, dim=-1)
42
+ if next_token_id == 50256:
43
+ break
44
+ print(inputs.shape, next_token_id.shape)
45
+ inputs = torch.cat([inputs, next_token_id.unsqueeze(0).unsqueeze(-1)], dim=-1)
46
+ next_token = st.session_state.tokenizer.decode(next_token_id, return_tensors="pt")
47
+ time_taken = stop - start
48
+ branch_locations = st.session_state.model.config.branch_locations
49
+ print(outputs.head_indices)
50
+ if outputs.head_indices in branch_locations:
51
+ print(sorted(branch_locations, reverse=True))
52
+ early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
53
+ else:
54
+ early_exit = 0
55
+ # Add data to dataframe
56
+ new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
57
+ st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
58
+ yield next_token, early_exit
59
+
60
+ @st.cache_resource
61
+ def load_model(model_str, tokenizer_str):
62
+ model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True)
63
+ model.eval()
64
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
65
+ return model, tokenizer
66
+
67
+ model_str = "valcore/Branchy-Phi-2"
68
+ tokenizer_str = "microsoft/Phi-2"
69
+
70
+ if "model" not in st.session_state or "tokenizer" not in st.session_state:
71
+ print("Loading model...")
72
+ st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str)
73
+
74
+ # Initialize chat history and dataframe
75
+ if "messages" not in st.session_state:
76
+ st.session_state.messages = []
77
+ st.session_state.data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth"])
78
+
79
+ col1, col2 = st.columns([1, 4])
80
+
81
+ with col1:
82
+ early_exit = st.checkbox("Early exit", value=False)
83
+ if early_exit:
84
+ st.session_state.model.head_thresholds = [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703]
85
+ else:
86
+ st.session_state.model.head_thresholds = [10., 10., 10., 10.]
87
+ clear_session = st.button("Clear session")
88
+ if clear_session:
89
+ print("Clearing session")
90
+ st.session_state.messages = []
91
+ st.session_state.data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth"])
92
+
93
+ with col2:
94
+ # Display chat messages from history on app rerun
95
+ for message in st.session_state.messages:
96
+ with st.chat_message(message["role"]):
97
+ annotated_text(message["content"])
98
+
99
+ prompt = st.chat_input("What is up?")
100
+ # React to user input
101
+ if prompt:
102
+ # Display user message in chat message container
103
+ with st.chat_message("User"):
104
+ st.markdown(prompt)
105
+ # Add user message to chat history
106
+ st.session_state.messages.append({"role": "User", "content": prompt})
107
+
108
+ # Display assistant response in chat message container
109
+ with st.chat_message("Assistant"):
110
+ response = []
111
+ with st.spinner('Running inference...'):
112
+ for next_token, early_exit in generate_next_token():
113
+ if early_exit > 0.0:
114
+ response.append(tuple((next_token, str(early_exit))))
115
+ else:
116
+ response.append(next_token)
117
+ print(response)
118
+ annotated_text(response)
119
+
120
+ # Add assistant response to chat history
121
+ st.session_state.messages.append({"role": "Assistant", "content": response})
122
+ st.line_chart(st.session_state.data, x=None, y=["Time taken (in ms)", "Early exit depth"])
123
+ print(st.session_state.messages)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit==1.31.0
2
+ torch==2.0.1
3
+ pandas==2.0.3
4
+ transformers==4.36.0
5
+ st-annotated-text