Fislac_Bot / app-file.py
CamiloVega's picture
Upload 3 files
770d5ac verified
raw
history blame
12.6 kB
import spaces
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
import logging
import sys
import os
from accelerate import init_empty_weights
from typing import List, Dict
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Get HuggingFace token from environment variable
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
if not hf_token:
logger.error("HUGGINGFACE_TOKEN environment variable not set")
raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable")
# Constants
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
KNOWLEDGE_BASE_DIR = "knowledge_base"
class DocumentLoader:
"""Class to manage PDF document loading."""
@staticmethod
def load_pdfs(directory_path: str) -> List:
documents = []
pdf_files = [f for f in os.listdir(directory_path) if f.endswith('.pdf')]
for pdf_file in pdf_files:
pdf_path = os.path.join(directory_path, pdf_file)
try:
loader = PyPDFLoader(pdf_path)
pdf_documents = loader.load()
for doc in pdf_documents:
doc.metadata.update({
'title': pdf_file,
'type': 'technical' if 'Valencia' in pdf_file else 'qa',
'language': 'en',
'page': doc.metadata.get('page', 0)
})
documents.append(doc)
logger.info(f"Document {pdf_file} loaded successfully")
except Exception as e:
logger.error(f"Error loading {pdf_file}: {str(e)}")
return documents
class TextProcessor:
"""Class to process and split text into chunks."""
def __init__(self):
self.technical_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=200,
separators=["\n\n", "\n", ". ", " ", ""],
length_function=len
)
self.qa_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", ". ", " ", ""],
length_function=len
)
def process_documents(self, documents: List) -> List:
if not documents:
logger.warning("No documents to process")
return []
processed_chunks = []
for doc in documents:
splitter = self.technical_splitter if doc.metadata['type'] == 'technical' else self.qa_splitter
chunks = splitter.split_documents([doc])
processed_chunks.extend(chunks)
logger.info(f"Documents processed into {len(processed_chunks)} chunks")
return processed_chunks
class RAGSystem:
"""Main RAG system class."""
def __init__(self, model_name: str = MODEL_NAME):
self.model_name = model_name
self.embeddings = None
self.vector_store = None
self.qa_chain = None
self.tokenizer = None
self.model = None
def initialize_system(self):
"""Initialize complete RAG system."""
try:
logger.info("Starting RAG system initialization...")
# Load and process documents
loader = DocumentLoader()
documents = loader.load_pdfs(KNOWLEDGE_BASE_DIR)
processor = TextProcessor()
processed_chunks = processor.process_documents(documents)
# Initialize embeddings
self.embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large",
model_kwargs={'device': 'cuda'},
encode_kwargs={'normalize_embeddings': True}
)
# Create vector store
self.vector_store = FAISS.from_documents(
processed_chunks,
self.embeddings
)
# Initialize LLM
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
token=hf_token
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
token=hf_token,
device_map="auto"
)
# Create generation pipeline
pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=512,
temperature=0.1,
top_p=0.95,
repetition_penalty=1.15,
device_map="auto"
)
llm = HuggingFacePipeline(pipeline=pipe)
# Create prompt template
prompt_template = """
Context: {context}
Based on the context above, please provide a clear and concise answer to the following question.
If the information is not in the context, explicitly state so.
Question: {question}
"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
# Set up QA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(
search_kwargs={"k": 6}
),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
logger.info("RAG system initialized successfully")
except Exception as e:
logger.error(f"Error during RAG system initialization: {str(e)}")
raise
def generate_response(self, question: str) -> Dict:
"""Generate response for a given question."""
try:
result = self.qa_chain({"query": question})
response = {
'answer': result['result'],
'sources': []
}
for doc in result['source_documents']:
source = {
'title': doc.metadata.get('title', 'Unknown'),
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content,
'metadata': doc.metadata
}
response['sources'].append(source)
return response
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
raise
@spaces.GPU(duration=60)
def process_response(user_input: str, chat_history: List) -> tuple:
"""Process user input and generate response."""
try:
response = rag_system.generate_response(user_input)
# Clean and format response
answer = response['answer']
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
# Format sources
sources = set([source['title'] for source in response['sources'][:3]])
if sources:
answer += "\n\nπŸ“š Sources consulted:\n" + "\n".join([f"β€’ {source}" for source in sources])
chat_history.append((user_input, answer))
return chat_history
except Exception as e:
logger.error(f"Error in process_response: {str(e)}")
error_message = f"Sorry, an error occurred: {str(e)}"
chat_history.append((user_input, error_message))
return chat_history
# Initialize RAG system
logger.info("Initializing RAG system...")
rag_system = RAGSystem()
rag_system.initialize_system()
logger.info("RAG system initialization completed")
# Create Gradio interface
try:
logger.info("Creating Gradio interface...")
with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo:
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
<h1 style="color: #2d333a;">πŸ“Š FislacBot</h1>
<p style="color: #4a5568;">
AI Assistant specialized in fiscal analysis and FISLAC documentation
</p>
</div>
""")
chatbot = gr.Chatbot(
show_label=False,
container=True,
height=500,
bubble_full_width=True,
show_copy_button=True,
scale=2
)
with gr.Row():
message = gr.Textbox(
placeholder="πŸ’­ Type your question here...",
show_label=False,
container=False,
scale=8,
autofocus=True
)
clear = gr.Button("πŸ—‘οΈ Clear", size="sm", scale=1)
# Suggested questions
gr.HTML('<p style="color: #2d333a; font-weight: bold; margin: 20px 0 10px 0;">πŸ’‘ Suggested questions:</p>')
with gr.Row():
suggestion1 = gr.Button("What is FISLAC?", scale=1)
suggestion2 = gr.Button("What are the main modules of FISLAC?", scale=1)
with gr.Row():
suggestion3 = gr.Button("What macroeconomic variables are relevant for advanced economies?", scale=1)
suggestion4 = gr.Button("How does fiscal risk compare between emerging and advanced countries?", scale=1)
# Footer
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
background-color: #f8f9fa; border-radius: 10px;">
<div style="margin-bottom: 15px;">
<h3 style="color: #2d333a;">πŸ” About this assistant</h3>
<p style="color: #666; font-size: 14px;">
This bot uses RAG (Retrieval Augmented Generation) technology combining:
</p>
<ul style="list-style: none; color: #666; font-size: 14px;">
<li>πŸ”Ή LLM Engine: Llama-2-7b-chat-hf</li>
<li>πŸ”Ή Embeddings: multilingual-e5-large</li>
<li>πŸ”Ή Vector Store: FAISS</li>
</ul>
</div>
<div style="border-top: 1px solid #ddd; padding-top: 15px;">
<p style="color: #666; font-size: 14px;">
<strong>Current Knowledge Base:</strong><br>
β€’ Valencia et al. (2022) - "Assessing macro-fiscal risk for Latin American and Caribbean countries"<br>
β€’ FISLAC Technical Documentation
</p>
</div>
<div style="border-top: 1px solid #ddd; margin-top: 15px; padding-top: 15px;">
<p style="color: #666; font-size: 14px;">
Created by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/"
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>,
AI Consultant πŸ€–
</p>
</div>
</div>
""")
# Configure event handlers
def submit(user_input, chat_history):
return process_response(user_input, chat_history)
message.submit(submit, [message, chatbot], [chatbot])
clear.click(lambda: None, None, chatbot)
# Handle suggested questions
for btn in [suggestion1, suggestion2, suggestion3, suggestion4]:
btn.click(submit, [btn, chatbot], [chatbot])
logger.info("Gradio interface created successfully")
demo.launch()
except Exception as e:
logger.error(f"Error in Gradio interface creation: {str(e)}")
raise