groq-llama3 / app.py
dromerosm's picture
Add image upload functionality and base64 encoding in app.py
ff51e6e
import os
from dotenv import find_dotenv, load_dotenv
import streamlit as st
from groq import Groq
import base64
# Load environment variables
load_dotenv(find_dotenv())
# Function to encode the image to a base64 string
def encode_image(uploaded_file):
"""
Encodes an uploaded image file into a base64 string.
Args:
uploaded_file: The file-like object uploaded via Streamlit.
Returns:
str: The base64 encoded string of the image.
"""
return base64.b64encode(uploaded_file.read()).decode('utf-8')
# Initialize the Groq client using the API key from the environment variables
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Set up Streamlit page configuration
st.set_page_config(
page_icon="πŸ“ƒ",
layout="wide",
page_title="Groq & LLaMA3x Chat Bot"
)
# App Title
st.title("Groq Chat with LLaMA3x")
# Cache the model fetching function to improve performance
@st.cache_data
def fetch_available_models():
"""
Fetches the available models from the Groq API.
Returns a list of models or an empty list if there's an error.
"""
try:
models_response = client.models.list()
return models_response.data
except Exception as e:
st.error(f"Error fetching models: {e}")
return []
# Load available models and filter them
available_models = fetch_available_models()
filtered_models = [
model for model in available_models if model.id.startswith('llama-3')
]
# Prepare a dictionary of model metadata
models = {
model.id: {
"name": model.id,
"tokens": 4000,
"developer": model.owned_by,
}
for model in filtered_models
}
# Initialize session state variables
if "messages" not in st.session_state:
st.session_state.messages = []
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
# Sidebar: Controls
with st.sidebar:
# Powered by Groq logo
st.markdown(
"""
<a href="https://groq.com" target="_blank" rel="noopener noreferrer">
<img
src="https://groq.com/wp-content/uploads/2024/03/PBG-mark1-color.svg"
alt="Powered by Groq for fast inference."
width="100%"
/>
</a>
""",
unsafe_allow_html=True
)
st.markdown("---")
# Define a function to clear messages when the model changes
def reset_chat_on_model_change():
st.session_state.messages = []
st.session_state.image_used = False
uploaded_file = None
base64_image = None
# Model selection dropdown
if models:
model_option = st.selectbox(
"Choose a model:",
options=list(models.keys()),
format_func=lambda x: f"{models[x]['name']} ({models[x]['developer']})",
on_change=reset_chat_on_model_change, # Reset chat when model changes
)
else:
st.warning("No available models to select.")
model_option = None
# Token limit slider
if models:
max_tokens_range = models[model_option]["tokens"]
max_tokens = st.slider(
"Max Tokens:",
min_value=200,
max_value=max_tokens_range,
value=max(100, int(max_tokens_range * 0.5)),
step=256,
help=f"Adjust the maximum number of tokens for the response. Maximum for the selected model: {max_tokens_range}"
)
else:
max_tokens = 200
# Additional options
stream_mode = st.checkbox("Enable Streaming", value=True)
# Button to clear the chat
if st.button("Clear Chat"):
st.session_state.messages = []
st.session_state.image_used = False
# Initialize session state for tracking uploaded image usage
if "image_used" not in st.session_state:
st.session_state.image_used = False # Flag to track image usage
# Check if the selected model supports vision
base64_image = None
uploaded_file = None
if model_option and "vision" in model_option.lower():
st.markdown(
"### Upload an Image"
"\n\n*One per conversation*"
)
# File uploader for images (only if image hasn't been used yet)
if not st.session_state.image_used:
uploaded_file = st.file_uploader(
"Upload an image for the model to process:",
type=["png", "jpg", "jpeg"],
help="Upload an image if the model supports vision tasks.",
accept_multiple_files=False
)
if uploaded_file:
base64_image = encode_image(uploaded_file)
st.image(uploaded_file, caption="Uploaded Image")
else:
base64_image = None
st.markdown("### Usage Summary")
usage_box = st.empty()
# Disclaimer
st.markdown(
"""
-----
⚠️ **Important:**
*The responses provided by this application are generated automatically using an AI model.
Users are responsible for verifying the accuracy of the information before relying on it.
Always cross-check facts and data for critical decisions.*
"""
)
# Main Chat Interface
st.markdown("### Chat Interface")
# Display the chat history
for message in st.session_state.messages:
avatar = "πŸ”‹" if message["role"] == "assistant" else "πŸ§‘β€πŸ’»"
with st.chat_message(message["role"], avatar=avatar):
# Check if the content is a list (text and image combined)
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
st.markdown(item["text"])
elif item["type"] == "image_url":
# Handle base64-encoded image URLs
if item["image_url"]["url"].startswith("data:image"):
st.image(item["image_url"]["url"], caption="Uploaded Image")
st.session_state.image_used = True
else:
st.warning("Invalid image format or unsupported URL.")
else:
# For plain text content
st.markdown(message["content"])
# Capture user input
if user_input:=st.chat_input("Enter your message here..."):
# Append the user input to the session state
# including the image if uploaded
if base64_image and not st.session_state.image_used:
# Append the user message with the image to session state
st.session_state.messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": user_input},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
},
},
],
}
)
st.session_state.image_used = True
else:
st.session_state.messages.append({"role": "user", "content": user_input})
# Display the uploaded image and user query in the chat
with st.chat_message("user", avatar="πŸ§‘β€πŸ’»"):
# Display the user input
st.markdown(user_input)
# Display the uploaded image only if it's included in the current message
if base64_image and st.session_state.image_used:
st.image(uploaded_file, caption="Uploaded Image")
base64_image = None
# Generate a response using the selected model
try:
full_response = ""
usage_summary = ""
if stream_mode:
# Generate a response with streaming enabled
chat_completion = client.chat.completions.create(
model=model_option,
messages=[
{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages
],
max_tokens=max_tokens,
stream=True
)
with st.chat_message("assistant", avatar="πŸ”‹"):
response_placeholder = st.empty()
for chunk in chat_completion:
if chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
response_placeholder.markdown(full_response)
else:
# Generate a response without streaming
chat_completion = client.chat.completions.create(
model=model_option,
messages=[
{"role": m["role"], "content": m["content"]}
for m in st.session_state.messages
],
max_tokens=max_tokens,
stream=False
)
response = chat_completion.choices[0].message.content
usage_data = chat_completion.usage
with st.chat_message("assistant", avatar="πŸ”‹"):
st.markdown(response)
full_response = response
if usage_data:
usage_summary = (
f"**Token Usage:**\n"
f"- Prompt Tokens: {usage_data.prompt_tokens}\n"
f"- Response Tokens: {usage_data.completion_tokens}\n"
f"- Total Tokens: {usage_data.total_tokens}\n\n"
f"**Timings:**\n"
f"- Prompt Time: {round(usage_data.prompt_time,5)} secs\n"
f"- Response Time: {round(usage_data.completion_time,5)} secs\n"
f"- Total Time: {round(usage_data.total_time,5)} secs"
)
if usage_summary:
usage_box.markdown(usage_summary)
# Append the assistant's response to the session state
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
except Exception as e:
st.error(f"Error generating the response: {e}")