rchrdgwr commited on
Commit
0614fbf
·
1 Parent(s): ba0bf00

cleaned up code

Browse files
BuildingAChainlitApp.md CHANGED
@@ -257,34 +257,53 @@ Code was modified to support pdf documents in the following areas:
257
 
258
  2) change process_text_file() function to handle .pdf files
259
 
260
- - identify the file extension
261
- - read the uploaded document into a temporary file
262
- - process a .txt file as before resulting in the texts list
263
- - if the file is .pdf use the PyMuPDF library to read each page and extract the text and add it to texts list
 
 
 
 
 
 
 
 
264
 
265
  ```python
266
- file_extension = os.path.splitext(file.name)[1].lower()
267
-
268
- with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_extension) as temp_file:
269
- temp_file_path = temp_file.name
270
- temp_file.write(file.content)
271
-
272
- if file_extension == ".txt":
273
- with open(temp_file_path, "r", encoding="utf-8") as f:
274
- text_loader = TextFileLoader(temp_file_path)
275
- documents = text_loader.load_documents()
276
- texts = text_splitter.split_texts(documents)
277
-
278
- elif file_extension == ".pdf":
279
- pdf_document = fitz.open(temp_file_path)
280
- documents = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  for page_num in range(len(pdf_document)):
282
  page = pdf_document.load_page(page_num)
283
  text = page.get_text()
284
- documents.append(text)
285
- texts = text_splitter.split_texts(documents)
286
- else:
287
- raise ValueError("Unsupported file type")
288
  ```
289
 
290
  3) Test the handling of .pdf and .txt files
 
257
 
258
  2) change process_text_file() function to handle .pdf files
259
 
260
+ - refactor the code to do all file handling in richard.text_utils
261
+ - app calls process_file, optionally passing in the text splitter function
262
+ - default text splitter function is CharacterTextSplitter
263
+ ```python
264
+ texts = process_file(file)
265
+ ```
266
+ - load_file() function does the following
267
+ - read the uploaded document into a temporary file
268
+ - identify the file extension
269
+ - process a .txt file as before resulting in the texts list
270
+ - if the file is .pdf use the PyMuPDF library to read each page and extract the text and add it to texts list
271
+ - use the passed in text splitter function to split the documents
272
 
273
  ```python
274
+ def load_file(self, file, text_splitter=CharacterTextSplitter()):
275
+ file_extension = os.path.splitext(file.name)[1].lower()
276
+ with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_extension) as temp_file:
277
+ self.temp_file_path = temp_file.name
278
+ temp_file.write(file.content)
279
+
280
+ if os.path.isfile(self.temp_file_path):
281
+ if self.temp_file_path.endswith(".txt"):
282
+ self.load_text_file()
283
+ elif self.temp_file_path.endswith(".pdf"):
284
+ self.load_pdf_file()
285
+ else:
286
+ raise ValueError(
287
+ f"Unsupported file type: {self.temp_file_path}"
288
+ )
289
+ return text_splitter.split_texts(self.documents)
290
+ else:
291
+ raise ValueError(
292
+ "Not a file"
293
+ )
294
+
295
+ def load_text_file(self):
296
+ with open(self.temp_file_path, "r", encoding=self.encoding) as f:
297
+ self.documents.append(f.read())
298
+
299
+ def load_pdf_file(self):
300
+ print("load_pdf_file()")
301
+ pdf_document = fitz.open(self.temp_file_path)
302
+ print(len(pdf_document))
303
  for page_num in range(len(pdf_document)):
304
  page = pdf_document.load_page(page_num)
305
  text = page.get_text()
306
+ self.documents.append(text)
 
 
 
307
  ```
308
 
309
  3) Test the handling of .pdf and .txt files
aimakerspace/vectordatabase.py CHANGED
@@ -52,77 +52,9 @@ class VectorDatabase:
52
  for text, embedding in zip(list_of_text, embeddings):
53
  self.insert(text, np.array(embedding))
54
  return self
55
- import hashlib
56
- from qdrant_client import QdrantClient
57
- from qdrant_client.http.models import PointStruct
58
- class QdrantDatabase:
59
- def __init__(self, qdrant_client: QdrantClient, collection_name: str, embedding_model=None):
60
- self.qdrant_client = qdrant_client
61
- self.collection_name = collection_name
62
- self.embedding_model = embedding_model or EmbeddingModel()
63
- self.vectors = defaultdict(np.array) # Still keeps a local copy if needed
64
-
65
- def string_to_int_id(self, s: str) -> int:
66
- return int(hashlib.sha256(s.encode('utf-8')).hexdigest(), 16) % (10**8)
67
-
68
- def insert(self, key: str, vector: np.array) -> None:
69
-
70
- point_id = self.string_to_int_id(key)
71
- # Insert vector into Qdrant
72
- payload = {"text": key} # Storing the key (text) as payload
73
- point = PointStruct(
74
- id=point_id,
75
- vector={"default": vector.tolist()}, # Use the vector name defined in the collection
76
- payload=payload
77
- )
78
-
79
- # Insert the vector into Qdrant with the associated document
80
- self.qdrant_client.upsert(
81
- collection_name=self.collection_name,
82
- points=[point] # Qdrant expects a list of PointStruct
83
- )
84
-
85
- def search(
86
- self,
87
- query_vector: np.array,
88
- k: int,
89
- distance_measure: Callable = None,
90
- ) -> List[Tuple[str, float]]:
91
- # Perform search in Qdrant
92
- print(query_vector)
93
- if isinstance(query_vector, list):
94
- query_vector = np.array(query_vector)
95
-
96
- search_results = self.qdrant_client.search(
97
- collection_name=self.collection_name,
98
- query_vector={"name": "default", "vector": query_vector.tolist()},# Convert numpy array to list
99
- limit=k
100
- )
101
-
102
- # Extract and return results
103
- return [(result.payload['text'], result.score) for result in search_results]
104
 
105
- def search_by_text(
106
- self,
107
- query_text: str,
108
- k: int,
109
- distance_measure: Callable = None,
110
- return_as_text: bool = False,
111
- ) -> List[Tuple[str, float]]:
112
- query_vector = self.embedding_model.get_embedding(query_text)
113
- results = self.search(query_vector, k, distance_measure)
114
- return [result[0] for result in results] if return_as_text else results
115
 
116
- def retrieve_from_key(self, key: str) -> np.array:
117
- # Retrieve from local cache
118
- return self.vectors.get(key, None)
119
 
120
- async def abuild_from_list(self, list_of_text: List[str]) -> "QdrantDatabase":
121
- embeddings = await self.embedding_model.async_get_embeddings(list_of_text)
122
- for text, embedding in zip(list_of_text, embeddings):
123
- self.insert(text, np.array(embedding))
124
- return self
125
-
126
  if __name__ == "__main__":
127
  list_of_text = [
128
  "I like to eat broccoli and bananas.",
 
52
  for text, embedding in zip(list_of_text, embeddings):
53
  self.insert(text, np.array(embedding))
54
  return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
57
 
 
 
 
 
 
 
58
  if __name__ == "__main__":
59
  list_of_text = [
60
  "I like to eat broccoli and bananas.",
app.py CHANGED
@@ -1,20 +1,27 @@
1
  import os
2
- from typing import List
3
  from chainlit.types import AskFileResponse
4
- from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader
5
  from aimakerspace.openai_utils.prompts import (
6
  UserRolePrompt,
7
  SystemRolePrompt,
8
  AssistantRolePrompt,
9
  )
10
  from aimakerspace.openai_utils.embedding import EmbeddingModel
11
- from aimakerspace.vectordatabase import VectorDatabase, QdrantDatabase
12
  from aimakerspace.openai_utils.chatmodel import ChatOpenAI
13
  import chainlit as cl
14
- import fitz
 
 
 
 
 
15
 
16
  system_template = """\
17
- Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
 
 
 
18
  system_role_prompt = SystemRolePrompt(system_template)
19
 
20
  user_prompt_template = """\
@@ -26,65 +33,39 @@ Question:
26
  """
27
  user_role_prompt = UserRolePrompt(user_prompt_template)
28
 
29
- class RetrievalAugmentedQAPipeline:
30
- def __init__(self, llm, vector_db_retriever: VectorDatabase) -> None:
31
- self.llm = llm
32
- self.vector_db_retriever = vector_db_retriever
33
-
34
- async def arun_pipeline(self, user_query: str):
35
- context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
36
-
37
- context_prompt = ""
38
- for context in context_list:
39
- context_prompt += context[0] + "\n"
40
-
41
- formatted_system_prompt = system_role_prompt.create_message()
42
-
43
- formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
44
-
45
- async def generate_response():
46
- async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
47
- yield chunk
48
-
49
- return {"response": generate_response(), "context": context_list}
50
-
51
- text_splitter = CharacterTextSplitter()
52
-
53
-
54
- def process_text_file(file: AskFileResponse):
55
- import tempfile
56
-
57
- file_extension = os.path.splitext(file.name)[1].lower()
58
-
59
- with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_extension) as temp_file:
60
- temp_file_path = temp_file.name
61
- temp_file.write(file.content)
62
-
63
- if file_extension == ".txt":
64
- with open(temp_file_path, "r", encoding="utf-8") as f:
65
- text_loader = TextFileLoader(temp_file_path)
66
- documents = text_loader.load_documents()
67
- texts = text_splitter.split_texts(documents)
68
-
69
- elif file_extension == ".pdf":
70
- pdf_document = fitz.open(temp_file_path)
71
- documents = []
72
- for page_num in range(len(pdf_document)):
73
- page = pdf_document.load_page(page_num)
74
- text = page.get_text()
75
- documents.append(text)
76
- texts = text_splitter.split_texts(documents)
77
- else:
78
- raise ValueError("Unsupported file type")
79
-
80
- return texts
81
-
82
 
83
 
84
  @cl.on_chat_start
85
  async def on_chat_start():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  files = None
87
-
88
  # Wait for the user to upload a file
89
  while not files:
90
  files = await cl.AskFileMessage(
@@ -102,63 +83,65 @@ async def on_chat_start():
102
  await msg.send()
103
 
104
  # load the file
105
- texts = process_text_file(file)
106
 
107
  msg = cl.Message(
108
  content=f"Resulted in {len(texts)} chunks", disable_human_feedback=True
109
  )
110
  await msg.send()
111
 
112
- print(f"Processing {len(texts)} text chunks")
113
-
114
  # decide if to use the dict vector store of the Qdrant vector store
115
-
116
- use_qdrant = True
117
- from qdrant_client import QdrantClient
118
- from qdrant_client.http.models import VectorParams, Distance
119
  # Create a dict vector store
120
- if use_qdrant:
 
 
 
121
  embedding_model = EmbeddingModel()
122
- qdrant_client = QdrantClient(
123
- url='https://6b3eac94-adfe-42cb-98f8-9f068538243c.europe-west3-0.gcp.cloud.qdrant.io:6333', # Replace with your cluster URL
124
- api_key='YrnApyEfdNAt41N7WkcZwjhjKqiIQQbXHBtzk_04guNyRLa83J0hOw' # Replace with your API key
125
- )
126
- vectors_config = {
127
- "default": VectorParams(size=1536, distance="Cosine") # Adjust size as per your model's output
128
- }
129
- if not qdrant_client.collection_exists("my_collection"):
130
- qdrant_client.create_collection(
 
 
131
  collection_name="my_collection",
132
- vectors_config=vectors_config
133
  )
134
 
135
- vector_db = QdrantDatabase(
136
- qdrant_client=qdrant_client,
137
- collection_name="my_collection",
138
- embedding_model=embedding_model # Replace with your embedding model instance
139
- )
140
- vector_db = await vector_db.abuild_from_list(texts)
141
 
142
- else:
143
- vector_db = VectorDatabase()
144
- vector_db = await vector_db.abuild_from_list(texts)
145
 
146
  msg = cl.Message(
147
  content=f"The Vector store has been created", disable_human_feedback=True
148
  )
149
  await msg.send()
 
150
  chat_openai = ChatOpenAI()
151
 
152
  # Create a chain
153
  retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
154
  vector_db_retriever=vector_db,
155
- llm=chat_openai
 
 
156
  )
157
 
158
  # Let the user know that the system is ready
159
- msg.content = f"Processing `{file.name}` done. You can now ask questions!"
 
 
160
  await msg.update()
161
-
162
  cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
163
 
164
 
 
1
  import os
 
2
  from chainlit.types import AskFileResponse
3
+
4
  from aimakerspace.openai_utils.prompts import (
5
  UserRolePrompt,
6
  SystemRolePrompt,
7
  AssistantRolePrompt,
8
  )
9
  from aimakerspace.openai_utils.embedding import EmbeddingModel
10
+ from aimakerspace.vectordatabase import VectorDatabase
11
  from aimakerspace.openai_utils.chatmodel import ChatOpenAI
12
  import chainlit as cl
13
+ from richard.text_utils import FileLoader
14
+ from richard.pipeline import RetrievalAugmentedQAPipeline
15
+ # from richard.vector_database import QdrantDatabase
16
+ from qdrant_client import QdrantClient
17
+ from langchain.vectorstores import Qdrant
18
+
19
 
20
  system_template = """\
21
+ Use the following context to answer a users question.
22
+ If you cannot find the answer in the context, say you don't know the answer.
23
+ The context contains the text from a document. Refer to it as the document not the context.
24
+ """
25
  system_role_prompt = SystemRolePrompt(system_template)
26
 
27
  user_prompt_template = """\
 
33
  """
34
  user_role_prompt = UserRolePrompt(user_prompt_template)
35
 
36
+ def process_file(file: AskFileResponse):
37
+ fileLoader = FileLoader()
38
+ return fileLoader.load_file(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  @cl.on_chat_start
42
  async def on_chat_start():
43
+ res = await cl.AskActionMessage(
44
+ content="Do you want to use Qdrant?",
45
+ actions=[
46
+ cl.Action(name="yes", value="yes", label="✅ Yes"),
47
+ cl.Action(name="no", value="no", label="❌ No"),
48
+ ],
49
+ ).send()
50
+ use_qdrant = False
51
+ use_qdrant_type = "Local"
52
+ if res and res.get("value") == "yes":
53
+ use_qdrant = True
54
+ local_res = await cl.AskActionMessage(
55
+ content="Do you want to use local or cloud?",
56
+ actions=[
57
+ cl.Action(name="Local", value="Local", label="✅ Local"),
58
+ cl.Action(name="Cloud", value="Cloud", label="❌ Cloud"),
59
+ ],
60
+ ).send()
61
+ if local_res and local_res.get("value") == "Cloud":
62
+ use_qdrant_type = "Cloud"
63
+ msg = cl.Message(
64
+ content=f"Sorry - the Qdrant processing has been temporarily disconnected"
65
+ )
66
+ await msg.send()
67
+ use_qdrant = False
68
  files = None
 
69
  # Wait for the user to upload a file
70
  while not files:
71
  files = await cl.AskFileMessage(
 
83
  await msg.send()
84
 
85
  # load the file
86
+ texts = process_file(file)
87
 
88
  msg = cl.Message(
89
  content=f"Resulted in {len(texts)} chunks", disable_human_feedback=True
90
  )
91
  await msg.send()
92
 
 
 
93
  # decide if to use the dict vector store of the Qdrant vector store
94
+ from qdrant_client.models import PointStruct, VectorParams
 
 
 
95
  # Create a dict vector store
96
+ if use_qdrant == False:
97
+ vector_db = VectorDatabase()
98
+ vector_db = await vector_db.abuild_from_list(texts)
99
+ else:
100
  embedding_model = EmbeddingModel()
101
+ if use_qdrant_type == "Local":
102
+ from qdrant_client.http.models import OptimizersConfig
103
+ print("Using qdrant local")
104
+ qdrant_client = QdrantClient(location=":memory:")
105
+
106
+ vector_params = VectorParams(
107
+ size=1536, # vector size
108
+ distance="Cosine" # distance metric
109
+ )
110
+
111
+ qdrant_client.recreate_collection(
112
  collection_name="my_collection",
113
+ vectors_config={"default": vector_params},
114
  )
115
 
116
+ from richard.vector_database import QdrantDatabase
117
+ vector_db = QdrantDatabase(
118
+ qdrant_client=qdrant_client,
119
+ collection_name="my_collection",
120
+ embedding_model=embedding_model
121
+ )
122
 
123
+ vector_db = await vector_db.abuild_from_list(texts)
 
 
124
 
125
  msg = cl.Message(
126
  content=f"The Vector store has been created", disable_human_feedback=True
127
  )
128
  await msg.send()
129
+
130
  chat_openai = ChatOpenAI()
131
 
132
  # Create a chain
133
  retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
134
  vector_db_retriever=vector_db,
135
+ llm=chat_openai,
136
+ system_role_prompt=system_role_prompt,
137
+ user_role_prompt=user_role_prompt
138
  )
139
 
140
  # Let the user know that the system is ready
141
+ msg.content = f"Processing `{file.name}` is complete."
142
+ await msg.update()
143
+ msg.content = f"You can now ask questions about `{file.name}`."
144
  await msg.update()
 
145
  cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
146
 
147
 
richard/__init__.py ADDED
File without changes
richard/pipeline.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aimakerspace.vectordatabase import VectorDatabase
2
+
3
+ class RetrievalAugmentedQAPipeline:
4
+ def __init__(self, llm, vector_db_retriever: VectorDatabase,
5
+ system_role_prompt, user_role_prompt
6
+ ) -> None:
7
+ self.llm = llm
8
+ self.vector_db_retriever = vector_db_retriever
9
+ self.system_role_prompt = system_role_prompt
10
+ self.user_role_prompt = user_role_prompt
11
+
12
+ async def arun_pipeline(self, user_query: str):
13
+ context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
14
+
15
+ context_prompt = ""
16
+ for context in context_list:
17
+ context_prompt += context[0] + "\n"
18
+
19
+ formatted_system_prompt = self.system_role_prompt.create_message()
20
+
21
+ formatted_user_prompt = self.user_role_prompt.create_message(question=user_query, context=context_prompt)
22
+
23
+ async def generate_response():
24
+ async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
25
+ yield chunk
26
+
27
+ return {"response": generate_response(), "context": context_list}
richard/text_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fitz
3
+ import tempfile
4
+ from aimakerspace.text_utils import CharacterTextSplitter
5
+
6
+ class FileLoader:
7
+
8
+ def __init__(self, encoding: str = "utf-8"):
9
+ self.documents = []
10
+ self.encoding = encoding
11
+ self.temp_file_path = ""
12
+
13
+
14
+ def load_file(self, file, text_splitter=CharacterTextSplitter()):
15
+ file_extension = os.path.splitext(file.name)[1].lower()
16
+ with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_extension) as temp_file:
17
+ self.temp_file_path = temp_file.name
18
+ temp_file.write(file.content)
19
+
20
+ if os.path.isfile(self.temp_file_path):
21
+ if self.temp_file_path.endswith(".txt"):
22
+ self.load_text_file()
23
+ elif self.temp_file_path.endswith(".pdf"):
24
+ self.load_pdf_file()
25
+ else:
26
+ raise ValueError(
27
+ f"Unsupported file type: {self.temp_file_path}"
28
+ )
29
+ return text_splitter.split_texts(self.documents)
30
+ else:
31
+ raise ValueError(
32
+ "Not a file"
33
+ )
34
+
35
+ def load_text_file(self):
36
+ with open(self.temp_file_path, "r", encoding=self.encoding) as f:
37
+ self.documents.append(f.read())
38
+
39
+ def load_pdf_file(self):
40
+ print("load_pdf_file()")
41
+ pdf_document = fitz.open(self.temp_file_path)
42
+ print(len(pdf_document))
43
+ for page_num in range(len(pdf_document)):
44
+ page = pdf_document.load_page(page_num)
45
+ text = page.get_text()
46
+ self.documents.append(text)
richard/vector_database.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import defaultdict
3
+ from typing import List, Tuple, Callable
4
+ from aimakerspace.openai_utils.embedding import EmbeddingModel
5
+ import hashlib
6
+ from qdrant_client import QdrantClient
7
+ from qdrant_client.http.models import PointStruct
8
+
9
+ def cosine_similarity(vector_a: np.array, vector_b: np.array) -> float:
10
+ """Computes the cosine similarity between two vectors."""
11
+ dot_product = np.dot(vector_a, vector_b)
12
+ norm_a = np.linalg.norm(vector_a)
13
+ norm_b = np.linalg.norm(vector_b)
14
+ return dot_product / (norm_a * norm_b)
15
+
16
+
17
+ class QdrantDatabase:
18
+ def __init__(self, qdrant_client: QdrantClient, collection_name: str, embedding_model=None):
19
+ self.qdrant_client = qdrant_client
20
+ self.collection_name = collection_name
21
+ self.embedding_model = embedding_model or EmbeddingModel()
22
+ self.vectors = defaultdict(np.array) # Still keeps a local copy if needed
23
+
24
+ def string_to_int_id(self, s: str) -> int:
25
+ return int(hashlib.sha256(s.encode('utf-8')).hexdigest(), 16) % (10**8)
26
+ def get_test_vector(self):
27
+ retrieved_vector = self.qdrant_client.retrieve(
28
+ collection_name="my_collection",
29
+ ids=[self.string_to_int_id("test_key")]
30
+ )
31
+ return retrieved_vector
32
+ def insert(self, key: str, vector: np.array) -> None:
33
+ point_id = self.string_to_int_id(key)
34
+ payload = {"text": key}
35
+
36
+ point = PointStruct(
37
+ id=point_id,
38
+ vector={"default": vector.tolist()},
39
+ payload=payload
40
+ )
41
+ print(f"Inserting vector for key: {key}, ID: {point_id}")
42
+ # Insert the vector into Qdrant with the associated document
43
+ self.qdrant_client.upsert(
44
+ collection_name=self.collection_name,
45
+ points=[point] # Qdrant expects a list of PointStruct
46
+ )
47
+ print(f"Inserted vector for key: {key} with ID: {point_id}")
48
+ retrieved_vector = self.qdrant_client.retrieve(
49
+ collection_name=self.collection_name,
50
+ ids=[point_id]
51
+ )
52
+ print(f"Inserted vector with ID: {point_id}, retrieved: {retrieved_vector}")
53
+ self.list_vectors()
54
+
55
+
56
+ def list_vectors(self):
57
+ # List all vectors in the collection for debugging
58
+ collection_info = self.qdrant_client.get_collection(self.collection_name)
59
+ print(f"Collection info: {collection_info}")
60
+
61
+ def search(
62
+ self,
63
+ query_vector: np.array,
64
+ k: int,
65
+ distance_measure: Callable = None,
66
+ ) -> List[Tuple[str, float]]:
67
+ # Perform search in Qdrant
68
+ if isinstance(query_vector, list):
69
+ query_vector = np.array(query_vector)
70
+ print(self.collection_name)
71
+ print(f"Searching in collection: {self.collection_name} with vector: {query_vector}")
72
+ collection_info = self.qdrant_client.get_collection(self.collection_name)
73
+ print(f"Collection info: {collection_info}")
74
+
75
+ search_results = self.qdrant_client.search(
76
+ collection_name=self.collection_name,
77
+ query_vector=query_vector.tolist(), # Pass the vector as a list
78
+ limit=k
79
+ )
80
+
81
+ print(f"Search results: {search_results}")
82
+ # print(query_vector.tolist())
83
+ # search_results = self.qdrant_client.query_points(
84
+ # collection_name=self.collection_name,
85
+ # query=query_vector.tolist(), # Pass the vector as a list
86
+ # limit=k,
87
+ # )
88
+ # Extract and return results
89
+ return [(result.payload['text'], result.score) for result in search_results]
90
+
91
+ def search_by_text(
92
+ self,
93
+ query_text: str,
94
+ k: int,
95
+ distance_measure: Callable = None,
96
+ return_as_text: bool = False,
97
+ ) -> List[Tuple[str, float]]:
98
+ self.list_vectors()
99
+ query_vector = self.embedding_model.get_embedding(query_text)
100
+ results = self.search(query_vector, k, distance_measure)
101
+ return [result[0] for result in results] if return_as_text else results
102
+
103
+ def retrieve_from_key(self, key: str) -> np.array:
104
+ # Retrieve from local cache
105
+ return self.vectors.get(key, None)
106
+
107
+ async def abuild_from_list(self, list_of_text: List[str]) -> "QdrantDatabase":
108
+ embeddings = await self.embedding_model.async_get_embeddings(list_of_text)
109
+ for text, embedding in zip(list_of_text, embeddings):
110
+ self.insert(text, np.array(embedding))
111
+ return self
112
+