|
import os |
|
from typing import Optional |
|
|
|
import pytest |
|
from _pytest.monkeypatch import MonkeyPatch |
|
from requests.adapters import HTTPAdapter |
|
from tcvectordb import VectorDBClient |
|
from tcvectordb.model.database import Collection, Database |
|
from tcvectordb.model.document import Document, Filter |
|
from tcvectordb.model.enum import ReadConsistency |
|
from tcvectordb.model.index import Index |
|
from xinference_client.types import Embedding |
|
|
|
|
|
class MockTcvectordbClass: |
|
def mock_vector_db_client( |
|
self, |
|
url=None, |
|
username="", |
|
key="", |
|
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, |
|
timeout=5, |
|
adapter: HTTPAdapter = None, |
|
): |
|
self._conn = None |
|
self._read_consistency = read_consistency |
|
|
|
def list_databases(self) -> list[Database]: |
|
return [ |
|
Database( |
|
conn=self._conn, |
|
read_consistency=self._read_consistency, |
|
name="dify", |
|
) |
|
] |
|
|
|
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: |
|
return [] |
|
|
|
def drop_collection(self, name: str, timeout: Optional[float] = None): |
|
return {"code": 0, "msg": "operation success"} |
|
|
|
def create_collection( |
|
self, |
|
name: str, |
|
shard: int, |
|
replicas: int, |
|
description: str, |
|
index: Index, |
|
embedding: Embedding = None, |
|
timeout: Optional[float] = None, |
|
) -> Collection: |
|
return Collection( |
|
self, |
|
name, |
|
shard, |
|
replicas, |
|
description, |
|
index, |
|
embedding=embedding, |
|
read_consistency=self._read_consistency, |
|
timeout=timeout, |
|
) |
|
|
|
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: |
|
collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout) |
|
return collection |
|
|
|
def collection_upsert( |
|
self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs |
|
): |
|
return {"code": 0, "msg": "operation success"} |
|
|
|
def collection_search( |
|
self, |
|
vectors: list[list[float]], |
|
filter: Filter = None, |
|
params=None, |
|
retrieve_vector: bool = False, |
|
limit: int = 10, |
|
output_fields: Optional[list[str]] = None, |
|
timeout: Optional[float] = None, |
|
) -> list[list[dict]]: |
|
return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]] |
|
|
|
def collection_query( |
|
self, |
|
document_ids: Optional[list] = None, |
|
retrieve_vector: bool = False, |
|
limit: Optional[int] = None, |
|
offset: Optional[int] = None, |
|
filter: Optional[Filter] = None, |
|
output_fields: Optional[list[str]] = None, |
|
timeout: Optional[float] = None, |
|
) -> list[dict]: |
|
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}] |
|
|
|
def collection_delete( |
|
self, |
|
document_ids: Optional[list[str]] = None, |
|
filter: Filter = None, |
|
timeout: Optional[float] = None, |
|
): |
|
return {"code": 0, "msg": "operation success"} |
|
|
|
|
|
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" |
|
|
|
|
|
@pytest.fixture |
|
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): |
|
if MOCK: |
|
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) |
|
monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) |
|
monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) |
|
monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) |
|
monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection) |
|
monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection) |
|
monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert) |
|
monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search) |
|
monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query) |
|
monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete) |
|
|
|
yield |
|
|
|
if MOCK: |
|
monkeypatch.undo() |
|
|