Frag-dein-PDF / run.py
AFischer1985's picture
Update run.py
d8b9fc9 verified
raw
history blame
12.3 kB
###########################################################################################
# Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB
# Author: Andreas Fischer
# Date: October 10th, 2024
# Last update: October 14th, 2024
##########################################################################################
import os
import torch
from transformers import AutoTokenizer, AutoModel # chromaDB
from datetime import datetime, date #add_doc,
import chromadb #chromaDB
from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB
from chromadb.utils import embedding_functions #chromaDB
import ocrmypdf #convertPDF
from pypdf import PdfReader #convertPDF
import re #format_prompt
import gradio as gr # multimodal_response
from huggingface_hub import InferenceClient #multimodal_response
#---------------------------------------------------
# Specify models for text generation and embeddings
#---------------------------------------------------
myModel="mistralai/Mixtral-8x7b-instruct-v0.1"
#mod="mistralai/Mixtral-8x7b-instruct-v0.1"
#tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...")
#cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}]
#cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}]
#res=tok.apply_chat_template(cha)
#print(tok.decode(res))
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
device='cuda:0' if torch.cuda.is_available() else 'cpu'
jina.to(device) #cuda:0
print(device)
#-----------------
# ChromaDB-client
#-----------------
class JinaEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
embeddings = jina.encode(input) #max_length=2048
return(embeddings.tolist())
dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/"
onPrem = True if(os.path.exists(dbPath)) else False
if(onPrem==False): dbPath="/home/user/app/db/"
print(dbPath)
client = chromadb.PersistentClient(path=dbPath)
print(client.heartbeat())
print(client.get_version())
print(client.list_collections())
jina_ef=JinaEmbeddingFunction()
embeddingModel=jina_ef
databases=[(date.today(),"0")] # start a list of databases
#---------------------------------------------------------------------
# Function for formatting single message according to prompt template
#---------------------------------------------------------------------
def format_prompt0(message, history):
prompt = "<s>"
#for user_prompt, bot_response in history:
# prompt += f"[INST] {user_prompt} [/INST]"
# prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
#-------------------------------------------------------------------------
# Function for formatting multiturn-dialogue according to prompt template
#-------------------------------------------------------------------------
def format_prompt(message, history, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False):
if zeichenlimit is None: zeichenlimit=1000000000 # :-)
startOfString="<s>" #<s> [INST] U1 [/INST] A1</s> [INST] U2 [/INST] A2</s>
template0=" [INST] {system} [/INST]</s>"
template1=" [INST] {message} [/INST]"
template2=" {response}</s>"
prompt = ""
if RAGAddon is not None:
system += RAGAddon
if system is not None:
prompt += template0.format(system=system) #"<s>"
if history is not None:
for user_message, bot_response in history[-historylimit:]:
if user_message is None: user_message = ""
if bot_response is None: bot_response = ""
bot_response = re.sub("\n\n<details>((.|\n)*?)</details>","", bot_response) # remove RAG-compontents
if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering)
if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit])
if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit])
if message is not None: prompt += template1.format(message=message[:zeichenlimit])
if system2 is not None:
prompt += system2
return startOfString+prompt
#--------------------------------------------
# Function for converting pdf-files to text
#--------------------------------------------
def convertPDF(pdf_file, allow_ocr=False):
reader = PdfReader(pdf_file)
full_text = ""
page_list = []
def extract_text_from_pdf(reader):
full_text = ""
page_list = []
page_count = 1
for idx, page in enumerate(reader.pages):
text = page.extract_text()
if len(text) > 0:
page_list.append(text)
#full_text += f"---- Page {idx} ----\n" + text + "\n\n"
page_count += 1
return full_text.strip(), page_count, page_list
# Check if there are any images
image_count = sum(len(page.images) for page in reader.pages)
# If there are images and not much content, you may want to perform OCR on the document
if allow_ocr:
print(f"{image_count} Images")
if image_count > 0 and len(full_text) < 1000:
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
reader = PdfReader(out_pdf_file)
# Extract text:
full_text, page_count, page_list = extract_text_from_pdf(reader)
l = len(page_list)
print(f"{l} Pages")
# Extract metadata
metadata = {
"author": reader.metadata.author,
"creator": reader.metadata.creator,
"producer": reader.metadata.producer,
"subject": reader.metadata.subject,
"title": reader.metadata.title,
"image_count": image_count,
"page_count": page_count,
"char_count": len(full_text),
}
return page_list, full_text, metadata
#------------------------------------------
# Function for splitting text with overlap
#------------------------------------------
def split_with_overlap(text,chunk_size=3500, overlap=700):
chunks=[]
step=max(1,chunk_size-overlap)
for i in range(0,len(text),step):
end=min(i+chunk_size,len(text))
chunks.append(text[i:end])
return chunks
#---------------------------------------------------------------
# Function for adding docs to ChromaDB and/or return collection
#---------------------------------------------------------------
def add_doc(path, session):
print("def add_doc!")
print(path)
anhang=False
if(str.lower(path).endswith(".pdf") and os.path.exists(path)):
doc=convertPDF(path)
if(len(doc[0])>5):
gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing excerpt (first 5 pages)!")
else:
gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!")
doc="\n\n".join(doc[0][0:5])
anhang=True
else:
gr.Info("No PDF attached - answer based on DB_"+str(session)+".")
client = chromadb.PersistentClient(path=dbPath)
print(str(client.list_collections()))
print(str(session))
dbName="DB_"+str(session)
if(not "name="+dbName in str(client.list_collections())):
# client.delete_collection(name=dbName)
collection = client.create_collection(
name=dbName,
embedding_function=embeddingModel,
metadata={"hnsw:space": "cosine"})
else:
collection = client.get_collection(
name=dbName, embedding_function=embeddingModel)
if(anhang==True):
corpus=split_with_overlap(doc,3500,700)
print(len(corpus))
then = datetime.now()
x=collection.get(include=[])["ids"]
print(len(x))
if(len(x)==0):
chunkSize=40000
for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts
print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5)))
ids=list(range(i*chunkSize,(i*chunkSize+chunkSize)))
batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)]
textIDs=[str(id) for id in ids[0:len(batch)]]
ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID
collection.add(documents=batch, ids=ids,
metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids,
print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5)))
now = datetime.now()
gr.Info(f"Indexing complete!")
print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks
return(collection)
#split_with_overlap("test me if you can",2,1)
#--------------------------------------------------------
# Function for response to user queries and pot. addenda
#--------------------------------------------------------
def multimodal_response(message, history, dropdown, hfToken, request: gr.Request):
print("def multimodal response!")
if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
inferenceClient = InferenceClient(model=myModel, token=hfToken)
else:
inferenceClient = InferenceClient(myModel)
global databases
if request:
session=request.session_hash
else:
session="0"
length=str(len(history))
print(databases)
if(not databases[-1][1]==session):
databases.append((date.today(),session))
#print(databases)
query=message["text"]
if(len(message["files"])>0): # is there at least one file attached?
collection=add_doc(message["files"][0], session)
else: # otherwise, you still want to get the collection with the session-based db
collection=add_doc(message["text"], session)
client = chromadb.PersistentClient(path=dbPath)
print(str(client.list_collections()))
x=collection.get(include=[])["ids"]
context=collection.query(query_texts=[query], n_results=1)
context=["<context "+str(i)+"> "+str(c)+"</context "+str(i)+">" for i,c in enumerate(context["documents"][0])]
gr.Info("Kontext:\n"+str(context))
generate_kwargs = dict(
temperature=float(0.9),
max_new_tokens=5000,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
seed=42,
)
system="Mit Blick auf das folgende Gespräch und den relevanten Kontext, antworte auf die aktuelle Frage des Nutzers. "+\
"Antworte ausschließlich auf Basis der Informationen im Kontext.\n\nKontext:\n\n"+\
str("\n\n".join(context))
#"Given the following conversation, relevant context, and a follow up question, "+\
#"reply with an answer to the current question the user is asking. "+\
#"Return only your response to the question given the above information "+\
#"following the users instructions as needed.\n\nContext:"+\
print(system)
#formatted_prompt = format_prompt0(system+"\n"+query, history)
formatted_prompt = format_prompt(query, history,system=system)
print(formatted_prompt)
output = ""
try:
stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
for response in stream:
output += response.token.text
yield output
except Exception as e:
output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an."
if(len(context)>0):
output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:"
yield output
print(str(e))
if(len(context)>0):
output=output+"\n\n<br><details open><summary><strong>Quellen</strong></summary><br><ul>"+ "".join(["<li>" + c + "</li>" for c in context])+"</ul></details>"
yield output
#------------------------------
# Launch Gradio-ChatInterface
#------------------------------
i=gr.ChatInterface(multimodal_response,
title="Frag dein PDF",
multimodal=True,
additional_inputs=[
gr.Dropdown(
info="select retrieval version",
choices=["1","2","3"],
value="1",
label="Retrieval Version"),
gr.Textbox(
value="",
label="HF_token"),
])
i.launch() #allowed_paths=["."])