Spaces:
Sleeping
Sleeping
File size: 3,118 Bytes
f2932e2 ca2fff7 51a7f02 9c7a6f3 f2932e2 1bec7d8 51a7f02 ca2fff7 c7e10e4 9c7a6f3 1bec7d8 f2932e2 ca2fff7 f2932e2 9c7a6f3 f2932e2 ca2fff7 f2932e2 ca2fff7 f2932e2 12a040e 9c7a6f3 1bec7d8 12a040e f2932e2 1bec7d8 f2932e2 ca2fff7 1bec7d8 12a040e f2932e2 1bec7d8 f2932e2 ca2fff7 f2932e2 ca2fff7 f2932e2 1bec7d8 f2932e2 12a040e 1bec7d8 f2932e2 1bec7d8 ca2fff7 e258771 f2932e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from abc import abstractmethod
from functools import cache
import os
from qdrant_client import QdrantClient
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
from qdrant_client.models import VectorParams, Distance
from db.embedding import Embedding, EMBEDDINGS
class Store:
@staticmethod
def get_embedding():
embedding = os.getenv("EMBEDDING")
if not embedding:
return EMBEDDINGS["OPEN_AI"]
return EMBEDDINGS[embedding]
@staticmethod
@cache
def get_instance():
vector_store = os.getenv("STORE")
if vector_store == "ELASTIC":
return ElasticVectorStore(Store.get_embedding())
elif vector_store == "QDRANT":
return QdrantVectorStore(Store.get_embedding())
else:
raise ValueError(f"Invalid vector store {vector_store}")
def __init__(self, embedding: Embedding):
self.embedding = embedding
@abstractmethod
def get_collection(self, collection: str="test") -> VectorStore:
"""
get an instance of vector store
of collection
"""
pass
@abstractmethod
def create_collection(self, collection: str) -> None:
"""
create an instance of vector store
with collection name
"""
pass
@abstractmethod
def list_collections(self) -> list[dict]:
"""
Return a list of collections in the vecot store.
"""
pass
class ElasticVectorStore(Store):
def __init__(self, embeddings):
super().__init__(embeddings)
def get_collection(self, collection:str) -> ElasticVectorSearch:
return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
index_name= collection, embedding=self.embedding.embedding)
def create_collection(self, collection: str) -> None:
store = self.get_collection(collection)
store.create_index(store.client,collection, dict())
def list_collections(self) -> list[dict]:
#TODO: not impelented
return []
class QdrantVectorStore(Store):
def __init__(self, embeddings):
super().__init__(embeddings)
self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"))
def get_collection(self, collection: str) -> Qdrant:
return Qdrant(client=self.client,collection_name=collection,
embeddings=self.embedding.embedding)
def create_collection(self, collection: str) -> None:
self.client.create_collection(collection_name=collection,
vectors_config=VectorParams(size=self.embedding.dimension,
distance=Distance.COSINE))
def list_collections(self) -> list[dict]:
""" return a list of collections.
"""
return [ c.dict() for i,c in enumerate(self.client.get_collections().collections)]
|