Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import chromadb | |
from chromadb.config import Settings | |
from io import StringIO | |
from sentence_transformers import SentenceTransformer | |
import openai | |
import plotly.express as px | |
from sklearn.manifold import TSNE | |
# Initialize Chroma client with DuckDB and Parquet for persistence | |
chroma_client = chromadb.Client(Settings( | |
chroma_db_impl="duckdb+parquet", | |
persist_directory="./chroma_db" | |
)) | |
# Model Configuration for Dynamic Dropdown | |
model_config = { | |
"gpt-4": { | |
"endpoint": "https://roger-m38jr9pd-eastus2.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-08-01-preview", | |
"api_key": os.getenv("GPT4_API_KEY") | |
}, | |
"gpt-4o": { | |
"endpoint": "https://roger-m38jr9pd-eastus2.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview", | |
"api_key": os.getenv("GPT4O_API_KEY") | |
}, | |
"gpt-35-turbo": { | |
"endpoint": "https://rogerkoranteng.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-08-01-preview", | |
"api_key": os.getenv("GPT35_TURBO_API_KEY") | |
}, | |
"gpt-4-32k": { | |
"endpoint": "https://roger-m38orjxq-australiaeast.openai.azure.com/openai/deployments/gpt-4-32k/chat/completions?api-version=2024-08-01-preview", | |
"api_key": os.getenv("GPT4_32K_API_KEY") | |
} | |
} | |
# Function to process uploaded CSV | |
def process_csv_text(temp_file): | |
if isinstance(temp_file, str): | |
df = pd.read_csv(StringIO(temp_file)) | |
else: | |
df = pd.read_csv(temp_file.name, header='infer', sep=',') | |
return df, gr.Dropdown.update(choices=list(df.columns)) | |
# Insert or update ChromaDB with embeddings | |
def insert_or_update_chroma(col, table, model_name, similarity_metric, client=chroma_client): | |
try: | |
collection = client.create_collection(name="my_collection", | |
embedding_function=SentenceTransformer(model_name), | |
metadata={"hnsw:space": similarity_metric}) | |
except Exception as e: | |
print("Collection exists, deleting it") | |
client.delete_collection(name='my_collection') | |
collection = client.create_collection(name="my_collection", | |
embedding_function=SentenceTransformer(model_name), | |
metadata={"hnsw:space": similarity_metric}) | |
if collection: | |
try: | |
collection.add( | |
documents=list(table[col]), | |
metadatas=[{"source": i} for i in range(len(table))], | |
ids=[str(i + 1) for i in range(len(table))] | |
) | |
return "Embedding calculations and insertions successful" | |
except Exception as e: | |
return "Error in embedding calculations" | |
# Show plot with embeddings using t-SNE | |
def show_fig(): | |
collection = chroma_client.get_collection(name="my_collection") | |
embeddings = collection.get(include=['embeddings', 'documents']) | |
df = pd.DataFrame({ | |
'text': embeddings['documents'], | |
'embedding': embeddings['embeddings'] | |
}) | |
embeddings_np = np.array(df['embedding'].tolist()) | |
tsne = TSNE(n_components=2, random_state=42) | |
transformed = tsne.fit_transform(embeddings_np) | |
df['tsne_x'] = transformed[:, 0] | |
df['tsne_y'] = transformed[:, 1] | |
fig = px.scatter(df, x='tsne_x', y='tsne_y', hover_name='text') | |
return fig, transformed | |
# Show test string figure | |
def show_test_string_fig(test_string, tsne, model_name, similarity_metric): | |
collection = chroma_client.get_collection(name="my_collection", | |
embedding_function=SentenceTransformer(model_name)) | |
collection.add( | |
documents=[test_string], | |
metadatas=[{"source": 'test'}], | |
ids=['test_sample'] | |
) | |
embeddings = collection.get(include=['embeddings', 'documents']) | |
df = pd.DataFrame({ | |
'text': embeddings['documents'], | |
'embedding': embeddings['embeddings'], | |
'set': ['orig' if document != test_string else 'test_string' for document in embeddings["documents"]] | |
}) | |
embeddings_np = np.array(df['embedding'].tolist()) | |
transformed = tsne.transform(embeddings_np) | |
df['tsne_x'] = transformed[:, 0] | |
df['tsne_y'] = transformed[:, 1] | |
fig = px.scatter(df, x='tsne_x', y='tsne_y', hover_name='text', color='set') | |
return fig, tsne | |
# Function to interact with OpenAI's Azure API | |
def ask_gpt(message, messages_history, embedding_model, system_prompt, temperature, max_tokens, chatgpt_model): | |
if len(messages_history) < 1: | |
messages_history = [{"role": "system", "content": system_prompt}] | |
model_info = model_config[chatgpt_model] | |
headers = { | |
"Content-Type": "application/json", | |
"api-key": model_info["api_key"] | |
} | |
message = retrieve_similar(message, embedding_model) | |
messages_history += [{"role": "user", "content": message}] | |
response = openai.ChatCompletion.create( | |
model=chatgpt_model, | |
messages=messages_history, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return response['choices'][0]['message']['content'], messages_history | |
# Function to retrieve similar questions from ChromaDB | |
def retrieve_similar(prompt, embedding_model, client=chroma_client): | |
collection = client.get_collection(name="my_collection", embedding_function=SentenceTransformer(model_name=embedding_model)) | |
results = collection.query(query_texts=prompt, n_results=10) | |
additional_context = '' | |
for i, document in enumerate(results['documents'][0]): | |
if i == 0: | |
additional_context = 'Information: \n' + str(i+1) + '. ' + document | |
else: | |
additional_context += '\n' + str(i+1) + '. ' + document | |
prompt_with_context = additional_context + '\nQuestion: ' + prompt | |
return prompt_with_context | |
# Gradio App Setup | |
with gr.Blocks() as demo: | |
# Tab 1: Upload CSV and Display Data | |
with gr.Tab("Upload data"): | |
upload_button = gr.UploadButton(label="Upload csv", file_types=['.csv'], file_count="single") | |
table = gr.Dataframe(type="pandas", max_rows='20', overflow_row_behaviour='paginate', interactive=True) | |
cols = gr.Dropdown(choices=[], label='Dataframe columns') | |
upload_button.upload(fn=process_csv_text, inputs=upload_button, outputs=[table, cols], api_name="upload_csv") | |
# Tab 2: ChromaDB, Embeddings, and Plotting | |
with gr.Tab("Select Column and insert embeddings to ChromaDb"): | |
with gr.Row(): | |
gr.Markdown("<br>") | |
with gr.Row(): | |
cols = gr.Dropdown(choices=['text_column_1_placeholder'], label='Dataframe columns') | |
with gr.Row(): | |
embedding_model = gr.Dropdown(value='all-MiniLM-L6-v2', choices=['all-MiniLM-L6-v2', 'intfloat/e5-small-v2', 'intfloat/e5-base-v2', 'intfloat/e5-large-v2','paraphrase-multilingual-MiniLM-L12-v2'], label='Embedding model to use') | |
similarity_metric = gr.Dropdown(value='cosine', choices=['cosine', 'l2'], label='Similarity metric to use') | |
with gr.Row(): | |
embedding_button = gr.Button(value="Insert or update rows from selected column to embeddings db") | |
text = gr.Textbox(label='Process status for Chroma', placeholder='This will be updated once you click "Process status for Chroma"') | |
with gr.Row(): | |
show_embeddings_button = gr.Button(value="Calculate 2d values from embeddings and show scatter plot") | |
embeddings_plot = gr.Plot() | |
with gr.Row(): | |
tsne = gr.State(value=None) | |
test_string = gr.Textbox(label='test string to try to embed', value="Insert test string here") | |
with gr.Row(): | |
calculate_2d_repr_button = gr.Button(value="See where text string is in 2d") | |
embeddings_plot_with_text_string = gr.Plot() | |
embedding_button.click(insert_or_update_chroma, inputs=[cols, table, embedding_model, similarity_metric], outputs=[text]) | |
show_embeddings_button.click(show_fig, inputs=[], outputs=[embeddings_plot, tsne]) | |
calculate_2d_repr_button.click(show_test_string_fig, inputs=[test_string, tsne, embedding_model, similarity_metric], outputs=[embeddings_plot_with_text_string, tsne]) | |
# Tab 3: Chat with GPT Models | |
with gr.Tab("Chat"): | |
system_prompt = gr.Textbox(value="You are a helpful assistant.", label="System Message") | |
chatgpt_model = gr.Dropdown(value="gpt-4", choices=list(model_config.keys()), label="ChatGPT Model to Use") | |
temperature = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.7, label="Temperature") | |
max_tokens = gr.Slider(minimum=50, maximum=2000, step=50, value=300, label="Max Tokens") | |
chatbot = gr.Chatbot(label="ChatGPT Chat") | |
clear_button = gr.Button("Clear Chat History") | |
msg = gr.Textbox() | |
msg_log = gr.Textbox("Message history will be visible here", label='Message history') | |
msg.submit(ask_gpt, [msg, chatbot], [msg, chatbot]) | |
chatbot.submit(ask_gpt, [chatbot, system_prompt, embedding_model, temperature, max_tokens, chatgpt_model], [chatbot, system_prompt]) | |
clear_button.click(fn=lambda: None, inputs=None, outputs=[chatbot]) | |
# Launch Gradio interface | |
demo.launch() | |