YiJina / app.py
Tonic's picture
add chroma cllection name
d24a7c2
raw
history blame
7.85 kB
# main.py
import spaces
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import threading
import queue
import os
import json
import numpy as np
import gradio as gr
from huggingface_hub import InferenceClient
import openai
from openai import OpenAI
from globalvars import API_BASE, intention_prompt, tasks , system_message, model_name
from dotenv import load_dotenv
import re
from utils import load_env_variables
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.config import Settings
from chromadb import HttpClient
from langchain_community.document_loaders import UnstructuredFileLoader
from utils import load_env_variables , parse_and_route
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_CACHE_DISABLE'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_token, yi_token = load_env_variables()
def clear_cuda_cache():
torch.cuda.empty_cache()
client = OpenAI(
api_key=yi_token,
base_url=API_BASE
)
class EmbeddingGenerator:
def __init__(self, model_name: str, token: str, intention_client):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device)
self.intention_client = intention_client
def clear_cuda_cache(self):
torch.cuda.empty_cache()
@spaces.GPU
def compute_embeddings(self, input_text: str):
# Get the intention
intention_completion = self.intention_client.chat.completions.create(
model="yi-large",
messages=[
{"role": "system", "content": intention_prompt},
{"role": "user", "content": input_text}
]
)
intention_output = intention_completion.choices[0].message['content']
# Parse and route the intention
parsed_task = parse_and_route(intention_output)
selected_task = list(parsed_task.keys())[0]
# Construct the prompt
try:
task_description = tasks[selected_task]
except KeyError:
print(f"Selected task not found: {selected_task}")
return f"Error: Task '{selected_task}' not found. Please select a valid task."
query_prefix = f"Instruct: {task_description}\nQuery: "
queries = [input_text]
# Get the embeddings
with torch.no_grad():
inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
outputs = self.model(**inputs)
query_embeddings = outputs.last_hidden_state.mean(dim=1)
# Normalize embeddings
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
self.clear_cuda_cache()
return embeddings_list
class MyEmbeddingFunction(EmbeddingFunction):
def __init__(self, embedding_generator: EmbeddingGenerator):
self.embedding_generator = embedding_generator
def __call__(self, input: Documents) -> Embeddings:
embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input]
embeddings = [item for sublist in embeddings for item in sublist]
return embeddings
class DocumentLoader:
def __init__(self, file_path: str, mode: str = "elements"):
self.file_path = file_path
self.mode = mode
def load_documents(self):
loader = UnstructuredFileLoader(self.file_path, mode=self.mode)
docs = loader.load()
return [doc.page_content for doc in docs]
class ChromaManager:
def __init__(self, collection_name: str, embedding_function: MyEmbeddingFunction):
self.client = HttpClient(settings=Settings(allow_reset=True))
self.client.reset() # resets the database
self.collection = self.client.create_collection(collection_name)
self.embedding_function = embedding_function
def add_documents(self, documents: list):
for doc in documents:
self.collection.add(ids=[str(uuid.uuid1())], documents=[doc], embeddings=self.embedding_function([doc]))
def query(self, query_text: str):
db = Chroma(client=self.client, collection_name=self.collection.name, embedding_function=self.embedding_function)
result_docs = db.similarity_search(query_text)
return result_docs
# Initialize clients
intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
embedding_function = MyEmbeddingFunction(embedding_generator=embedding_generator)
chroma_manager = ChromaManager(collection_name="Tonic-instruct" , embedding_function=embedding_function)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
retrieved_text = query_documents(message)
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": f"{retrieved_text}\n\n{message}"})
response = ""
for message in intention_client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
def upload_documents(files):
for file in files:
loader = DocumentLoader(file.name)
documents = loader.load_documents()
chroma_manager.add_documents(documents)
return "Documents uploaded and processed successfully!"
def query_documents(query):
results = chroma_manager.query(query)
return "\n\n".join([result.content for result in results])
with gr.Blocks() as demo:
with gr.Tab("Upload Documents"):
with gr.Row():
document_upload = gr.File(file_count="multiple", file_types=["document"])
upload_button = gr.Button("Upload and Process")
upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text())
with gr.Tab("Ask Questions"):
with gr.Row():
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
query_input = gr.Textbox(label="Query")
query_button = gr.Button("Query")
query_output = gr.Textbox()
query_button.click(query_documents, inputs=query_input, outputs=query_output)
if __name__ == "__main__":
demo.launch()