# app.py import spaces from torch.nn import DataParallel from torch import Tensor from transformers import AutoTokenizer, AutoModel from huggingface_hub import InferenceClient from openai import OpenAI from langchain_community.document_loaders import UnstructuredFileLoader from langchain_chroma import Chroma from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.config import Settings from chromadb import HttpClient import os import re import uuid import gradio as gr import torch import torch.nn.functional as F from dotenv import load_dotenv from utils import load_env_variables, parse_and_route , escape_special_characters from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name , metadata_prompt import time import httpx load_dotenv() os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30' os.environ['CUDA_LAUNCH_BLOCKING'] = '1' os.environ['CUDA_CACHE_DISABLE'] = '1' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ### Utils hf_token, yi_token = load_env_variables() def clear_cuda_cache(): torch.cuda.empty_cache() client = OpenAI(api_key=yi_token, base_url=API_BASE) class EmbeddingGenerator: def __init__(self, model_name: str, token: str, intention_client): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True) self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device) self.intention_client = intention_client def clear_cuda_cache(self): torch.cuda.empty_cache() @spaces.GPU def compute_embeddings(self, input_text: str): escaped_input_text = escape_special_characters(input_text) intention_completion = self.intention_client.chat.completions.create( model="yi-large", messages=[ {"role": "system", "content": escape_special_characters(intention_prompt)}, {"role": "user", "content": escaped_input_text} ] ) intention_output = intention_completion.choices[0].message['content'] # Parse and route the intention parsed_task = parse_and_route(intention_output) selected_task = list(parsed_task.keys())[0] # Construct the prompt try: task_description = tasks[selected_task] except KeyError: print(f"Selected task not found: {selected_task}") return f"Error: Task '{selected_task}' not found. Please select a valid task." query_prefix = f"Instruct: {task_description}\nQuery: " queries = [escaped_input_text] # Get the metadata metadata_completion = self.intention_client.chat.completions.create( model="yi-large", messages=[ {"role": "system", "content": escape_special_characters(metadata_prompt)}, {"role": "user", "content": escaped_input_text} ] ) metadata_output = metadata_completion.choices[0].message['content'] metadata = self.extract_metadata(metadata_output) # Get the embeddings with torch.no_grad(): inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device) outputs = self.model(**inputs) query_embeddings = outputs.last_hidden_state.mean(dim=1) # Normalize embeddings query_embeddings = F.normalize(query_embeddings, p=2, dim=1) embeddings_list = query_embeddings.detach().cpu().numpy().tolist() self.clear_cuda_cache() return embeddings_list, metadata def extract_metadata(self, metadata_output: str): # Regex pattern to extract key-value pairs pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"') matches = pattern.findall(metadata_output) metadata = {key: value for key, value in matches} return metadata class MyEmbeddingFunction(EmbeddingFunction): def __init__(self, embedding_generator: EmbeddingGenerator): self.embedding_generator = embedding_generator def __call__(self, input: Documents) -> (Embeddings, list): embeddings_with_metadata = [self.embedding_generator.compute_embeddings(doc) for doc in input] embeddings = [item[0] for item in embeddings_with_metadata] metadata = [item[1] for item in embeddings_with_metadata] embeddings_flattened = [emb for sublist in embeddings for emb in sublist] metadata_flattened = [meta for sublist in metadata for meta in sublist] return embeddings_flattened, metadata_flattened def load_documents(file_path: str, mode: str = "elements"): loader = UnstructuredFileLoader(file_path, mode=mode) docs = loader.load() return [doc.page_content for doc in docs] def wait_for_chroma_server(client, retries=10, delay=0.5): for _ in range(retries): try: client.heartbeat() print("Chroma server is up and running!") return True except Exception as e: print(f"Attempt to connect to Chroma server failed: {e}") time.sleep(delay) print("Failed to connect to Chroma server after multiple attempts.") return False def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction): host = 'localhost' port = 8000 client = HttpClient(host=host, port=port, settings=Settings(allow_reset=True, anonymized_telemetry=False)) if not wait_for_chroma_server(client): raise ConnectionError("Could not connect to Chroma server. Ensure it is running.") client.reset() # Empties and completely resets the database collection = client.create_collection(collection_name) return client, collection def add_documents_to_chroma(client, collection, documents: list, embedding_function: MyEmbeddingFunction): for doc in documents: embeddings, metadata = embedding_function.embedding_generator.compute_embeddings(doc) for embedding, meta in zip(embeddings, metadata): collection.add( ids=[str(uuid.uuid1())], documents=[doc], embeddings=[embedding], metadatas=[meta] ) def query_chroma(client, collection_name: str, query_text: str, embedding_function: MyEmbeddingFunction): # Compute query embeddings and metadata query_embeddings, query_metadata = embedding_function.embedding_generator.compute_embeddings(query_text) # Initialize Chroma with the collection db = Chroma(client=client, collection_name=collection_name, embedding_function=embedding_function) # Perform similarity search using the query embeddings and metadata result_docs = db.similarity_search( query_embeddings=query_embeddings, query_metadata=query_metadata ) return result_docs # Initialize clients intention_client = OpenAI(api_key=yi_token, base_url=API_BASE) embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client) embedding_function = MyEmbeddingFunction(embedding_generator=embedding_generator) chroma_client, chroma_collection = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function) def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): retrieved_text = query_documents(message) messages = [{"role": "system", "content": escape_special_characters(system_message)}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": f"{retrieved_text}\n\n{escape_special_characters(message)}"}) response = "" for message in intention_client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = message.choices[0].delta.content response += token yield response def upload_documents(files): for file in files: loader = UnstructuredFileLoader(file.name) documents = loader.load() add_documents_to_chroma(chroma_client, chroma_collection, documents, embedding_function) return "Documents uploaded and processed successfully!" def query_documents(query): results = query_chroma(query) return "\n\n".join([result.content for result in results]) with gr.Blocks() as demo: with gr.Tab("Upload Documents"): document_upload = gr.File(file_count="multiple", file_types=["document"]) upload_button = gr.Button("Upload and Process") upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text()) with gr.Tab("Ask Questions"): with gr.Row(): chat_interface = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are a friendly Chatbot.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), ], ) query_input = gr.Textbox(label="Query") query_button = gr.Button("Query") query_output = gr.Textbox() query_button.click(query_documents, inputs=query_input, outputs=query_output) if __name__ == "__main__": os.system("chroma run --host localhost --port 8000 &") demo.launch()