Mikeplockhart commited on
Commit
52a9cd3
1 Parent(s): 304394c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -6
app.py CHANGED
@@ -1,13 +1,63 @@
1
  import gradio as gr
2
- import utils
 
 
3
 
4
- def setup(collection):
5
- data_loads = utils.load_data()
6
- #print(data_loads)
7
- utils.chroma_upserting(collection, data_loads)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def run_query(query_string: str, collection):
10
- meal_string = utils.search_chroma(collection, query_string)
11
  return meal_string
12
 
13
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import chromadb
3
+ from sentence_transformers import CrossEncoder, SentenceTransformer
4
+ import json
5
 
6
+ print("Setup client")
7
+ chroma_client = chromadb.Client()
8
+ collection = chroma_client.create_collection(
9
+ name="food_collection",
10
+ metadata={"hnsw:space": "cosine"} # l2 is the default
11
+ )
12
+
13
+ print("load data")
14
+ with open("test_json.json", "r") as f:
15
+ payload = json.load(f)
16
+
17
+ def embedding_function(items_to_embed: list[str]):
18
+ print("embedding")
19
+ sentence_model = SentenceTransformer(
20
+ "mixedbread-ai/mxbai-embed-large-v1"
21
+ )
22
+ embedded_items = sentence_model.encode(
23
+ items_to_embed
24
+ )
25
+ print(len(embedded_items))
26
+ print(type(embedded_items[0]))
27
+ print(type(embedded_items[0][0]))
28
+ embedded_list = [item.tolist() for item in embedded_items]
29
+ print(len(embedded_list))
30
+ print(type(embedded_list[0]))
31
+ print(type(embedded_list[0][0]))
32
+ return embedded_list
33
+
34
+
35
+ print('upserting')
36
+ print("printing item:")
37
+ embedding = embedding_function([item['doc'] for item in payload])
38
+ print(type(embedding))
39
+ collection.add(
40
+ documents=[item['doc'] for item in payload],
41
+ embeddings=embedding,
42
+ metadatas=[{'payload':item} for item in payload],
43
+ ids=[f"id_{idx}" for idx, _ in enumerate(payload)]
44
+ )
45
+
46
+ def search_chroma(collection, query:str):
47
+ results = collection.query(
48
+ query_embeddings=embedding_function([query]),
49
+ n_results=5
50
+ )
51
+ return results
52
+
53
+ def reranking_results(query: str, top_k_results: list[str]):
54
+ # Load the model, here we use our base sized model
55
+ rerank_model = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
56
+ reranked_results = rerank_model.rank(query, top_k_results, return_documents=True)
57
+ return reranked_results
58
 
59
  def run_query(query_string: str, collection):
60
+ meal_string = search_chroma(collection, query_string)
61
  return meal_string
62
 
63
  if __name__ == "__main__":