ghana-streamlit / app.py
poemsforaphrodite's picture
Update app.py
5fe0bae verified
import streamlit as st
import PyPDF2
import io
import os
from dotenv import load_dotenv
from pinecone import Pinecone, ServerlessSpec
from openai import OpenAI
import uuid
import re
import time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Load environment variables from .env file
load_dotenv()
# Initialize OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Initialize Pinecone
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_ENVIRONMENT = os.getenv("PINECONE_ENVIRONMENT")
INDEX_NAME = "ghana"
EMBEDDING_MODEL = "text-embedding-3-large"
EMBEDDING_DIMENSION = 3072
# Initialize Pinecone
pc = Pinecone(api_key=PINECONE_API_KEY)
# Check if the index exists
if INDEX_NAME not in pc.list_indexes().names():
# Create the index with updated dimensions
pc.create_index(
name=INDEX_NAME,
dimension=EMBEDDING_DIMENSION,
metric="cosine",
spec=ServerlessSpec(
cloud=PINECONE_ENVIRONMENT.split('-')[0], # Assuming environment is in format 'gcp-starter'
region=PINECONE_ENVIRONMENT.split('-')[1]
)
)
else:
# Optionally, verify the existing index's dimension matches
existing_index = pc.describe_index(INDEX_NAME)
if existing_index.dimension != EMBEDDING_DIMENSION:
raise ValueError(f"Existing index '{INDEX_NAME}' has dimension {existing_index.dimension}, expected {EMBEDDING_DIMENSION}. Please choose a different index name or adjust accordingly.")
# Connect to the Pinecone index
index = pc.Index(INDEX_NAME)
def transcribe_pdf(pdf_file):
print("Starting PDF transcription...")
# Read PDF and extract text
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_file))
text = ""
for page in pdf_reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
print(f"Extracted {len(text)} characters from PDF.")
# Dynamic Chunking
chunks = dynamic_chunking(text, max_tokens=500, overlap=50)
print(f"Created {len(chunks)} chunks from the extracted text.")
# Process chunks one by one
progress_bar = st.progress(0)
for i, chunk in enumerate(chunks):
print(f"Processing chunk {i+1}/{len(chunks)}...")
# Generate embedding for the chunk
embedding = get_embedding(chunk)
# Prepare upsert data
upsert_data = [(str(uuid.uuid4()), embedding, {"text": chunk})]
# Upsert to Pinecone
print(f"Upserting vector to Pinecone index '{INDEX_NAME}'...")
index.upsert(vectors=upsert_data)
# Update progress bar
progress = (i + 1) / len(chunks)
progress_bar.progress(progress)
# Optional: Add a small delay to avoid potential rate limits
time.sleep(0.5)
progress_bar.empty()
return f"Successfully processed and upserted {len(chunks)} chunks to Pinecone index '{INDEX_NAME}'."
def dynamic_chunking(text, max_tokens=200, overlap=100):
print(f"Starting dynamic chunking with max_tokens={max_tokens} and overlap={overlap}...")
tokens = re.findall(r'\S+', text)
chunks = []
start = 0
while start < len(tokens):
end = start + max_tokens
chunk = ' '.join(tokens[start:end])
chunks.append(chunk)
start += max_tokens - overlap
print(f"Dynamic chunking complete. Created {len(chunks)} chunks.")
return chunks
def get_embedding(chunk):
print("Generating embedding for chunk...")
try:
response = client.embeddings.create(
input=chunk, # Now we can pass the chunk directly
model=EMBEDDING_MODEL
)
embedding = response.data[0].embedding
print("Successfully generated embedding.")
return embedding
except Exception as e:
print(f"Error during embedding generation: {str(e)}")
raise e
def clear_database():
print("Clearing the Pinecone index...")
try:
index.delete(delete_all=True)
return "Successfully cleared all vectors from the Pinecone index."
except Exception as e:
print(f"Error clearing the Pinecone index: {str(e)}")
return f"Error clearing the Pinecone index: {str(e)}"
def query_database(query_text):
print(f"Querying database with: {query_text}")
try:
query_embedding = get_embedding(query_text)
results = index.query(vector=query_embedding, top_k=5, include_metadata=True)
context = ""
for match in results['matches']:
metadata = match.get('metadata', {})
text = metadata.get('text', '')
context += f"{text}\n\n"
if not context:
return "No relevant information found in the database."
return generate_answer(query_text, context)
except Exception as e:
print(f"Error querying the database: {str(e)}")
return f"Error querying the database: {str(e)}"
def generate_answer(query, context):
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are an assistant for the Ghana Labor Act. Use the provided context to answer the user's question accurately and concisely."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
]
)
return response.choices[0].message.content
except Exception as e:
print(f"Error generating answer: {str(e)}")
return f"Error generating answer: {str(e)}"
def generate_hr_document(document_type, additional_info):
print(f"Generating HR document: {document_type}")
try:
prompt = f"""Generate a professional {document_type} for an HR department.
Additional information: {additional_info}
Important: Format the response as plain text, not markdown. Use appropriate line breaks and spacing for readability."""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are an expert HR assistant. Generate a professional HR document based on the given type and additional information. Format the response as plain text, not markdown."},
{"role": "user", "content": prompt}
]
)
return response.choices[0].message.content
except Exception as e:
print(f"Error generating HR document: {str(e)}")
return f"Error generating HR document: {str(e)}"
def calculate_paye(annual_income):
tax_bands = [
(5880, 0),
(1320, 0.05),
(1560, 0.10),
(38000, 0.175),
(192000, 0.25),
(366240, 0.30),
(float('inf'), 0.35)
]
remaining_income = annual_income
total_tax = 0
for band, rate in tax_bands:
if remaining_income <= 0:
break
taxable_amount = min(band, remaining_income)
tax = taxable_amount * rate
total_tax += tax
remaining_income -= taxable_amount
return total_tax
def calculate_ssnit(basic_salary):
return basic_salary * 0.055
def main():
st.set_page_config(page_title="HR Document Assistant", layout="wide")
# Create a header with logo and title
col1, col2 = st.columns([1, 4])
with col1:
st.image("logo.png", width=200) # Adjust the width as needed
with col2:
st.title("HR Document Assistant")
tab1, tab2, tab3, tab4, tab5 = st.tabs([
"📤 Upload PDF",
"🔍 Query Database",
"📝 Generate HR Document",
"🧮 Tax Calculator",
"🗑️ Clear Database"
])
with tab1:
st.header("Upload PDF")
st.write("Upload a PDF file to extract its text content, chunk it dynamically, and upsert the chunks to the Pinecone index.")
pdf_file = st.file_uploader("Upload PDF", type="pdf")
if st.button("📥 Transcribe and Upsert"):
if pdf_file is not None:
with st.spinner("Processing PDF..."):
result = transcribe_pdf(pdf_file.read())
st.success(result)
else:
st.error("Please upload a PDF file first.")
with tab2:
st.header("Query Database")
st.write("Enter a query about the Ghana Labor Act.")
query = st.text_input("Enter your query", placeholder="What does the Act say about...?")
if st.button("🔎 Get Answer"):
answer = query_database(query)
st.markdown("### Answer:")
st.write(answer)
with tab3:
st.header("Generate HR Document")
st.write("Select an HR document type and provide additional information to generate the document.")
document_types = [
"Employment Contract", "Offer Letter", "Job Description", "Employee Handbook",
"Performance Review Form", "Disciplinary Action Form", "Leave Request Form",
"Onboarding Checklist", "Termination Letter", "Non-Disclosure Agreement (NDA)",
"Code of Conduct", "Workplace Policy", "Benefits Summary", "Compensation Plan",
"Training and Development Plan", "Resignation Letter", "Exit Interview Form",
"Employee Grievance Form", "Time-off Request Form", "Workplace Safety Guidelines"
]
selected_document = st.selectbox("Select HR Document Type", document_types)
additional_info = st.text_area(
"Additional Information",
placeholder="Enter any specific details or requirements for the document..."
)
if st.button("✍️ Generate Document"):
with st.spinner("Generating document..."):
document = generate_hr_document(selected_document, additional_info)
st.subheader(f"Generated {selected_document}")
st.text_area("Document Content", value=document, height=400)
st.download_button(
label="Download Document",
data=document,
file_name=f"{selected_document.lower().replace(' ', '_')}.txt",
mime="text/plain"
)
with tab4:
st.header("Tax Calculator")
st.write("Calculate PAYE and SSNIT contributions based on annual income and basic salary.")
salary_examples = {
"Entry Level": (36000, 30000),
"Mid Level": (72000, 60000),
"Senior Level": (120000, 90000),
"Executive": (240000, 180000)
}
selected_example = st.selectbox(
"Select a salary example or enter custom values:",
["Custom"] + list(salary_examples.keys())
)
if selected_example == "Custom":
annual_income = st.number_input("Annual Income (GH₵)", min_value=0.0, value=0.0, step=1000.0)
basic_salary = st.number_input("Basic Salary (GH₵)", min_value=0.0, value=0.0, step=1000.0)
else:
annual_income, basic_salary = salary_examples[selected_example]
st.write(f"Annual Income: GH₵ {annual_income:.2f}")
st.write(f"Basic Salary: GH₵ {basic_salary:.2f}")
if st.button("Calculate Taxes"):
ssnit_contribution = calculate_ssnit(basic_salary)
taxable_income = annual_income - ssnit_contribution
paye = calculate_paye(taxable_income)
net_income = annual_income - ssnit_contribution - paye
col1, col2 = st.columns(2)
with col1:
st.subheader("Tax Breakdown")
st.write(f"SSNIT Contribution: GH₵ {ssnit_contribution:.2f}")
st.write(f"PAYE: GH₵ {paye:.2f}")
st.write(f"Total Deductions: GH₵ {(ssnit_contribution + paye):.2f}")
st.write(f"Net Income: GH₵ {net_income:.2f}")
with col2:
# Pie chart for income breakdown
fig, ax = plt.subplots(figsize=(3, 2))
sizes = [ssnit_contribution, paye, net_income]
labels = ['SSNIT', 'PAYE', 'Net']
colors = ['#ff9999', '#66b3ff', '#99ff99']
ax.pie(
sizes, labels=labels, colors=colors,
autopct='%1.1f%%', startangle=90, textprops={'fontsize': 6}
)
ax.axis('equal')
plt.title("Income Breakdown", fontsize=8)
st.pyplot(fig)
# Display tax rates by income bracket as a table
st.subheader("Tax Rates by Income Bracket")
tax_data = {
"Income Range (GH₵)": [
"0 - 5,880", "5,881 - 7,200", "7,201 - 8,760",
"8,761 - 46,760", "46,761 - 238,760",
"238,761 - 605,000", "Above 605,000"
],
"Rate (%)": [0, 5, 10, 17.5, 25, 30, 35]
}
df = pd.DataFrame(tax_data)
st.table(df)
with tab5:
st.header("Clear Database")
st.write("Use this option carefully. It will remove all data from the Pinecone index.")
if st.button("🗑️ Clear Database"):
result = clear_database()
st.success(result)
st.markdown("""
### 📌 Note
- Ensure you have the necessary API keys set up for OpenAI and Pinecone.
- The PDF upload process may take some time depending on the file size.
- Generated HR documents are based on AI and may require human review.
""")
if __name__ == "__main__":
main()