Spaces:
Running
Running
########################################################################################### | |
# Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB | |
# Author: Andreas Fischer | |
# Date: October 10th, 2024 | |
# Last update: October 10th, 2024 | |
########################################################################################## | |
import os | |
import chromadb | |
from datetime import datetime | |
from chromadb import Documents, EmbeddingFunction, Embeddings | |
from chromadb.utils import embedding_functions | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16) | |
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de") | |
device='cuda' if torch.cuda.is_available() else 'cpu' | |
#device='cpu' #'cuda' if torch.cuda.is_available() else 'cpu' | |
jina.to(device) #cuda:0 | |
print(device) | |
class JinaEmbeddingFunction(EmbeddingFunction): | |
def __call__(self, input: Documents) -> Embeddings: | |
embeddings = jina.encode(input) #max_length=2048 | |
return(embeddings.tolist()) | |
dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db" | |
onPrem = True if(os.path.exists(dbPath)) else False | |
if(onPrem==False): dbPath="/home/user/app/db" | |
#onPrem=True # uncomment to override automatic detection | |
print(dbPath) | |
path=dbPath | |
client = chromadb.PersistentClient(path=path) | |
print(client.heartbeat()) | |
print(client.get_version()) | |
print(client.list_collections()) | |
jina_ef=JinaEmbeddingFunction() | |
embeddingModel=jina_ef | |
from huggingface_hub import InferenceClient | |
import gradio as gr | |
import json | |
inferenceClient = InferenceClient( | |
"mistralai/Mixtral-8x7B-Instruct-v0.1" | |
#"mistralai/Mistral-7B-Instruct-v0.1" | |
) | |
def format_prompt(message, history): | |
prompt = "<s>" | |
#for user_prompt, bot_response in history: | |
# prompt += f"[INST] {user_prompt} [/INST]" | |
# prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
from pypdf import PdfReader | |
import ocrmypdf | |
def convertPDF(pdf_file, allow_ocr=False): | |
reader = PdfReader(pdf_file) | |
full_text = "" | |
page_list = [] | |
def extract_text_from_pdf(reader): | |
full_text = "" | |
page_list = [] | |
page_count = 1 | |
for idx, page in enumerate(reader.pages): | |
text = page.extract_text() | |
if len(text) > 0: | |
page_list.append(text) | |
#full_text += f"---- Page {idx} ----\n" + text + "\n\n" | |
page_count += 1 | |
return full_text.strip(), page_count, page_list | |
# Check if there are any images | |
image_count = sum(len(page.images) for page in reader.pages) | |
# If there are images and not much content, perform OCR on the document | |
if allow_ocr: | |
print(f"{image_count} Images") | |
if image_count > 0 and len(full_text) < 1000: | |
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf") | |
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True) | |
reader = PdfReader(out_pdf_file) | |
# Extract text: | |
full_text, page_count, page_list = extract_text_from_pdf(reader) | |
l = len(page_list) | |
print(f"{l} Pages") | |
# Extract metadata | |
metadata = { | |
"author": reader.metadata.author, | |
"creator": reader.metadata.creator, | |
"producer": reader.metadata.producer, | |
"subject": reader.metadata.subject, | |
"title": reader.metadata.title, | |
"image_count": image_count, | |
"page_count": page_count, | |
"char_count": len(full_text), | |
} | |
return page_list, full_text, metadata | |
def split_with_overlap(text,chunk_size=3500, overlap=700): | |
chunks=[] | |
step=max(1,chunk_size-overlap) | |
for i in range(0,len(text),step): | |
end=min(i+chunk_size,len(text)) | |
#chunk = text[i:i+chunk_size] | |
chunks.append(text[i:end]) | |
return chunks | |
def add_doc(path): | |
print("def add_doc!") | |
print(path) | |
if(str.lower(path).endswith(".pdf")): | |
doc=convertPDF(path) | |
doc="\n\n".join(doc[0]) | |
gr.Info("PDF uploaded, start Indexing!") | |
else: | |
gr.Info("Error: Only pdfs are accepted!") | |
client = chromadb.PersistentClient(path="output/general_knowledge") | |
print(str(client.list_collections())) | |
#global collection | |
dbName="test" | |
if("name="+dbName in str(client.list_collections())): | |
client.delete_collection(name=dbName) | |
collection = client.create_collection( | |
dbName, | |
embedding_function=embeddingModel, | |
metadata={"hnsw:space": "cosine"}) | |
corpus=split_with_overlap(doc,3500,700) | |
print(len(corpus)) | |
then = datetime.now() | |
x=collection.get(include=[])["ids"] | |
print(len(x)) | |
if(len(x)==0): | |
chunkSize=40000 | |
for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts | |
print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5))) | |
ids=list(range(i*chunkSize,(i*chunkSize+chunkSize))) | |
batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)] | |
textIDs=[str(id) for id in ids[0:len(batch)]] | |
ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID | |
collection.add(documents=batch, ids=ids, | |
metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids, | |
print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5))) | |
now = datetime.now() | |
gr.Info(f"Indexing complete!") | |
print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks | |
return(collection) | |
#split_with_overlap("test me if you can",2,1) | |
import gradio as gr | |
import re | |
def multimodalResponse(message,history,headerPattern,sentenceWiseSplitting): | |
print("def multimodal response!") | |
length=str(len(history)) | |
query=message["text"] | |
if(len(message["files"])>0): # is there at least one file attached? | |
collection=add_doc(message["files"][0]) | |
client = chromadb.PersistentClient(path="output/general_knowledge") | |
print(str(client.list_collections())) | |
x=collection.get(include=[])["ids"] | |
context=collection.query(query_texts=[query], n_results=1) | |
print(str(context)) | |
#context=["<context "+str(i+1)+">\n"+c+"\n</context "+str(i+1)+">" for i, c in enumerate(retrievedTexts)] | |
#context="\n\n".join(context) | |
#return context | |
if temperature < 1e-2: temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=float(0.9), | |
max_new_tokens=5000, | |
top_p=0.95, | |
repetition_penalty=1.0, | |
do_sample=True, | |
seed=42, | |
) | |
system="Given the following conversation, relevant context, and a follow up question, "+\ | |
"reply with an answer to the current question the user is asking. "+\ | |
"Return only your response to the question given the above information "+\ | |
"following the users instructions as needed.\n\nContext:"+\ | |
str(context) | |
print(system) | |
formatted_prompt = format_prompt(system+"\n"+prompt, history) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
yield output | |
#output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>" | |
yield output | |
i=gr.ChatInterface(multimodalResponse, | |
title="pdfChatbot", | |
multimodal=True, | |
additional_inputs=[ | |
gr.Dropdown( | |
info="select retrieval version", | |
choices=["1","2","3"], | |
value=["1"], | |
label="Retrieval Version")]) | |
i.launch() #allowed_paths=["."]) | |