File size: 3,118 Bytes
f2932e2
ca2fff7
51a7f02
 
9c7a6f3
f2932e2
 
1bec7d8
51a7f02
 
ca2fff7
c7e10e4
9c7a6f3
 
 
1bec7d8
 
 
 
f2932e2
ca2fff7
f2932e2
 
9c7a6f3
f2932e2
ca2fff7
f2932e2
ca2fff7
f2932e2
 
12a040e
9c7a6f3
1bec7d8
 
12a040e
f2932e2
 
 
 
 
 
 
1bec7d8
f2932e2
 
 
 
 
 
 
 
ca2fff7
 
 
 
 
 
 
 
1bec7d8
 
12a040e
f2932e2
 
1bec7d8
 
f2932e2
 
 
 
ca2fff7
 
 
f2932e2
ca2fff7
f2932e2
1bec7d8
 
f2932e2
 
 
 
12a040e
1bec7d8
f2932e2
 
 
1bec7d8
 
ca2fff7
 
 
 
e258771
f2932e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from abc import abstractmethod
from functools import cache
import os
from qdrant_client import QdrantClient
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Qdrant, ElasticVectorSearch, VectorStore
from qdrant_client.models import VectorParams, Distance
from db.embedding import Embedding, EMBEDDINGS


class Store:

    @staticmethod
    def get_embedding():
        embedding = os.getenv("EMBEDDING")
        if not embedding:
            return EMBEDDINGS["OPEN_AI"]
        return EMBEDDINGS[embedding]
    
    @staticmethod
    @cache
    def get_instance():
        vector_store = os.getenv("STORE")

        if vector_store == "ELASTIC":
            return ElasticVectorStore(Store.get_embedding())
        elif vector_store == "QDRANT":
            return QdrantVectorStore(Store.get_embedding())
        else:
            raise ValueError(f"Invalid vector store {vector_store}")
    

    def __init__(self, embedding: Embedding):
        self.embedding = embedding

    @abstractmethod
    def get_collection(self, collection: str="test") -> VectorStore:
        """
        get an instance of vector store
        of collection
        """
        pass

    @abstractmethod
    def create_collection(self, collection: str) -> None:
        """
        create an instance of vector store
        with collection name
        """
        pass

    @abstractmethod
    def list_collections(self) -> list[dict]:
        """
        Return a list of collections in the vecot store.
        """
        pass

class ElasticVectorStore(Store):
    def __init__(self, embeddings):
        super().__init__(embeddings)

    def get_collection(self, collection:str) -> ElasticVectorSearch:
        return ElasticVectorSearch(elasticsearch_url= os.getenv("ES_URL"),
                               index_name= collection, embedding=self.embedding.embedding)
    
    def create_collection(self, collection: str) -> None:
        store = self.get_collection(collection)
        store.create_index(store.client,collection, dict())

    def list_collections(self) -> list[dict]:
        #TODO: not impelented
        return []

class QdrantVectorStore(Store):

    def __init__(self, embeddings):
        super().__init__(embeddings)
        self.client = QdrantClient(url=os.getenv("QDRANT_URL"),
                                        api_key=os.getenv("QDRANT_API_KEY"))

    def get_collection(self, collection: str) -> Qdrant:  
        return Qdrant(client=self.client,collection_name=collection,
                      embeddings=self.embedding.embedding)

    def create_collection(self, collection: str) -> None:
        self.client.create_collection(collection_name=collection, 
                        vectors_config=VectorParams(size=self.embedding.dimension, 
                                                    distance=Distance.COSINE))
    
    def list_collections(self) -> list[dict]:
        """ return a list of collections.
        """
        return [ c.dict() for i,c in enumerate(self.client.get_collections().collections)]