Spaces:
Sleeping
Sleeping
Introduce elastic search as backened store
Browse filesIdea is to see if it works any better if we use
elastic search based lookup.
In specific if we use elastic's sparse encoder
- api/db/vector_store.py +53 -5
- api/routes/admin.py +2 -4
- api/routes/search.py +7 -5
- requirements.txt +1 -1
api/db/vector_store.py
CHANGED
@@ -1,12 +1,60 @@
|
|
|
|
1 |
import os
|
2 |
from qdrant_client import QdrantClient
|
3 |
from langchain.embeddings import OpenAIEmbeddings
|
4 |
-
from langchain.vectorstores import Qdrant
|
|
|
5 |
|
6 |
embeddings = OpenAIEmbeddings()
|
7 |
-
client = QdrantClient(url=os.getenv("QDRANT_URL"),
|
8 |
-
api_key=os.getenv("QDRANT_API_KEY"))
|
9 |
|
10 |
-
|
11 |
-
return Qdrant(client=client,collection_name=collection,embeddings=embeddings)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
import os
|
3 |
from qdrant_client import QdrantClient
|
4 |
from langchain.embeddings import OpenAIEmbeddings
|
5 |
+
from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
|
6 |
+
from qdrant_client.models import VectorParams, Distance
|
7 |
|
8 |
embeddings = OpenAIEmbeddings()
|
|
|
|
|
9 |
|
10 |
+
class ToyVectorStore:
|
|
|
11 |
|
12 |
+
@staticmethod
|
13 |
+
def get_instance():
|
14 |
+
vector_store = os.getenv("STORE")
|
15 |
+
if vector_store == "ELASTIC":
|
16 |
+
return ElasticVectorStore()
|
17 |
+
elif vector_store == "QDRANT":
|
18 |
+
return QdrantVectorStore()
|
19 |
+
else:
|
20 |
+
raise ValueError(f"Invalid vector store {vector_store}")
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def get_collection(self, collection: str="test") -> VectorStore:
|
24 |
+
"""
|
25 |
+
get an instance of vector store
|
26 |
+
of collection
|
27 |
+
"""
|
28 |
+
pass
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def create_collection(self, collection: str) -> None:
|
32 |
+
"""
|
33 |
+
create an instance of vector store
|
34 |
+
with collection name
|
35 |
+
"""
|
36 |
+
pass
|
37 |
+
|
38 |
+
class ElasticVectorStore(ToyVectorStore):
|
39 |
+
def get_collection(self, collection:str) -> ElasticVectorSearch:
|
40 |
+
return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
|
41 |
+
index_name= collection, embedding=embeddings)
|
42 |
+
|
43 |
+
def create_collection(self, collection: str) -> None:
|
44 |
+
store = self.get_collection(collection)
|
45 |
+
store.create_index(store.client,collection, dict())
|
46 |
+
|
47 |
+
|
48 |
+
class QdrantVectorStore(ToyVectorStore):
|
49 |
+
|
50 |
+
def __init__(self):
|
51 |
+
self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
|
52 |
+
api_key=os.getenv("QDRANT_API_KEY"))
|
53 |
+
|
54 |
+
def get_collection(self, collection: str) -> Qdrant:
|
55 |
+
return Qdrant(client=self.client,collection_name=collection,embeddings=embeddings)
|
56 |
+
|
57 |
+
def create_collection(self, collection: str) -> None:
|
58 |
+
self.client.create_collection(collection_name=collection,
|
59 |
+
vectors_config=VectorParams(size=1536, distance=Distance.COSINE))
|
60 |
+
|
api/routes/admin.py
CHANGED
@@ -2,9 +2,8 @@
|
|
2 |
|
3 |
from typing import Annotated
|
4 |
|
5 |
-
from qdrant_client.models import VectorParams, Distance
|
6 |
from fastapi import APIRouter, Body
|
7 |
-
from db import
|
8 |
|
9 |
router = APIRouter()
|
10 |
|
@@ -14,5 +13,4 @@ async def recreate_collection(name: Annotated[str, Body(embed=True)]):
|
|
14 |
If one exits, delete and recreate.
|
15 |
"""
|
16 |
print(f"creating collection {name} in db")
|
17 |
-
return
|
18 |
-
vectors_config=VectorParams(size=1536, distance=Distance.COSINE))
|
|
|
2 |
|
3 |
from typing import Annotated
|
4 |
|
|
|
5 |
from fastapi import APIRouter, Body
|
6 |
+
from db.vector_store import ToyVectorStore
|
7 |
|
8 |
router = APIRouter()
|
9 |
|
|
|
13 |
If one exits, delete and recreate.
|
14 |
"""
|
15 |
print(f"creating collection {name} in db")
|
16 |
+
return ToyVectorStore.get_instance().create_collection(name)
|
|
api/routes/search.py
CHANGED
@@ -11,22 +11,24 @@ from langchain.vectorstores import Qdrant
|
|
11 |
from langchain.schema import Document
|
12 |
from langchain.chains.question_answering import load_qa_chain
|
13 |
from langchain.llms import OpenAI
|
14 |
-
from db import
|
15 |
|
16 |
router = APIRouter()
|
17 |
_chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", verbose=True)
|
18 |
|
19 |
@router.post("/v1/docs")
|
20 |
async def create_or_update(name: Annotated[str, Body()], file_name: Annotated[str, Body()], file: UploadFile = File(...)):
|
21 |
-
"""
|
22 |
`name` of the collection
|
23 |
`file` to upload.
|
24 |
`fileName` name of the file.
|
25 |
"""
|
26 |
|
27 |
-
_db =
|
28 |
if not _db:
|
|
|
29 |
return JSONResponse(status_code=404, content={})
|
|
|
30 |
async for doc in generate_documents(file, file_name):
|
31 |
print(doc)
|
32 |
_db.add_documents([doc])
|
@@ -39,9 +41,9 @@ async def answer(name: str, query: str):
|
|
39 |
`name` of the collection.
|
40 |
`query` to be answered.
|
41 |
"""
|
42 |
-
_db =
|
43 |
print(query)
|
44 |
-
docs = _db.
|
45 |
print(docs)
|
46 |
answer = _chain.run(input_documents=[tup[0] for tup in docs], question=query)
|
47 |
return JSONResponse(status_code=200, content={"answer": answer, "file_score": [[f"{d[0].metadata['file']} : {d[0].metadata['page']}", d[1]] for d in docs]})
|
|
|
11 |
from langchain.schema import Document
|
12 |
from langchain.chains.question_answering import load_qa_chain
|
13 |
from langchain.llms import OpenAI
|
14 |
+
from db.vector_store import ToyVectorStore
|
15 |
|
16 |
router = APIRouter()
|
17 |
_chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", verbose=True)
|
18 |
|
19 |
@router.post("/v1/docs")
|
20 |
async def create_or_update(name: Annotated[str, Body()], file_name: Annotated[str, Body()], file: UploadFile = File(...)):
|
21 |
+
"""Create or update an existing collection with information from the file
|
22 |
`name` of the collection
|
23 |
`file` to upload.
|
24 |
`fileName` name of the file.
|
25 |
"""
|
26 |
|
27 |
+
_db = ToyVectorStore.get_instance().get_collection(name)
|
28 |
if not _db:
|
29 |
+
#todo. fix this to create a collection, may be.
|
30 |
return JSONResponse(status_code=404, content={})
|
31 |
+
|
32 |
async for doc in generate_documents(file, file_name):
|
33 |
print(doc)
|
34 |
_db.add_documents([doc])
|
|
|
41 |
`name` of the collection.
|
42 |
`query` to be answered.
|
43 |
"""
|
44 |
+
_db = ToyVectorStore.get_instance().get_collection(name)
|
45 |
print(query)
|
46 |
+
docs = _db.similarity_search_with_score(query=query)
|
47 |
print(docs)
|
48 |
answer = _chain.run(input_documents=[tup[0] for tup in docs], question=query)
|
49 |
return JSONResponse(status_code=200, content={"answer": answer, "file_score": [[f"{d[0].metadata['file']} : {d[0].metadata['page']}", d[1]] for d in docs]})
|
requirements.txt
CHANGED
@@ -8,4 +8,4 @@ langchain
|
|
8 |
tiktoken
|
9 |
faiss-cpu
|
10 |
qdrant-client
|
11 |
-
|
|
|
8 |
tiktoken
|
9 |
faiss-cpu
|
10 |
qdrant-client
|
11 |
+
elasticsearch
|