File size: 5,883 Bytes
7e9cb9e
5b2e6a5
 
97675ea
92603a4
97675ea
5b2e6a5
7e9cb9e
 
bda5ea2
7e9cb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92603a4
 
7e9cb9e
 
 
 
 
 
 
 
a0417ab
 
 
 
 
 
7e9cb9e
a0417ab
7e9cb9e
 
 
bda5ea2
7e9cb9e
bda5ea2
7e9cb9e
 
 
 
 
bda5ea2
97675ea
7e9cb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
97675ea
7e9cb9e
 
a0417ab
7e9cb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import torch
import pandas as pd
import plotly.graph_objects as go
import spaces
from plotly.subplots import make_subplots
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import numpy as np

# Load the model and tokenizer
model_str = "valcore/Branchy-Phi-2"
tokenizer_str = "microsoft/Phi-2"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)

# Initialize dataframe for storing token generation data
data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"])

# Define thresholds for different epsilon values
epsilon_thresholds = {
    0.4: [1.0307843685150146, 0.8693032264709473, 0.6637287139892578, 0.3111608028411865],
    0.5: [1.505380630493164, 1.5712471008300781, 1.1971790790557861, 0.6908178329467773],
    0.6: [2.0270779132843018, 1.8969502449035645, 1.4789371490478516, 0.9875392913818359],
    0.7: [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703],
    0.8: [3.3786778450012207, 2.568857192993164, 2.5665550231933594, 2.006620407104492],
    0.9: [3.187114715576172, 3.442272663116455, 2.636230945587158, 2.460529088973999],
    1.0: [10.0, 10.0, 10.0, 10.0]  # Effectively disable early exits
}

# Global variable to control generation
stop_generation = False

def create_plot():
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    
    fig.add_trace(
        go.Scatter(
            x=data.index,
            y=data["Time taken (in ms)"],
            name="Time taken (ms)",
            text=data["Token"],
            hovertemplate="<b>Token:</b> %{text}<br><b>Time:</b> %{y:.2f} ms<extra></extra>",
        ),
        secondary_y=False,
    )
    
    fig.add_trace(
        go.Scatter(
            x=data.index,
            y=data["Early exit depth"],
            name="Early exit depth",
            text=data["Token"],
            hovertemplate="<b>Token:</b> %{text}<br><b>Depth:</b> %{y:.2f}<extra></extra>",
        ),
        secondary_y=True,
    )
    
    fig.update_layout(
        title_text="Token Generation Metrics",
        xaxis_title="Token Index",
        yaxis_title="Time (ms)",
        yaxis2_title="Exit Depth",
        hovermode="closest",
    )
    
    fig.update_yaxes(range=[0, 1.1], secondary_y=True)
    
    return fig

def truncate_context(input_ids, max_length=2048):
    if len(input_ids[0]) > max_length:
        return input_ids[:, -max_length:]
    return input_ids
    
@spaces.GPU
def generate_response(message, chat_history, epsilon):
    global data, stop_generation
    data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth", "Token"])
    stop_generation = False
    
    # Set model thresholds based on epsilon
    model.head_thresholds = torch.tensor(epsilon_thresholds[epsilon])
    
    # Format the prompt with chat history
    formatted_prompt = ""
    for user_msg, assistant_msg in chat_history:
        formatted_prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
    formatted_prompt += f"User: {message}\nAssistant:"
    
    full_response = ""
    inputs = tokenizer.encode(formatted_prompt, return_tensors="pt").to(device)
    
    while not stop_generation:
        inputs = truncate_context(inputs)
        start = time.time()
        outputs = model(inputs)
        stop = time.time()
        
        next_token_logits = outputs.logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1)
        
        if next_token_id.item() == tokenizer.eos_token_id:
            break
        
        inputs = torch.cat([inputs, next_token_id.unsqueeze(0)], dim=-1)
        next_token = tokenizer.decode(next_token_id)
        full_response += next_token
        
        time_taken = (stop - start) * 1000  # Convert to milliseconds
        branch_locations = model.config.branch_locations
        early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations) if outputs.head_indices in branch_locations else 1.0
        
        new_row = pd.DataFrame({
            "Time taken (in ms)": [time_taken],
            "Early exit depth": [early_exit],
            "Token": [next_token]
        })
        data = pd.concat([data, new_row], ignore_index=True)
        
        new_history = chat_history + [(message, full_response)]
        yield new_history, new_history, gr.update(value=create_plot())
        
def stop_gen():
    global stop_generation
    stop_generation = True
    return gr.update(interactive=False)

with gr.Blocks() as demo:
    gr.Markdown("# Multi-Head LLM Demo with Early Exit Capabilities 🤗")
    gr.Markdown("""This is a demo of a multi-head language model with early exit capabilities. 
                The model is based on the Phi-2 architecture and is available here: https://huggingface.co/valcore/Branchy-Phi-2.
                The model has four heads, each of which can be exited early based on a threshold. The graph shows the depth of early exit for each token and the time taken to generate each token.
                Use the slider to adjust the early exit threshold. Lower values allow for more early exits, potentially speeding up generation at the cost of accuracy.
                """)
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Message")
    epsilon = gr.Slider(minimum=0.4, maximum=1.0, value=0.7, step=0.1, label="Epsilon")
    
    with gr.Row():
        send = gr.Button("Send")
        stop = gr.Button("Stop Generation")
    
    graph = gr.Plot()
    
    send.click(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph])
    msg.submit(generate_response, inputs=[msg, chatbot, epsilon], outputs=[chatbot, chatbot, graph])
    stop.click(stop_gen, outputs=[stop])

demo.queue().launch()