# Save this as app.py and run with `streamlit run app.py` import streamlit as st import torch import pandas as pd from transformers import AutoModelForCausalLM, AutoTokenizer from src.utils import generate_next_token, breaking_ties from src.BranchyModel import BranchyModel st.title("Multi-Head LLM Demo") def add_and_run(token, head): # Update pd with Head and mean of previous heads and actual head head_list = st.session_state["computation_pd"]["Head"].to_list() + [head] mean = sum(head_list) / len(head_list) st.session_state["computation_pd"] = pd.concat([st.session_state["computation_pd"], pd.DataFrame({"Head": [head], "Mean": [mean], "Base model consumption": [st.session_state['head_number']]})], ignore_index=True) st.session_state['current_sentence'] += token _, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence']) def reset(): st.session_state['computation_pd'] = pd.DataFrame(columns=["Head", "Mean", "Base model consumption"]) st.session_state['current_sentence'] = "The climate in" _, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence']) @st.cache_resource def load_model(penalty_alpha): penalty_map = {0.1:"model_20240118-144039.bin", 0.5:"model_20240118-192548.bin", 2:"model_20240118-211943.bin", 5:"model_20240118-231333.bin", 10:"model_20240119-010725.bin", 20:"model_20240119-030115.bin", 0:"model_20240119-135506.bin", 1:"model_20240119-154900.bin", -20: "model_20240208-072350.bin", -10: "model_20240208-052958.bin", -5: "model_20240208-033606.bin", -2: "model_20240208-014211.bin", -1: "model_20240207-234817.bin", -0.5: "model_20240207-215423.bin", -0.1: "model_20240207-200020.bin"} model_str = "susnato/phi-1_5_dev" model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1") tokenizer = AutoTokenizer.from_pretrained(model_str) branch_locations = list(range(0, 23, 5)) model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1") # Load the specific model based on penalty_alpha model_path = penalty_map.get(penalty_alpha) if model_path: model.load_state_dict(torch.load(model_path, map_location="cuda:1")) else: print("Invalid penalty_alpha. Using default model weights.") return model, tokenizer if "model" not in st.session_state or "tokenizer" not in st.session_state: print("Loading model...") st.session_state.model, st.session_state.tokenizer = load_model(penalty_alpha=-2) # Example penalty_alpha st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1 print(f"Head number: {st.session_state['head_number']}") # Session state to store the current sentence if 'current_sentence' not in st.session_state: reset() # Create a container to hold the buttons cols = st.columns(len(st.session_state.head_tokens)) # Create a column for each token # Iterate through each head token and create a button in a separate column for i, (col, token) in enumerate(zip(cols, st.session_state.head_tokens)): col.button(f"{st.session_state['head_tokens'][i]}", key=f"head_{i}", use_container_width=True, on_click=add_and_run, args=(st.session_state['head_tokens'][i], i)) # Display the current sentence st.markdown(f"{st.session_state['current_sentence']}") # Reset button to start over st.button('Reset', on_click=reset) if 'computation_pd' in st.session_state: st.line_chart(st.session_state['computation_pd']) # get last element from a pd saved_budget = 100 - ((st.session_state["computation_pd"]["Mean"].iloc[-1] * 100) / st.session_state["computation_pd"]["Base model consumption"].iloc[-1]) st.markdown(f"You saved **{saved_budget:.2f}%** of the base model consumption.") #st.write(st.session_state['computation_pd'])