File size: 3,051 Bytes
fbc7e49
52a9cd3
 
 
e055325
02b7760
52a9cd3
e055325
 
 
 
 
 
52a9cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e055325
5e3cc67
52a9cd3
5e3cc67
e055325
 
52a9cd3
 
7967a92
5e3cc67
 
c3f2eff
5e3cc67
 
52a9cd3
4c886ca
 
f5fba4f
 
 
 
 
 
 
4c886ca
a4370d3
52a9cd3
 
 
 
 
 
fbc7e49
e8c22b8
 
c77bb9e
fbc7e49
afc3612
 
 
 
3c9bd97
afc3612
 
 
 
 
 
02b7760
afc3612
fbc7e49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import chromadb
from sentence_transformers import CrossEncoder, SentenceTransformer
import json
from qdrant_client import QdrantClient

print("Setup client")
#chroma_client = chromadb.Client()
#collection = chroma_client.create_collection(
    #name="food_collection",
    #metadata={"hnsw:space": "cosine"} # l2 is the default
#)
client = QdrantClient(":memory:")

print("load data")
with open("test_json.json", "r") as f:
    payload = json.load(f)

def embedding_function(items_to_embed: list[str]):
    print("embedding")
    sentence_model = SentenceTransformer(
        "mixedbread-ai/mxbai-embed-large-v1"
    )
    embedded_items = sentence_model.encode(
        items_to_embed
    )
    print(len(embedded_items))
    print(type(embedded_items[0]))
    print(type(embedded_items[0][0]))
    embedded_list = [item.tolist() for item in embedded_items]
    print(len(embedded_list))
    print(type(embedded_list[0]))
    print(type(embedded_list[0][0]))
    return embedded_list


print('upserting')
print("printing item:")
embedding = embedding_function([item['doc'] for item in payload])
print(type(embedding))
client.add(
    collection_name="food",
    documents=[item['doc'] for item in payload],
    #embeddings=embedding,
    metadata=[{'payload':item} for item in payload],
    ids=[idx for idx, _ in enumerate(payload)]
    )

def search_chroma(query:str):
    results = client.query(
        #query_embeddings=embedding_function([query]),
        collection_name="food",
        query_text=query
        #n_results=2
    )
    #print(results[0])
    #print(results[0].QueryResponse.metadata)
    #instructions = ['\n'.join(item.metadata['payload']['instructions']) for item in results]
    #text_only= [f"# Title:\n{item.metadata['payload']['title']}\n\n## Description:\n{item.metadata['payload']['doc']}\n\n ## Instructions:\n{instructions}" for item in results]
    text_only = []
    for item in results:
        instructions = '\n'.join(item.metadata['payload']['instructions'])
        markdown_text = f"# Title:\n{item.metadata['payload']['title']}\n\n## Description:\n{item.metadata['payload']['doc']}\n\n ## Instructions:\n{instructions}"
        text_only.append(markdown_text)
    print(text_only)
    return "\n".join(text_only)

def reranking_results(query: str, top_k_results: list[str]):
    # Load the model, here we use our base sized model
    rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
    reranked_results = rerank_model.rank(query, top_k_results, return_documents=True)
    return reranked_results

def run_query(query_string: str):
    meal_string = search_chroma(query_string)
    return meal_string

with gr.Blocks() as meal_search:
    gr.Markdown("Start typing below and then click **Run** to see the output.")
    with gr.Row():
        inp = gr.Textbox(placeholder="What sort of meal are you after?")
        out = gr.Markdown()
    btn = gr.Button("Run")
    btn.click(
        fn=run_query,
        inputs=inp, 
        outputs=out
    )

meal_search.launch()