########################################################################################### # Title: Gradio Interface to LLM-chatbot with dynamic RAG-funcionality and ChromaDB # Author: Andreas Fischer # Date: October 10th, 2024 # Last update: October 24th, 2024 ########################################################################################## import os import torch from transformers import AutoTokenizer, AutoModel # chromaDB from datetime import datetime, date #add_doc, import chromadb #chromaDB from chromadb import Documents, EmbeddingFunction, Embeddings #chromaDB from chromadb.utils import embedding_functions #chromaDB import ocrmypdf #convertPDF from pypdf import PdfReader #convertPDF import re #format_prompt import gradio as gr # multimodal_response from huggingface_hub import InferenceClient #multimodal_response #--------------------------------------------------- # Specify models for text generation and embeddings #--------------------------------------------------- myModel="mistralai/Mixtral-8x7b-instruct-v0.1" #mod="mistralai/Mixtral-8x7b-instruct-v0.1" #tok=AutoTokenizer.from_pretrained(mod) #,token="hf_...") #cha=[{"role":"system","content":"A"},{"role":"user","content":"B"},{"role":"assistant","content":"C"}] #cha=[{"role":"user","content":"U1"},{"role":"assistant","content":"A1"},{"role":"user","content":"U2"},{"role":"assistant","content":"A2"}] #res=tok.apply_chat_template(cha) #print(tok.decode(res)) jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16) #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de") device='cuda:0' if torch.cuda.is_available() else 'cpu' jina.to(device) #cuda:0 print(device) #----------------- # ChromaDB-client #----------------- class JinaEmbeddingFunction(EmbeddingFunction): def __call__(self, input: Documents) -> Embeddings: embeddings = jina.encode(input) #max_length=2048 return(embeddings.tolist()) dbPath = "/home/af/Schreibtisch/Code/gradio/Chroma/db/" onPrem = True if(os.path.exists(dbPath)) else False if(onPrem==False): dbPath="/home/user/app/db/" print(dbPath) client = chromadb.PersistentClient(path=dbPath) print(client.heartbeat()) print(client.get_version()) print(client.list_collections()) jina_ef=JinaEmbeddingFunction() embeddingModel=jina_ef databases=[(date.today(),"0")] # start a list of databases #--------------------------------------------------------------------- # Function for formatting single message according to prompt template #--------------------------------------------------------------------- def format_prompt0(message, history): prompt = "" #for user_prompt, bot_response in history: # prompt += f"[INST] {user_prompt} [/INST]" # prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt #------------------------------------------------------------------------- # Function for formatting multiturn-dialogue according to prompt template #------------------------------------------------------------------------- def format_prompt(message, history=None, system=None, RAGAddon=None, system2=None, zeichenlimit=None,historylimit=4, removeHTML=False, startOfString="", template0=" [INST] {system} [/INST] ",template1=" [INST] {message} [/INST]",template2=" {response}"): if zeichenlimit is None: zeichenlimit=1000000000 # :-) prompt = "" if RAGAddon is not None: system += RAGAddon if system is not None: prompt += template0.format(system=system) #"" if history is not None: for user_message, bot_response in history[-historylimit:]: if user_message is None: user_message = "" if bot_response is None: bot_response = "" bot_response = re.sub("\n\n
((.|\n)*?)
","", bot_response) # remove RAG-compontents if removeHTML==True: bot_response = re.sub("<(.*?)>","\n", bot_response) # remove HTML-components in general (may cause bugs with markdown-rendering) if user_message is not None: prompt += template1.format(message=user_message[:zeichenlimit]) if bot_response is not None: prompt += template2.format(response=bot_response[:zeichenlimit]) if message is not None: prompt += template1.format(message=message[:zeichenlimit]) if system2 is not None: prompt += system2 return startOfString+prompt #-------------------------------------------- # Function for converting pdf-files to text #-------------------------------------------- def convertPDF(pdf_file, allow_ocr=False): reader = PdfReader(pdf_file) full_text = "" page_list = [] def extract_text_from_pdf(reader): full_text = "" page_list = [] page_count = 1 for idx, page in enumerate(reader.pages): text = page.extract_text() if len(text) > 0: page_list.append(text) #full_text += f"---- Page {idx} ----\n" + text + "\n\n" page_count += 1 return full_text.strip(), page_count, page_list # Check if there are any images image_count = sum(len(page.images) for page in reader.pages) # If there are images and not much content, you may want to perform OCR on the document if allow_ocr: print(f"{image_count} Images") if image_count > 0 and len(full_text) < 1000: out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf") ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True) reader = PdfReader(out_pdf_file) # Extract text: full_text, page_count, page_list = extract_text_from_pdf(reader) l = len(page_list) print(f"{l} Pages") # Extract metadata metadata = { "author": reader.metadata.author, "creator": reader.metadata.creator, "producer": reader.metadata.producer, "subject": reader.metadata.subject, "title": reader.metadata.title, "image_count": image_count, "page_count": page_count, "char_count": len(full_text), } return page_list, full_text, metadata #------------------------------------------ # Function for splitting text with overlap #------------------------------------------ def split_with_overlap0(text,chunk_size=3500, overlap=700): """ Split text in chunks based on number of characters (chunk_size) with chunks overlapping (overlap)""" chunks=[] step=max(1,chunk_size-overlap) for i in range(0,len(text),step): end=min(i+chunk_size,len(text)) chunks.append(text[i:end]) return chunks import re def split_with_overlap(text, chunk_size=3500, overlap=700, pattern=r'([.!;?][ \n\r]|[\n\r]{2,})', variant=1, verbose=False): """ Split text in chunks based on regex (pattern) matches. By default the pattern is '([.!;?][ \\n\\r]|[\\n\\r]{2,})' Chunks are no longer than a certain number of characters (chunk_size) with chunks overlapping (overlap). By default (variant=1) chunking is based on complete sentences, but it's also possible to split only within the left overlap region and within the rest of the chunk-size (variant==2) or strictly within both overlap-regions (variant=3). """ chunks = [] overlap=min(overlap,chunk_size) # Overlap kann nicht größer sein als chunk_size step = max(1, chunk_size - overlap) # step richtet sich nach chunk_size und overlap def find_pattern(text): # Funktion zur Suche nach dem Muster return re.search(pattern, text) i, lastEnd = 0,0 while i0): chunks.append(text[end:]) # Ergänze am ende etwaigen Rest return chunks fiveChars= "(?5): if(not "cuda" in device): doc="\n\n".join(doc[0][0:5]) gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing excerpt (first 5 pages on CPU setups)!") else: doc="\n\n".join(doc[0]) gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!") else: doc="\n\n".join(doc[0]) gr.Info("PDF uploaded to DB_"+str(session)+", start Indexing!") anhang=True else: gr.Info("No PDF attached - answer based on DB_"+str(session)+".") client = chromadb.PersistentClient(path=dbPath) print(str(client.list_collections())) print(str(session)) dbName="DB_"+str(session) if(not "name="+dbName in str(client.list_collections())): # client.delete_collection(name=dbName) collection = client.create_collection( name=dbName, embedding_function=embeddingModel, metadata={"hnsw:space": "cosine"}) else: collection = client.get_collection( name=dbName, embedding_function=embeddingModel) if(anhang==True): corpus=split_with_overlap(doc,3500,700,pattern=splitRegex) print("Length of corpus: "+str(len(corpus))) print("Corpus:"+str(corpus)) then = datetime.now() x=collection.get(include=[])["ids"] print(len(x)) if(len(x)==0): chunkSize=40000 for i in range(round(len(corpus)/chunkSize+0.5)): #0 is first batch, 3 is last (incomplete) batch given 133497 texts print("embed batch "+str(i)+" of "+str(round(len(corpus)/chunkSize+0.5))) ids=list(range(i*chunkSize,(i*chunkSize+chunkSize))) batch=corpus[i*chunkSize:(i*chunkSize+chunkSize)] textIDs=[str(id) for id in ids[0:len(batch)]] ids=[str(id+len(x)+1) for id in ids[0:len(batch)]] # id refers to chromadb-unique ID collection.add(documents=batch, ids=ids, metadatas=[{"date": str("2024-10-10")} for b in batch]) #"textID":textIDs, "id":ids, print("finished batch "+str(i)+" of "+str(round(len(corpus)/40000+0.5))) now = datetime.now() gr.Info(f"Indexing complete!") print(now-then) #zu viel GB für sentences (GPU), bzw. 0:00:10.375087 für chunks return(collection) #-------------------------------------------------------- # Function for response to user queries and pot. addenda #-------------------------------------------------------- def multimodal_response(message, history, dropdown, hfToken, request: gr.Request): print("def multimodal response!") if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided inferenceClient = InferenceClient(model=myModel, token=hfToken) else: inferenceClient = InferenceClient(myModel) global databases if request: session=request.session_hash else: session="0" length=str(len(history)) print(databases) if(not databases[-1][1]==session): databases.append((date.today(),session)) #print(databases) query=message["text"] if(len(message["files"])>0): # is there at least one file attached? collection=add_doc(message["files"][0], session) else: # otherwise, you still want to get the collection with the session-based db collection=add_doc(message["text"], session) client = chromadb.PersistentClient(path=dbPath) print(str(client.list_collections())) x=collection.get(include=[])["ids"] ragQuery=[format_prompt(query, history) if len(history)>0 else query] context=collection.query(query_texts=ragQuery, n_results=3) context=[" "+str(c)+"" for i,c in enumerate(context["documents"][0])] gr.Info("Kontext:\n"+str(context)) generate_kwargs = dict( temperature=float(0.9), max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0, do_sample=True, seed=42, ) system="Mit Blick auf das folgende Gespräch und den relevanten Kontext, antworte auf die aktuelle Frage des Nutzers. "+\ "Antworte ausschließlich auf Basis der Informationen im Kontext.\n\nKontext:\n\n"+\ str("\n\n".join(context)) #"Given the following conversation, relevant context, and a follow up question, "+\ #"reply with an answer to the current question the user is asking. "+\ #"Return only your response to the question given the above information "+\ #"following the users instructions as needed.\n\nContext:"+\ print(system) #formatted_prompt = format_prompt0(system+"\n"+query, history) formatted_prompt = format_prompt(query, history,system=system) print(formatted_prompt) output = "" try: stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) for response in stream: output += response.token.text yield output except Exception as e: output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an." if(len(context)>0): output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:" yield output print(str(e)) if(len(context)>0): output=output+"\n\n
Quellen
    "+ "".join(["
  • " + c + "
  • " for c in context])+"
" yield output #------------------------------ # Launch Gradio-ChatInterface #------------------------------ i=gr.ChatInterface(multimodal_response, title="Frag dein PDF", multimodal=True, additional_inputs=[ gr.Dropdown( info="Wähle eine Variante", choices=["1","2","3"], value="1", label="Variante"), gr.Textbox( value="", label="HF_token"), ]) i.launch() #allowed_paths=["."])