janar commited on
Commit
f2932e2
·
1 Parent(s): edd979e

Introduce elastic search as backened store

Browse files

Idea is to see if it works any better if we use
elastic search based lookup.

In specific if we use elastic's sparse encoder

api/db/vector_store.py CHANGED
@@ -1,12 +1,60 @@
 
1
  import os
2
  from qdrant_client import QdrantClient
3
  from langchain.embeddings import OpenAIEmbeddings
4
- from langchain.vectorstores import Qdrant
 
5
 
6
  embeddings = OpenAIEmbeddings()
7
- client = QdrantClient(url=os.getenv("QDRANT_URL"),
8
- api_key=os.getenv("QDRANT_API_KEY"))
9
 
10
- def get_instance(collection: str = "test"):
11
- return Qdrant(client=client,collection_name=collection,embeddings=embeddings)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
  import os
3
  from qdrant_client import QdrantClient
4
  from langchain.embeddings import OpenAIEmbeddings
5
+ from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
6
+ from qdrant_client.models import VectorParams, Distance
7
 
8
  embeddings = OpenAIEmbeddings()
 
 
9
 
10
+ class ToyVectorStore:
 
11
 
12
+ @staticmethod
13
+ def get_instance():
14
+ vector_store = os.getenv("STORE")
15
+ if vector_store == "ELASTIC":
16
+ return ElasticVectorStore()
17
+ elif vector_store == "QDRANT":
18
+ return QdrantVectorStore()
19
+ else:
20
+ raise ValueError(f"Invalid vector store {vector_store}")
21
+
22
+ @abstractmethod
23
+ def get_collection(self, collection: str="test") -> VectorStore:
24
+ """
25
+ get an instance of vector store
26
+ of collection
27
+ """
28
+ pass
29
+
30
+ @abstractmethod
31
+ def create_collection(self, collection: str) -> None:
32
+ """
33
+ create an instance of vector store
34
+ with collection name
35
+ """
36
+ pass
37
+
38
+ class ElasticVectorStore(ToyVectorStore):
39
+ def get_collection(self, collection:str) -> ElasticVectorSearch:
40
+ return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
41
+ index_name= collection, embedding=embeddings)
42
+
43
+ def create_collection(self, collection: str) -> None:
44
+ store = self.get_collection(collection)
45
+ store.create_index(store.client,collection, dict())
46
+
47
+
48
+ class QdrantVectorStore(ToyVectorStore):
49
+
50
+ def __init__(self):
51
+ self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
52
+ api_key=os.getenv("QDRANT_API_KEY"))
53
+
54
+ def get_collection(self, collection: str) -> Qdrant:
55
+ return Qdrant(client=self.client,collection_name=collection,embeddings=embeddings)
56
+
57
+ def create_collection(self, collection: str) -> None:
58
+ self.client.create_collection(collection_name=collection,
59
+ vectors_config=VectorParams(size=1536, distance=Distance.COSINE))
60
+
api/routes/admin.py CHANGED
@@ -2,9 +2,8 @@
2
 
3
  from typing import Annotated
4
 
5
- from qdrant_client.models import VectorParams, Distance
6
  from fastapi import APIRouter, Body
7
- from db import vector_store
8
 
9
  router = APIRouter()
10
 
@@ -14,5 +13,4 @@ async def recreate_collection(name: Annotated[str, Body(embed=True)]):
14
  If one exits, delete and recreate.
15
  """
16
  print(f"creating collection {name} in db")
17
- return vector_store.client.recreate_collection(collection_name=name,
18
- vectors_config=VectorParams(size=1536, distance=Distance.COSINE))
 
2
 
3
  from typing import Annotated
4
 
 
5
  from fastapi import APIRouter, Body
6
+ from db.vector_store import ToyVectorStore
7
 
8
  router = APIRouter()
9
 
 
13
  If one exits, delete and recreate.
14
  """
15
  print(f"creating collection {name} in db")
16
+ return ToyVectorStore.get_instance().create_collection(name)
 
api/routes/search.py CHANGED
@@ -11,22 +11,24 @@ from langchain.vectorstores import Qdrant
11
  from langchain.schema import Document
12
  from langchain.chains.question_answering import load_qa_chain
13
  from langchain.llms import OpenAI
14
- from db import vector_store
15
 
16
  router = APIRouter()
17
  _chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", verbose=True)
18
 
19
  @router.post("/v1/docs")
20
  async def create_or_update(name: Annotated[str, Body()], file_name: Annotated[str, Body()], file: UploadFile = File(...)):
21
- """create or update a collection
22
  `name` of the collection
23
  `file` to upload.
24
  `fileName` name of the file.
25
  """
26
 
27
- _db = vector_store.get_instance(name)
28
  if not _db:
 
29
  return JSONResponse(status_code=404, content={})
 
30
  async for doc in generate_documents(file, file_name):
31
  print(doc)
32
  _db.add_documents([doc])
@@ -39,9 +41,9 @@ async def answer(name: str, query: str):
39
  `name` of the collection.
40
  `query` to be answered.
41
  """
42
- _db = vector_store.get_instance(name)
43
  print(query)
44
- docs = _db.similarity_search_with_relevance_scores(query=query)
45
  print(docs)
46
  answer = _chain.run(input_documents=[tup[0] for tup in docs], question=query)
47
  return JSONResponse(status_code=200, content={"answer": answer, "file_score": [[f"{d[0].metadata['file']} : {d[0].metadata['page']}", d[1]] for d in docs]})
 
11
  from langchain.schema import Document
12
  from langchain.chains.question_answering import load_qa_chain
13
  from langchain.llms import OpenAI
14
+ from db.vector_store import ToyVectorStore
15
 
16
  router = APIRouter()
17
  _chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", verbose=True)
18
 
19
  @router.post("/v1/docs")
20
  async def create_or_update(name: Annotated[str, Body()], file_name: Annotated[str, Body()], file: UploadFile = File(...)):
21
+ """Create or update an existing collection with information from the file
22
  `name` of the collection
23
  `file` to upload.
24
  `fileName` name of the file.
25
  """
26
 
27
+ _db = ToyVectorStore.get_instance().get_collection(name)
28
  if not _db:
29
+ #todo. fix this to create a collection, may be.
30
  return JSONResponse(status_code=404, content={})
31
+
32
  async for doc in generate_documents(file, file_name):
33
  print(doc)
34
  _db.add_documents([doc])
 
41
  `name` of the collection.
42
  `query` to be answered.
43
  """
44
+ _db = ToyVectorStore.get_instance().get_collection(name)
45
  print(query)
46
+ docs = _db.similarity_search_with_score(query=query)
47
  print(docs)
48
  answer = _chain.run(input_documents=[tup[0] for tup in docs], question=query)
49
  return JSONResponse(status_code=200, content={"answer": answer, "file_score": [[f"{d[0].metadata['file']} : {d[0].metadata['page']}", d[1]] for d in docs]})
requirements.txt CHANGED
@@ -8,4 +8,4 @@ langchain
8
  tiktoken
9
  faiss-cpu
10
  qdrant-client
11
-
 
8
  tiktoken
9
  faiss-cpu
10
  qdrant-client
11
+ elasticsearch