Spaces:
Sleeping
Sleeping
import gradio as gr | |
import chromadb | |
from sentence_transformers import CrossEncoder, SentenceTransformer | |
import json | |
print("Setup client") | |
chroma_client = chromadb.Client() | |
collection = chroma_client.create_collection( | |
name="food_collection", | |
metadata={"hnsw:space": "cosine"} # l2 is the default | |
) | |
print("load data") | |
with open("test_json.json", "r") as f: | |
payload = json.load(f) | |
def embedding_function(items_to_embed: list[str]): | |
print("embedding") | |
sentence_model = SentenceTransformer( | |
"mixedbread-ai/mxbai-embed-large-v1" | |
) | |
embedded_items = sentence_model.encode( | |
items_to_embed | |
) | |
print(len(embedded_items)) | |
print(type(embedded_items[0])) | |
print(type(embedded_items[0][0])) | |
embedded_list = [item.tolist() for item in embedded_items] | |
print(len(embedded_list)) | |
print(type(embedded_list[0])) | |
print(type(embedded_list[0][0])) | |
return embedded_list | |
print('upserting') | |
print("printing item:") | |
embedding = embedding_function([item['doc'] for item in payload]) | |
print(type(embedding)) | |
collection.add( | |
collection_name="food", | |
documents=[item['doc'] for item in payload], | |
#embeddings=embedding, | |
metadatas=[{'payload':item} for item in payload], | |
ids=[f"id_{idx}" for idx, _ in enumerate(payload)] | |
) | |
def search_chroma(query:str): | |
results = client.query( | |
#query_embeddings=embedding_function([query]), | |
collection="food", | |
query_text=query | |
#n_results=2 | |
) | |
return results | |
text_only= [f"# Dish:\n{item}\n\n" for item in results['documents'][0]] | |
return "".join(text_only) | |
def reranking_results(query: str, top_k_results: list[str]): | |
# Load the model, here we use our base sized model | |
rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") | |
reranked_results = rerank_model.rank(query, top_k_results, return_documents=True) | |
return reranked_results | |
def run_query(query_string: str): | |
meal_string = search_chroma(query_string) | |
return meal_string | |
with gr.Blocks() as meal_search: | |
gr.Markdown("Start typing below and then click **Run** to see the output.") | |
with gr.Row(): | |
inp = gr.Textbox(placeholder="What sort of meal are you after?") | |
out = gr.Markdown() | |
btn = gr.Button("Run") | |
btn.click( | |
fn=run_query, | |
inputs=inp, | |
outputs=out | |
) | |
meal_search.launch() | |