File size: 3,699 Bytes
8b05694
2d5272c
6d0e9f5
315944a
6d0e9f5
 
2acdf87
6d0e9f5
8b05694
 
59546e8
6bd7eb3
2d5272c
59546e8
03784c0
8b05694
 
6d0e9f5
2acdf87
119ea18
6d0e9f5
 
 
 
03784c0
7f9077b
3815718
824e8e6
 
8f7757a
2d5272c
 
6d0e9f5
8b05694
 
ada655b
f362ef8
8b05694
 
 
 
 
 
 
6d0e9f5
 
03784c0
8b05694
 
6d0e9f5
03784c0
 
 
59546e8
f362ef8
 
 
 
8b05694
6d0e9f5
f362ef8
8b05694
f362ef8
03784c0
59546e8
8b05694
03784c0
 
59546e8
6d0e9f5
 
 
 
 
 
 
 
ada655b
03784c0
 
8b05694
03784c0
8b05694
9f80d5d
2d5272c
 
8b05694
 
9f80d5d
2d5272c
6d0e9f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2acdf87
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
import os
import torch
from openai import OpenAI
import numpy as np
import sys
from dotenv import load_dotenv
import random
from huggingface_hub import InferenceClient

# Load environment variables
load_dotenv()

# Constants
MAX_TOKENS = 4000
DEFAULT_TEMPERATURE = 0.5

# initialize the client

client = OpenAI(
  base_url="https://api-inference.huggingface.co/v1",
  api_key=os.environ.get('API_KEY')  # Replace with your token
)
    
# Create supported models
model_links = {
    "Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
    "Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
    "Gemma-2-27b-it": "google/gemma-2-27b-it",
    "Falcon-7b-Instruct": "tiiuae/falcon-7b-instruct",
}


# Random dog images for error message
random_dog_images = ["broken_llama3.jpeg"]

def reset_conversation():
    '''
    Resets Conversation
    '''
    st.session_state.conversation = []
    st.session_state.messages = []
    return None

st.sidebar.button('Reset Chat', on_click=reset_conversation) #Reset button
    
def main():
    st.header('Multi-Models')


    # Sidebar for model selection and temperature
    selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
    temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE)

    if "prev_option" not in st.session_state:
        st.session_state.prev_option = selected_model

    if st.session_state.prev_option != selected_model:
        st.session_state.messages = []
        # st.write(f"Changed to {selected_model}")
        st.session_state.prev_option = selected_model
        reset_conversation()

    st.markdown(f'_powered_ by ***:violet[{selected_model}]***')

    # Display model info and logo
    st.sidebar.write(f"You're now chatting with **{selected_model}**")
    st.sidebar.markdown("*Generated content may be inaccurate or false.*")

    # Initialize chat history
    if "messages" not in st.session_state:
        st.session_state.messages = []

    # Display chat messages from history on app rerun
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    # Chat input and response
    if prompt := st.chat_input("Type message here..."):
        process_user_input(client, prompt, selected_model, temperature)

def process_user_input(client, prompt, selected_model, temperature):
    # Display user message
    with st.chat_message("user"):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Generate and display assistant response
    with st.chat_message("assistant"):
        try:
            stream = client.chat.completions.create(
                model=model_links[selected_model],
                messages=[
                    {"role": m["role"], "content": m["content"]}
                    for m in st.session_state.messages
                ],
                temperature=temperature,
                stream=True,
                max_tokens=MAX_TOKENS,
            )
            response = st.write_stream(stream)
        except Exception as e:
            handle_error(e)
            return

    st.session_state.messages.append({"role": "assistant", "content": response})

def handle_error(error):
    response = """😵‍💫 Looks like someone unplugged something!
    \n Either the model space is being updated or something is down."""
    st.write(response)
    random_dog_pick = random.choice(random_dog_images)
    st.image(random_dog_pick)
    st.write("This was the error message:")
    st.write(str(error))

if __name__ == "__main__":
    main()