########################################################################################### # Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB # Author: Andreas Fischer # Date: October 10th, 2024 # Last update: October 10th, 2024 ########################################################################################## import os import chromadb from datetime import datetime from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.utils import embedding_functions from transformers import AutoTokenizer, AutoModel import torch jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16) #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de") device='cuda' if torch.cuda.is_available() else 'cpu' #device='cpu' #'cuda' if torch.cuda.is_available() else 'cpu' jina.to(device) #cuda:0 print(device) class JinaEmbeddingFunction(EmbeddingFunction): def __call__(self, input: Documents) -> Embeddings: embeddings = jina.encode(input) #max_length=2048 return(embeddings.tolist()) dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db" onPrem = True if(os.path.exists(dbPath)) else False if(onPrem==False): dbPath="/home/user/app/db" #onPrem=True # uncomment to override automatic detection print(dbPath) path=dbPath client = chromadb.PersistentClient(path=path) print(client.heartbeat()) print(client.get_version()) print(client.list_collections()) jina_ef=JinaEmbeddingFunction() embeddingModel=jina_ef from huggingface_hub import InferenceClient import gradio as gr import json inferenceClient = InferenceClient( "mistralai/Mixtral-8x7B-Instruct-v0.1" #"mistralai/Mistral-7B-Instruct-v0.1" ) def format_prompt(message, history): prompt = "" #for user_prompt, bot_response in history: # prompt += f"[INST] {user_prompt} [/INST]" # prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt from pypdf import PdfReader import ocrmypdf def convertPDF(pdf_file, allow_ocr=False): reader = PdfReader(pdf_file) full_text = "" page_list = [] def extract_text_from_pdf(reader): full_text = "" page_list = [] page_count = 1 for idx, page in enumerate(reader.pages): text = page.extract_text() if len(text) > 0: page_list.append(text) #full_text += f"---- Page {idx} ----\n" + text + "\n\n" page_count += 1 return full_text.strip(), page_count, page_list # Check if there are any images image_count = sum(len(page.images) for page in reader.pages) # If there are images and not much content, perform OCR on the document if allow_ocr: print(f"{image_count} Images") if image_count > 0 and len(full_text) < 1000: out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf") ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True) reader = PdfReader(out_pdf_file) # Extract text: full_text, page_count, page_list = extract_text_from_pdf(reader) l = len(page_list) print(f"{l} Pages") # Extract metadata metadata = { "author": reader.metadata.author, "creator": reader.metadata.creator, "producer": reader.metadata.producer, "subject": reader.metadata.subject, "title": reader.metadata.title, "image_count": image_count, "page_count": page_count, "char_count": len(full_text), } return page_list, full_text, metadata def split_with_overlap(text,chunk_size=3500, overlap=700): chunks=[] step=max(1,chunk_size-overlap) for i in range(0,len(text),step): end=min(i+chunk_size,len(text)) #chunk = text[i:i+chunk_size] chunks.append(text[i:end]) return chunks def add_doc(path): print("def add_doc!") print(path) if(str.lower(path).endswith(".pdf")): doc=convertPDF(path) doc="\n\n".join(doc[0]) gr.Info("PDF uploaded, start Indexing!") else: gr.Info("Error: Only pdfs are accepted!") client = chromadb.PersistentClient(path="output/general_knowledge") print(str(client.list_collections())) #global collection dbName="test" if("name="+dbName in str(client.list_collections())): client.delete_collection(name=dbName) collection = client.create_collection( dbName, embedding_function=embeddingModel, metadata={"hnsw:space": "cosine"}) corpus=split_with_overlap(doc,3500,700) print(len(corpus)) then = datetime.now() x=collection.get(include=[])["ids"] print(len(x)) if(len(x)==0): chunkSize=40000 for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5))) ids=list(range(i*chunkSize,(i*chunkSize+chunkSize))) batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)] textIDs=[str(id) for id in ids[0:len(batch)]] ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID collection.add(documents=batch, ids=ids, metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids, print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5))) now = datetime.now() gr.Info(f"Indexing complete!") print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks return(collection) #split_with_overlap("test me if you can",2,1) import gradio as gr import re def multimodalResponse(message,history,dropdown): print("def multimodal response!") length=str(len(history)) query=message["text"] if(len(message["files"])>0): # is there at least one file attached? collection=add_doc(message["files"][0]) else: collection=add_doc(message["text"]) client = chromadb.PersistentClient(path="output/general_knowledge") print(str(client.list_collections())) x=collection.get(include=[])["ids"] context=collection.query(query_texts=[query], n_results=1) print(str(context)) #context=["\n"+c+"\n" for i, c in enumerate(retrievedTexts)] #context="\n\n".join(context) #return context generate_kwargs = dict( temperature=float(0.9), max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0, do_sample=True, seed=42, ) system="Given the following conversation, relevant context, and a follow up question, "+\ "reply with an answer to the current question the user is asking. "+\ "Return only your response to the question given the above information "+\ "following the users instructions as needed.\n\nContext:"+\ str(context) print(system) formatted_prompt = format_prompt(system+"\n"+prompt, history) stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text yield output #output=output+"\n\n
Sources
" yield output i=gr.ChatInterface(multimodalResponse, title="Frag dein PDF", multimodal=True, additional_inputs=[ gr.Dropdown( info="select retrieval version", choices=["1","2","3"], value=["1"], label="Retrieval Version")]) i.launch() #allowed_paths=["."])