Spaces:
Sleeping
Sleeping
import gradio as gr | |
import chromadb | |
from sentence_transformers import CrossEncoder, SentenceTransformer | |
import json | |
from qdrant_client import QdrantClient | |
print("Setup client") | |
#chroma_client = chromadb.Client() | |
#collection = chroma_client.create_collection( | |
#name="food_collection", | |
#metadata={"hnsw:space": "cosine"} # l2 is the default | |
#) | |
client = QdrantClient(":memory:") | |
print("load data") | |
with open("test_json.json", "r") as f: | |
payload = json.load(f) | |
def embedding_function(items_to_embed: list[str]): | |
print("embedding") | |
sentence_model = SentenceTransformer( | |
"mixedbread-ai/mxbai-embed-large-v1" | |
) | |
embedded_items = sentence_model.encode( | |
items_to_embed | |
) | |
print(len(embedded_items)) | |
print(type(embedded_items[0])) | |
print(type(embedded_items[0][0])) | |
embedded_list = [item.tolist() for item in embedded_items] | |
print(len(embedded_list)) | |
print(type(embedded_list[0])) | |
print(type(embedded_list[0][0])) | |
return embedded_list | |
print('upserting') | |
print("printing item:") | |
embedding = embedding_function([item['doc'] for item in payload]) | |
print(type(embedding)) | |
client.add( | |
collection_name="food", | |
documents=[item['doc'] for item in payload], | |
#embeddings=embedding, | |
metadata=[{'payload':item} for item in payload], | |
ids=[idx for idx, _ in enumerate(payload)] | |
) | |
def search_chroma(query:str): | |
results = client.query( | |
#query_embeddings=embedding_function([query]), | |
collection_name="food", | |
query_text=query | |
#n_results=2 | |
) | |
#print(results[0]) | |
#print(results[0].QueryResponse.metadata) | |
#instructions = ['\n'.join(item.metadata['payload']['instructions']) for item in results] | |
#text_only= [f"# Title:\n{item.metadata['payload']['title']}\n\n## Description:\n{item.metadata['payload']['doc']}\n\n ## Instructions:\n{instructions}" for item in results] | |
text_only = [] | |
for item in results: | |
instructions = '\n'.join(item.metadata['payload']['instructions']) | |
markdown_text = f"# Title:\n{item.metadata['payload']['title']}\n\n## Description:\n{item.metadata['payload']['doc']}\n\n ## Instructions:\n{instructions}" | |
text_only.append(markdown_text) | |
print(text_only) | |
return "\n".join(text_only) | |
def reranking_results(query: str, top_k_results: list[str]): | |
# Load the model, here we use our base sized model | |
rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") | |
reranked_results = rerank_model.rank(query, top_k_results, return_documents=True) | |
return reranked_results | |
def run_query(query_string: str): | |
meal_string = search_chroma(query_string) | |
return meal_string | |
with gr.Blocks() as meal_search: | |
gr.Markdown("Start typing below and then click **Run** to see the output.") | |
with gr.Row(): | |
inp = gr.Textbox(placeholder="What sort of meal are you after?") | |
out = gr.Markdown() | |
btn = gr.Button("Run") | |
btn.click( | |
fn=run_query, | |
inputs=inp, | |
outputs=out | |
) | |
meal_search.launch() | |