########################################################################################### # Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB # Author: Andreas Fischer # Date: October 10th, 2024 # Last update: October 25th, 2024 ########################################################################################## import os import torch from transformers import AutoTokenizer, AutoModel # chromaDB from datetime import datetime, date #add_doc, import chromadb #chromaDB from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB from chromadb.utils import embedding_functions #chromaDB import ocrmypdf #convertPDF from pypdf import PdfReader #convertPDF import re #format_prompt import gradio as gr # multimodal_response from huggingface_hub import InferenceClient #multimodal_response #--------------------------------------------------- # Specify models for text generation and embeddings #--------------------------------------------------- myModel="mistralai/Mixtral-8x7b-instruct-v0.1" #myModel="princeton-nlp/gemma-2-9b-it-SimPO" #myModel="google/gemma-2-2b-it" #myModel="meta-llama/Llama-3.1-8B-Instruct" #mod=myModel #tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...") #cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}] #cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}] #res=tok.apply_chat_template(cha) #print(tok.decode(res)) jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16) #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de") device='cuda:0' if torch.cuda.is_available() else 'cpu' jina.to(device) #cuda:0 print(device) #----------------- # ChromaDB-client #----------------- class JinaEmbeddingFunction(EmbeddingFunction): def __call__(self, input: Documents) -> Embeddings: embeddings = jina.encode(input) #max_length=2048 return(embeddings.tolist()) dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/" onPrem = True if(os.path.exists(dbPath)) else False if(onPrem==False): dbPath="/home/user/app/db/" print(dbPath) client = chromadb.PersistentClient(path=dbPath) print(client.heartbeat()) print(client.get_version()) print(client.list_collections()) jina_ef=JinaEmbeddingFunction() embeddingModel=jina_ef databases=[(date.today(),"0")] # start a list of databases #--------------------------------------------------------------------- # Function for formatting single message according to prompt template #--------------------------------------------------------------------- def format_prompt0(message, history): prompt = "" #for user_prompt, bot_response in history: # prompt += f"[INST] {user_prompt} [/INST]" # prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt #------------------------------------------------------------------------- # Function for formatting multiturn-dialogue according to prompt template #------------------------------------------------------------------------- def format_prompt(message, history=None, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False, startOfString="", template0=" [INST] {system} [/INST] ",template1=" [INST] {message} [/INST]",template2=" {response}"): # mistralai/Mixtral-8x7B-Instruct-v0.1 #startOfString="",template0="user\n{system}\nmodel\n\n",template1="user\n{message}\nmodel\n",template2="\n"): # google/gemma-2-2b-it #startOfString="", template0="<|start_header_id|>system<|end_header_id|>\n\n{system}\n<|eot_id|>", template1="<|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|>", template2="<|start_header_id|>assistant<|end_header_id|>\n\n{response}"): # meta-llama/Llama-3.1-8B-Instruct? if zeichenlimit is None: zeichenlimit=1000000000 # :-) prompt = "" if RAGAddon is not None: system += RAGAddon if system is not None: prompt += template0.format(system=system) #"" if history is not None: for user_message, bot_response in history[-historylimit:]: if user_message is None: user_message = "" if bot_response is None: bot_response = "" bot_response = re.sub("\n\n
((.|\n)*?)
","", bot_response) # remove RAG-compontents if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering) if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit]) if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit]) if message is not None: prompt += template1.format(message=message[:zeichenlimit]) if system2 is not None: prompt += system2 return startOfString+prompt #-------------------------------------------- # Function for converting pdf-files to text #-------------------------------------------- def convertPDF(pdf_file, allow_ocr=False): reader = PdfReader(pdf_file) full_text = "" page_list = [] def extract_text_from_pdf(reader): full_text = "" page_list = [] page_count = 1 for idx, page in enumerate(reader.pages): text = page.extract_text() if len(text) > 0: page_list.append(text) #full_text += f"---- Page {idx} ----\n" + text + "\n\n" page_count += 1 return full_text.strip(), page_count, page_list # Check if there are any images image_count = sum(len(page.images) for page in reader.pages) # If there are images and not much content, you may want to perform OCR on the document if allow_ocr: print(f"{image_count} Images") if image_count > 0 and len(full_text) < 1000: out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf") ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True) reader = PdfReader(out_pdf_file) # Extract text: full_text, page_count, page_list = extract_text_from_pdf(reader) l = len(page_list) print(f"{l} Pages") # Extract metadata metadata = { "author": reader.metadata.author, "creator": reader.metadata.creator, "producer": reader.metadata.producer, "subject": reader.metadata.subject, "title": reader.metadata.title, "image_count": image_count, "page_count": page_count, "char_count": len(full_text), } return page_list, full_text, metadata #------------------------------------------ # Function for splitting text with overlap #------------------------------------------ def split_with_overlap0(text,chunk_size=3500, overlap=700): """ Split text in chunks based on number of characters (chunk_size) with chunks overlapping (overlap)""" chunks=[] step=max(1,chunk_size-overlap) for i in range(0,len(text),step): end=min(i+chunk_size,len(text)) chunks.append(text[i:end]) return chunks import re def split_with_overlap(text, chunk_size=3500, overlap=700, pattern=r'([.!;?][ \n\r]|[\n\r]{2,})', variant=1, verbose=False): """ Split text in chunks based on regex (pattern) matches. By default the pattern is '([.!;?][ \\n\\r]|[\\n\\r]{2,})' Chunks are no longer than a certain number of characters (chunk_size) with chunks overlapping (overlap). By default (variant=1) chunking is based on complete sentences, but it's also possible to split only within the left overlap region and within the rest of the chunk-size (variant==2) or strictly within both overlap-regions (variant=3). """ chunks = [] overlap=min(overlap,chunk_size) # Overlap kann nicht größer sein als chunk_size step = max(1, chunk_size - overlap) # step richtet sich nach chunk_size und overlap def find_pattern(text): # Funktion zur Suche nach dem Muster return re.search(pattern, text) i, lastEnd = 0,0 while i0): chunks.append(text[end:]) # Ergänze am ende etwaigen Rest return chunks fiveChars= "(?5): if(not "cuda" in device): doc="\n\n".join(doc[0][0:5]) gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing excerpt (demo-mode: first 5 pages on CPU setups)!") else: doc="\n\n".join(doc[0]) gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!") else: doc="\n\n".join(doc[0]) gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!") anhang=True else: gr.Info("No PDF attached - answer based on DB_"+str(session)+".") client = chromadb.PersistentClient(path=dbPath) print(str(client.list_collections())) print(str(session)) dbName="DB_"+str(session) if(not "name="+dbName in str(client.list_collections())): # client.delete_collection(name=dbName) collection = client.create_collection( name=dbName, embedding_function=embeddingModel, metadata={"hnsw:space": "cosine"}) else: collection = client.get_collection( name=dbName, embedding_function=embeddingModel) if(anhang==True): corpus=split_with_overlap(doc,3500,700,pattern=splitRegex) print("Length of corpus: "+str(len(corpus))) print("Corpus:"+str(corpus)) then = datetime.now() x=collection.get(include=[])["ids"] print(len(x)) if(len(x)==0): chunkSize=40000 for i in range(round(len(corpus)/chunkSize+0.5)): print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5))) ids=list(range(i*chunkSize,(i*chunkSize+chunkSize))) batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)] textIDs=[str(id) for id in ids[0:len(batch)]] ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] collection.add(documents=batch, ids=ids, metadatas=[{"date": str("2024-10-10")} for b in batch]) print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5))) now = datetime.now() gr.Info(f"Indexing complete!") print(now-then) return(collection)