rajsinghparihar
remove example caching in gradio interface
2bae954
import gradio as gr
from rag import RAG, ServiceContextModule
from llama_index.core import set_global_service_context
import json
from prompts import general_prompt
from gradio_pdf import PDF
import requests
service_context_module = None
current_model = None
def initialize(api_key, model_name):
global service_context_module, current_model
gr.Info("Initializing app")
url = "https://api.groq.com/openai/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
try:
response = requests.get(url, headers=headers)
data = response.json()
models = [model["id"] for model in data["data"]]
except Exception:
gr.Error("Invalid API KEY")
return gr.update(choices=[])
if not service_context_module or current_model != model_name:
try:
service_context_module = ServiceContextModule(api_key, model_name)
except Exception as e:
print(e)
current_model = model_name
gr.Info("App started")
set_global_service_context(
service_context=service_context_module.service_context
)
else:
gr.Info("App is already running")
return gr.update(choices=models)
def process_document(file, query):
if file.endswith(".pdf"):
return process_pdf(file, query=query)
else:
return "Unsupported file format"
def postprocess_json_string(json_string: str) -> dict:
json_string = json_string.replace("'", '"')
json_string = json_string[json_string.rfind("{") : json_string.rfind("}") + 1]
try:
json_data = json.loads(json_string)
except Exception as e:
print("Error parsing output, invalid json format", e)
return json_data
def process_pdf(file, query):
rag_module = RAG(filepaths=[file])
fields = [field for field in query.split(",")]
formatted_prompt = general_prompt(fields=fields)
response = rag_module.run_query_engine(prompt=formatted_prompt)
extracted_json = postprocess_json_string(json_string=response)
return extracted_json
with gr.Blocks(title="Document Information Extractor.") as app:
gr.Markdown(
value="""
# Welcome to Document Information Extractor.
Created by [@rajsinghparihar](https://huggingface.co/rajsinghparihar) for extracting useful information from pdf documents like invoices, salary slips, etc.
## Usage:
- In the Init Section, Enter your `GROQ_API_KEY` in the corresponding labeled textbox.
- choose the model from the list of available models.
- click `Initialize` to start the app.
- In the app section, you can upload a document (pdf files: currently works for readable pdfs only, will add ocr functionality later)
- Enter the entities you wanna extract as a comma seperated string. (check the examples for more info)
- Click Submit to see the extracted entities as a JSON object.
"""
)
with gr.Tab(label="Init Section") as init_tab:
with gr.Row():
api_key = gr.Text(
label="Enter your Groq API KEY",
type="password",
)
available_models = gr.Dropdown(
value="llama3-70b-8192",
label="Choose your LLM",
choices=[
"gemma-7b-it",
"llama3-70b-8192",
"llama3-8b-8192",
"mixtral-8x7b-32768",
"whisper-large-v3",
],
)
init_btn = gr.Button(value="Initialize")
init_btn.click(
fn=initialize,
inputs=[api_key, available_models],
outputs=available_models,
)
with gr.Tab(label="App Section") as app_tab:
iface = gr.Interface(
fn=process_document,
inputs=[
PDF(label="Document"),
gr.Text(
label="Entities you wanna extract in comma separated string format"
),
],
outputs=gr.JSON(label="Extracted Entities"),
description="Upload a PDF document and extract specified entities from it.",
examples=[
[
"examples/Commerce Bank Statement Sample.pdf",
"Customer Name, Account Number, Statement Date, Ending Balance, Total Deposits, Checks Paid",
],
[
"examples/Salary-Slip-pdf.pdf",
"Employee Name, Bank Name, Location, Total Salary, Total Deductions",
],
],
cache_examples="lazy",
)
gr.Markdown("""
## Pros of LLMs as information extractors over current extraction solutions:
- LLMs are able to understand the scope of the problem from the context and are more robust to typos or extraction failure
## Cons
- Higher Inference Cost
- Can't use free APIs for Sensitive documents.
""")
app.launch(server_name="0.0.0.0", server_port=7860)