Spaces:
Sleeping
Sleeping
refactor
Browse files- api/db/vector_store.py +22 -5
- api/document_parsing.py +35 -0
- api/main.py +2 -3
- api/routes/admin.py +0 -16
- api/routes/embeddings.py +0 -15
- api/routes/search.py +10 -47
- api/routes/upload.py +39 -0
api/db/vector_store.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from abc import abstractmethod
|
|
|
2 |
import os
|
3 |
from qdrant_client import QdrantClient
|
4 |
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
@@ -7,7 +8,7 @@ from qdrant_client.models import VectorParams, Distance
|
|
7 |
from db.embedding import Embedding, EMBEDDINGS
|
8 |
|
9 |
|
10 |
-
class
|
11 |
|
12 |
@staticmethod
|
13 |
def get_embedding():
|
@@ -17,13 +18,14 @@ class ToyVectorStore:
|
|
17 |
return EMBEDDINGS[embedding]
|
18 |
|
19 |
@staticmethod
|
|
|
20 |
def get_instance():
|
21 |
vector_store = os.getenv("STORE")
|
22 |
|
23 |
if vector_store == "ELASTIC":
|
24 |
-
return ElasticVectorStore(
|
25 |
elif vector_store == "QDRANT":
|
26 |
-
return QdrantVectorStore(
|
27 |
else:
|
28 |
raise ValueError(f"Invalid vector store {vector_store}")
|
29 |
|
@@ -47,7 +49,14 @@ class ToyVectorStore:
|
|
47 |
"""
|
48 |
pass
|
49 |
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def __init__(self, embeddings):
|
52 |
super().__init__(embeddings)
|
53 |
|
@@ -59,8 +68,11 @@ class ElasticVectorStore(ToyVectorStore):
|
|
59 |
store = self.get_collection(collection)
|
60 |
store.create_index(store.client,collection, dict())
|
61 |
|
|
|
|
|
|
|
62 |
|
63 |
-
class QdrantVectorStore(
|
64 |
|
65 |
def __init__(self, embeddings):
|
66 |
super().__init__(embeddings)
|
@@ -75,4 +87,9 @@ class QdrantVectorStore(ToyVectorStore):
|
|
75 |
self.client.create_collection(collection_name=collection,
|
76 |
vectors_config=VectorParams(size=self.embedding.dimension,
|
77 |
distance=Distance.COSINE))
|
|
|
|
|
|
|
|
|
|
|
78 |
|
|
|
1 |
from abc import abstractmethod
|
2 |
+
from functools import cache
|
3 |
import os
|
4 |
from qdrant_client import QdrantClient
|
5 |
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
|
|
8 |
from db.embedding import Embedding, EMBEDDINGS
|
9 |
|
10 |
|
11 |
+
class Store:
|
12 |
|
13 |
@staticmethod
|
14 |
def get_embedding():
|
|
|
18 |
return EMBEDDINGS[embedding]
|
19 |
|
20 |
@staticmethod
|
21 |
+
@cache
|
22 |
def get_instance():
|
23 |
vector_store = os.getenv("STORE")
|
24 |
|
25 |
if vector_store == "ELASTIC":
|
26 |
+
return ElasticVectorStore(Store.get_embedding())
|
27 |
elif vector_store == "QDRANT":
|
28 |
+
return QdrantVectorStore(Store.get_embedding())
|
29 |
else:
|
30 |
raise ValueError(f"Invalid vector store {vector_store}")
|
31 |
|
|
|
49 |
"""
|
50 |
pass
|
51 |
|
52 |
+
@abstractmethod
|
53 |
+
def list_collections(self) -> list[dict]:
|
54 |
+
"""
|
55 |
+
Return a list of collections in the vecot store.
|
56 |
+
"""
|
57 |
+
pass
|
58 |
+
|
59 |
+
class ElasticVectorStore(Store):
|
60 |
def __init__(self, embeddings):
|
61 |
super().__init__(embeddings)
|
62 |
|
|
|
68 |
store = self.get_collection(collection)
|
69 |
store.create_index(store.client,collection, dict())
|
70 |
|
71 |
+
def list_collections(self) -> list[dict]:
|
72 |
+
#TODO: not impelented
|
73 |
+
return []
|
74 |
|
75 |
+
class QdrantVectorStore(Store):
|
76 |
|
77 |
def __init__(self, embeddings):
|
78 |
super().__init__(embeddings)
|
|
|
87 |
self.client.create_collection(collection_name=collection,
|
88 |
vectors_config=VectorParams(size=self.embedding.dimension,
|
89 |
distance=Distance.COSINE))
|
90 |
+
|
91 |
+
def list_collections(self) -> list[dict]:
|
92 |
+
""" return a list of collections.
|
93 |
+
"""
|
94 |
+
return [ c for i,c in enumerate(self.client.get_collections().collections)]
|
95 |
|
api/document_parsing.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated
|
2 |
+
|
3 |
+
from fastapi import APIRouter, UploadFile, File, Body
|
4 |
+
from langchain.schema import Document
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
from pypdf import PdfReader
|
8 |
+
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
|
9 |
+
from db.vector_store import Store
|
10 |
+
|
11 |
+
async def generate_documents(file: UploadFile, file_name: str):
|
12 |
+
num=0
|
13 |
+
async for txts in convert_documents(file):
|
14 |
+
num += 1
|
15 |
+
for txt in txts:
|
16 |
+
document = Document(page_content=txt,metadata={"file": file_name, "page": num})
|
17 |
+
yield document
|
18 |
+
|
19 |
+
async def convert_documents(file: UploadFile):
|
20 |
+
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
|
21 |
+
|
22 |
+
#parse pdf document
|
23 |
+
if file.content_type == 'application/pdf':
|
24 |
+
content = await file.read()
|
25 |
+
pdf_reader = PdfReader(io.BytesIO(content))
|
26 |
+
try:
|
27 |
+
for page in pdf_reader.pages:
|
28 |
+
yield splitter.split_text(page.extract_text())
|
29 |
+
except Exception as e:
|
30 |
+
print(f"Exception {e}")
|
31 |
+
elif "text" in file.content_type:
|
32 |
+
content = await file.read()
|
33 |
+
yield splitter.split_text(content.decode("utf-8"))
|
34 |
+
else:
|
35 |
+
return
|
api/main.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
from fastapi import FastAPI
|
3 |
-
from routes import
|
4 |
from fastapi.middleware import Middleware
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from datetime import datetime
|
@@ -16,9 +16,8 @@ logger.addHandler(handler)
|
|
16 |
|
17 |
# Create the FastAPI instance
|
18 |
app = FastAPI()
|
19 |
-
app.include_router(embeddings.router)
|
20 |
app.include_router(search.router)
|
21 |
-
app.include_router(
|
22 |
app.exception_handler(generic_exception_handler)
|
23 |
|
24 |
app.add_middleware(CORSMiddleware, allow_origins = ["*"],
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
from fastapi import FastAPI
|
3 |
+
from routes import search, upload
|
4 |
from fastapi.middleware import Middleware
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from datetime import datetime
|
|
|
16 |
|
17 |
# Create the FastAPI instance
|
18 |
app = FastAPI()
|
|
|
19 |
app.include_router(search.router)
|
20 |
+
app.include_router(upload.router)
|
21 |
app.exception_handler(generic_exception_handler)
|
22 |
|
23 |
app.add_middleware(CORSMiddleware, allow_origins = ["*"],
|
api/routes/admin.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
#This is to init the vector store
|
2 |
-
|
3 |
-
from typing import Annotated
|
4 |
-
|
5 |
-
from fastapi import APIRouter, Body
|
6 |
-
from db.vector_store import ToyVectorStore
|
7 |
-
|
8 |
-
router = APIRouter()
|
9 |
-
|
10 |
-
@router.put("/admin/v1/db")
|
11 |
-
async def recreate_collection(name: Annotated[str, Body(embed=True)]):
|
12 |
-
""" `name` of the collection to be created.
|
13 |
-
If one exits, delete and recreate.
|
14 |
-
"""
|
15 |
-
print(f"creating collection {name} in db")
|
16 |
-
return ToyVectorStore.get_instance().create_collection(name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/embeddings.py
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
from fastapi import APIRouter, UploadFile, File
|
2 |
-
import openai
|
3 |
-
import io
|
4 |
-
import os
|
5 |
-
from pypdf import PdfReader
|
6 |
-
|
7 |
-
router = APIRouter()
|
8 |
-
|
9 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
10 |
-
|
11 |
-
@router.post("/v1/embeddings")
|
12 |
-
async def embed_doc(file: UploadFile = File(...)):
|
13 |
-
#for now just truncate based on length of words
|
14 |
-
content = await file.read()
|
15 |
-
return openai.Embedding.create(input = content.decode("utf-8"), model = "text-embedding-ada-002")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/search.py
CHANGED
@@ -10,65 +10,28 @@ from langchain.schema import Document
|
|
10 |
from langchain.chains.question_answering import load_qa_chain
|
11 |
from langchain.llms import OpenAI
|
12 |
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
|
13 |
-
from db.vector_store import
|
14 |
|
15 |
router = APIRouter()
|
16 |
_chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", verbose=True)
|
17 |
|
18 |
-
@router.post("/v1/docs")
|
19 |
-
async def create_or_update(name: Annotated[str, Body()], file_name: Annotated[str, Body()], file: UploadFile = File(...)):
|
20 |
-
"""Create or update an existing collection with information from the file
|
21 |
-
`name` of the collection
|
22 |
-
`file` to upload.
|
23 |
-
`fileName` name of the file.
|
24 |
-
"""
|
25 |
-
|
26 |
-
_db = ToyVectorStore.get_instance().get_collection(name)
|
27 |
-
if not _db:
|
28 |
-
#todo. fix this to create a collection, may be.
|
29 |
-
return JSONResponse(status_code=404, content={})
|
30 |
-
|
31 |
-
async for doc in generate_documents(file, file_name):
|
32 |
-
print(doc)
|
33 |
-
_db.add_documents([doc])
|
34 |
-
#todo return something sensible
|
35 |
-
return JSONResponse(status_code=200, content={"name": name})
|
36 |
|
37 |
-
@router.get("/v1/
|
38 |
async def answer(name: str, query: str):
|
39 |
-
""" Answer a question from the
|
40 |
-
`name` of the
|
41 |
`query` to be answered.
|
42 |
"""
|
43 |
-
_db =
|
44 |
print(query)
|
45 |
docs = _db.similarity_search_with_score(query=query)
|
46 |
print(docs)
|
47 |
answer = _chain.run(input_documents=[tup[0] for tup in docs], question=query)
|
48 |
return JSONResponse(status_code=200, content={"answer": answer, "file_score": [[f"{d[0].metadata['file']} : {d[0].metadata['page']}", d[1]] for d in docs]})
|
49 |
|
50 |
-
async def generate_documents(file: UploadFile, file_name: str):
|
51 |
-
num=0
|
52 |
-
async for txts in convert_documents(file):
|
53 |
-
num += 1
|
54 |
-
for txt in txts:
|
55 |
-
document = Document(page_content=txt,metadata={"file": file_name, "page": num})
|
56 |
-
yield document
|
57 |
-
|
58 |
-
async def convert_documents(file: UploadFile):
|
59 |
-
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
for page in pdf_reader.pages:
|
67 |
-
yield splitter.split_text(page.extract_text())
|
68 |
-
except Exception as e:
|
69 |
-
print(f"Exception {e}")
|
70 |
-
elif "text" in file.content_type:
|
71 |
-
content = await file.read()
|
72 |
-
yield splitter.split_text(content.decode("utf-8"))
|
73 |
-
else:
|
74 |
-
return
|
|
|
10 |
from langchain.chains.question_answering import load_qa_chain
|
11 |
from langchain.llms import OpenAI
|
12 |
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
|
13 |
+
from db.vector_store import Store
|
14 |
|
15 |
router = APIRouter()
|
16 |
_chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", verbose=True)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
@router.get("/v1/docs/{name}/answer")
|
20 |
async def answer(name: str, query: str):
|
21 |
+
""" Answer a question from the doc
|
22 |
+
`name` of the doc.
|
23 |
`query` to be answered.
|
24 |
"""
|
25 |
+
_db = Store.get_instance().get_collection(name)
|
26 |
print(query)
|
27 |
docs = _db.similarity_search_with_score(query=query)
|
28 |
print(docs)
|
29 |
answer = _chain.run(input_documents=[tup[0] for tup in docs], question=query)
|
30 |
return JSONResponse(status_code=200, content={"answer": answer, "file_score": [[f"{d[0].metadata['file']} : {d[0].metadata['page']}", d[1]] for d in docs]})
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
@router.get("/v1/docs")
|
34 |
+
async def list() -> list[dict]:
|
35 |
+
""" List all the docs.
|
36 |
+
"""
|
37 |
+
return Store.get_instance().list_collections()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/routes/upload.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#This is to init the vector store
|
2 |
+
|
3 |
+
from typing import Annotated
|
4 |
+
|
5 |
+
from db.vector_store import Store
|
6 |
+
from document_parsing import generate_documents
|
7 |
+
|
8 |
+
from fastapi import APIRouter, Body
|
9 |
+
from fastapi.responses import JSONResponse
|
10 |
+
from fastapi import APIRouter, UploadFile, File, Body
|
11 |
+
|
12 |
+
router = APIRouter()
|
13 |
+
|
14 |
+
@router.put("/v1/docs")
|
15 |
+
async def recreate_collection(name: Annotated[str, Body(embed=True)]):
|
16 |
+
""" `name` of the doc to be created.
|
17 |
+
If one exits, delete and recreate.
|
18 |
+
"""
|
19 |
+
print(f"creating collection {name} in db")
|
20 |
+
return Store.get_instance().create_collection(name)
|
21 |
+
|
22 |
+
@router.post("/v1/docs")
|
23 |
+
async def update(name: Annotated[str, Body()], file_name: Annotated[str, Body()], file: UploadFile = File(...)):
|
24 |
+
"""Update an existing document with information from the file.
|
25 |
+
If one doesn't exist with name, it creates a new document to update.
|
26 |
+
`name` of the collection
|
27 |
+
`file` to upload.
|
28 |
+
`fileName` name of the file. This is used for metadata purposes only.
|
29 |
+
"""
|
30 |
+
|
31 |
+
_db = Store.get_instance().get_collection(name)
|
32 |
+
if not _db:
|
33 |
+
return JSONResponse(status_code=404, content={})
|
34 |
+
|
35 |
+
async for doc in generate_documents(file, file_name):
|
36 |
+
print(doc)
|
37 |
+
_db.add_documents([doc])
|
38 |
+
return JSONResponse(status_code=200, content={"name": name})
|
39 |
+
|