File size: 4,453 Bytes
9c31040 b472ee2 9c31040 e68b75c 9c31040 b472ee2 9c31040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
from dotenv import find_dotenv, load_dotenv
import streamlit as st
from typing import Generator
from groq import Groq
_ = load_dotenv(find_dotenv())
st.set_page_config(page_icon="π", layout="wide", page_title="Groq & LLaMA3 Chat Bot...")
def icon(emoji: str):
"""Shows an emoji as a Notion-style page icon."""
st.write(
f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
unsafe_allow_html=True,
)
# icon("β‘οΈ")
st.subheader("Groq Chat with LLaMA3 App", divider="rainbow", anchor=False)
client = Groq(
api_key=os.environ['GROQ_API_KEY'],
)
# Initialize chat history and selected model
if "messages" not in st.session_state:
st.session_state.messages = []
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
# Define model details
models = {
"llama3-70b-8192": {"name": "LLaMA3-70b", "tokens": 8192, "developer": "Meta"},
"llama3-8b-8192": {"name": "LLaMA3-8b", "tokens": 8192, "developer": "Meta"},
"llama2-70b-4096": {"name": "LLaMA2-70b-chat", "tokens": 4096, "developer": "Meta"},
"gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
"mixtral-8x7b-32768": {
"name": "Mixtral-8x7b-Instruct-v0.1",
"tokens": 32768,
"developer": "Mistral",
},
}
# Layout for model selection and max_tokens slider
col1, col2 = st.columns([1, 3]) # Adjust the ratio to make the first column smaller
with col1:
model_option = st.selectbox(
"Choose a model:",
options=list(models.keys()),
format_func=lambda x: models[x]["name"],
index=0, # Default to the first model in the list
)
max_tokens_range = models[model_option]["tokens"]
max_tokens = st.slider(
"Max Tokens:",
min_value=512,
max_value=max_tokens_range,
value=min(32768, max_tokens_range),
step=512,
help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
)
system_message = {}
if system_prompt := st.text_input("System Prompt"):
system_message = {"role": "system", "content": system_prompt}
# Detect model change and clear chat history if model has changed
if st.session_state.selected_model != model_option:
st.session_state.messages = []
st.session_state.selected_model = model_option
# Add a "Clear Chat" button
if st.button("Clear Chat"):
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
avatar = "π€" if message["role"] == "assistant" else "π"
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
"""Yield chat response content from the Groq API response."""
for chunk in chat_completion:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
if prompt := st.chat_input("Enter your prompt here..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar="π"):
st.markdown(prompt)
messages=[
{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages]
if system_message:
messages.insert(0,system_message)
# Fetch response from Groq API
try:
chat_completion = client.chat.completions.create(
model=model_option,
messages=messages,
max_tokens=max_tokens,
stream=True,
)
# Use the generator function with st.write_stream
with st.chat_message("assistant", avatar="π€"):
chat_responses_generator = generate_chat_responses(chat_completion)
full_response = st.write_stream(chat_responses_generator)
except Exception as e:
st.error(e, icon="β")
# Append the full response to session_state.messages
if isinstance(full_response, str):
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
else:
# Handle the case where full_response is not a string
combined_response = "\n".join(str(item) for item in full_response)
st.session_state.messages.append(
{"role": "assistant", "content": combined_response}
) |