File size: 4,846 Bytes
2d5272c
a2568fb
2d5272c
98af8e4
9f80d5d
59546e8
58785bd
59546e8
6bd7eb3
2d5272c
59546e8
03784c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9077b
b81ad51
82b6e1c
03784c0
2d5272c
 
d10fcc7
57ddcad
 
03784c0
d10fcc7
d058e2f
57ddcad
 
03784c0
d10fcc7
d058e2f
57ddcad
03784c0
 
d10fcc7
d058e2f
03784c0
2d5272c
 
7f9077b
e0d022a
2d5272c
f362ef8
 
 
 
 
 
 
 
 
 
03784c0
e0d022a
2d5272c
f362ef8
 
 
03784c0
 
 
59546e8
f362ef8
 
 
 
 
 
 
 
 
03784c0
59546e8
d10fcc7
03784c0
 
d10fcc7
03784c0
59546e8
03784c0
 
 
 
 
 
 
 
59546e8
03784c0
 
 
 
 
9f80d5d
2d5272c
 
 
 
9f80d5d
2d5272c
 
03784c0
 
 
 
 
 
 
 
 
 
 
2d5272c
03784c0
 
 
 
 
 
 
e0d022a
03784c0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import streamlit as st
from openai import OpenAI
import os
import numpy as np
from dotenv import load_dotenv
import random

# Load environment variables
load_dotenv()

# Constants
MAX_TOKENS = 4000
DEFAULT_TEMPERATURE = 0.5

# Initialize the client
def initialize_client():
    api_key = os.environ.get('HUGGINGFACEHUB_API_TOKEN')
    if not api_key:
        st.error("HUGGINGFACEHUB_API_TOKEN not found in environment variables.")
        st.stop()
    return OpenAI(
        base_url="https://api-inference.huggingface.co/v1",
        api_key=api_key
    )

# Create supported models
model_links = {
    "Meta-Llama-3.1-8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
    "Gemma-7b-it": "google/gemma-7b-it",
}

# Model information including logos
model_info = {
    "Meta-Llama-3.1-8B": {
        'description': """The Llama (3.1) model is a **Large Language Model (LLM)** that's able to have question and answer interactions.
        \nIt was created by the [**Meta's AI**](https://llama.meta.com/) team and has over **8 billion parameters.**\n""",
        "logo": "llama_logo.gif",
    },
    "Mistral-7B-Instruct-v0.3": {
        'description': """The Mistral-7B-Instruct-v0.3 is an instruct-tuned version of Mistral-7B.
        \nIt was created by [**Mistral AI**](https://mistral.ai/) and has **7 billion parameters.**\n""",
        "logo": "mistrail.jpeg",
    },
    "Gemma-7b-it": {
        'description': """Gemma is a family of lightweight, state-of-the-art open models from Google.
        \nThe 7B-it variant is instruction-tuned and has **7 billion parameters.**\n""",
        "logo": "gemma.jpeg",
    }
}

# Random dog images for error message
random_dog_images = ["broken_llama3.jpeg"]

def reset_conversation():
    '''
    Resets Conversation
    '''
    st.session_state.conversation = []
    st.session_state.messages = []
    return None

st.sidebar.button('Reset Chat', on_click=reset_conversation) #Reset button
    
def main():
    st.header('Multi-Models')

    # Initialize client
    client = initialize_client()

    # Sidebar for model selection and temperature
    selected_model = st.sidebar.selectbox("Select Model", list(model_links.keys()))
    temperature = st.sidebar.slider('Select a temperature value', 0.0, 1.0, DEFAULT_TEMPERATURE)

    if "prev_option" not in st.session_state:
        st.session_state.prev_option = selected_model

    if st.session_state.prev_option != selected_model:
        st.session_state.messages = []
        # st.write(f"Changed to {selected_model}")
        st.session_state.prev_option = selected_model
        reset_conversation()

    st.markdown(f'_powered_ by ***:violet[{selected_model}]***')

    # Display model info and logo
    st.sidebar.write(f"You're now chatting with **{selected_model}**")
    st.sidebar.markdown(model_info[selected_model]['description'])
    st.sidebar.image(model_info[selected_model]['logo'], use_column_width=True)
    st.sidebar.markdown("*Generated content may be inaccurate or false.*")

    # Initialize chat history
    if "messages" not in st.session_state:
        st.session_state.messages = []

    # Display chat messages from history on app rerun
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    # Chat input and response
    if prompt := st.chat_input("Type message here..."):
        process_user_input(client, prompt, selected_model, temperature)

def process_user_input(client, prompt, selected_model, temperature):
    # Display user message
    with st.chat_message("user"):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Generate and display assistant response
    with st.chat_message("assistant"):
        try:
            stream = client.chat.completions.create(
                model=model_links[selected_model],
                messages=[
                    {"role": m["role"], "content": m["content"]}
                    for m in st.session_state.messages
                ],
                temperature=temperature,
                stream=True,
                max_tokens=MAX_TOKENS,
            )
            response = st.write_stream(stream)
        except Exception as e:
            handle_error(e)
            return

    st.session_state.messages.append({"role": "assistant", "content": response})

def handle_error(error):
    response = """😵‍💫 Looks like someone unplugged something!
    \n Either the model space is being updated or something is down."""
    st.write(response)
    random_dog_pick = random.choice(random_dog_images)
    st.image(random_dog_pick)
    st.write("This was the error message:")
    st.write(str(error))

if __name__ == "__main__":
    main()