taupirho's picture
Update app.py
d7d5023 verified
import gradio as gr
from huggingface_hub import InferenceClient
import os
import groq
import warnings
import asyncio
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# A warning may appear which doesn't
# affect the operation of the code
# Suppress it with this code
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")
# Global variables
index = None
query_engine = None
# Initialize Groq LLM and ensure it is used
llm = Groq(model="mixtral-8x7b-32768")
Settings.llm = llm # Ensure Groq is the LLM being used
# Initialize our chosen embedding model
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
# These are our RAG functions, called in response to user
# initiated events e.g clicking the Load Documents button
# on the GUI
#
def load_documents(file_objs):
global index, query_engine
try:
if not file_objs:
return "Error: No files selected."
documents = []
document_names = []
for file_obj in file_objs:
document_names.append(file_obj.name)
loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data()
documents.extend(loaded_docs)
if not documents:
return "No documents found in the selected files."
# Create index from documents using Groq LLM and HuggingFace Embeddings
index = VectorStoreIndex.from_documents(
documents,
llm=llm, # Ensure Groq is used here
embed_model=embed_model
)
# Create query engine
query_engine = index.as_query_engine()
return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}"
except Exception as e:
return f"Error loading documents: {str(e)}"
async def perform_rag(query, history):
global query_engine
if query_engine is None:
return history + [("Please load documents first.", None)]
try:
response = await asyncio.to_thread(query_engine.query, query)
return history + [(query, str(response))]
except Exception as e:
return history + [(query, f"Error processing query: {str(e)}")]
def clear_all():
global index, query_engine
index = None
query_engine = None
return None, "", [], "" # Reset file input, load output, chatbot, and message input to default states
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# RAG Multi-file, Multi-format Chat Application")
with gr.Row():
file_input = gr.File(label="Select files to load (txt,csv,xlsx,docx,PDF)", file_count="multiple")
load_btn = gr.Button("Load Documents")
load_output = gr.Textbox(label="Load Status")
msg = gr.Textbox(label="Enter your question")
chatbot = gr.Chatbot()
clear = gr.Button("Clear")
# Set up event handlers
load_btn.click(load_documents, inputs=[file_input], outputs=[load_output])
msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot])
clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False)
# Run the app
if __name__ == "__main__":
demo.queue()
demo.launch()