Spaces:
Sleeping
Sleeping
import gradio as gr | |
import jsonlines | |
from sentence_transformers import CrossEncoder, SentenceTransformer | |
import json | |
from qdrant_client import QdrantClient | |
print("Setup client") | |
# chroma_client = chromadb.Client() | |
# collection = chroma_client.create_collection( | |
# name="food_collection", | |
# metadata={"hnsw:space": "cosine"} # l2 is the default | |
# ) | |
client = QdrantClient(":memory:") | |
print("load data") | |
with open("test_json.json", "r") as f: | |
payload = json.load(f) | |
def embedding_function(items_to_embed: list[str]): | |
print("embedding") | |
sentence_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
embedded_items = sentence_model.encode(items_to_embed) | |
print(len(embedded_items)) | |
print(type(embedded_items[0])) | |
print(type(embedded_items[0][0])) | |
embedded_list = [item.tolist() for item in embedded_items] | |
print(len(embedded_list)) | |
print(type(embedded_list[0])) | |
print(type(embedded_list[0][0])) | |
return embedded_list | |
print("upserting") | |
print("printing item:") | |
embedding = embedding_function([item["doc"] for item in payload]) | |
print(type(embedding)) | |
client.add( | |
collection_name="food", | |
documents=[item["doc"] for item in payload], | |
# embeddings=embedding, | |
metadata=[{"payload": item} for item in payload], | |
ids=[idx for idx, _ in enumerate(payload)], | |
) | |
def search_chroma(query: str): | |
results = client.query( | |
# query_embeddings=embedding_function([query]), | |
collection_name="food", | |
query_text=query, | |
limit=5, | |
) | |
# print(results[0]) | |
# print(results[0].QueryResponse.metadata) | |
# instructions = ['\n'.join(item.metadata['payload']['instructions']) for item in results] | |
# text_only= [f"# Title:\n{item.metadata['payload']['title']}\n\n## Description:\n{item.metadata['payload']['doc']}\n\n ## Instructions:\n{instructions}" for item in results] | |
top_k = [item.document for item in results] | |
reranked = reranking_results(query, top_k) | |
ordered_results = [] | |
for item in reranked: | |
for result in results: | |
if item["text"] == result.document: | |
ordered_results.append(result) | |
text_only = [] | |
for item in ordered_results: | |
instructions = "- " + "<br>- ".join(item.metadata["payload"]["instructions"]) | |
markdown_text = f"# Dish: {item.metadata['payload']['title']}\n\n## Description:\n{item.metadata['payload']['doc']}\n\n ## Instructions:\n{instructions}\n\n### Score: {item.score}\n" | |
text_only.append(markdown_text) | |
return "\n".join(text_only) | |
def reranking_results(query: str, top_k_results: list[str]): | |
# Load the model, here we use our base sized model | |
rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") | |
reranked_results = rerank_model.rank(query, top_k_results, return_documents=True) | |
return reranked_results | |
def run_query(query_string: str): | |
meal_string = search_chroma(query_string) | |
return meal_string | |
with gr.Blocks() as meal_search: | |
gr.Markdown("Start typing below and then click **Run** to see the output.") | |
with gr.Row(): | |
inp = gr.Textbox(placeholder="What sort of meal are you after?") | |
out = gr.Markdown() | |
btn = gr.Button("Run") | |
btn.click(fn=run_query, inputs=inp, outputs=out) | |
meal_search.launch() | |