|
import os |
|
from typing import Union |
|
from unittest.mock import MagicMock |
|
|
|
import pytest |
|
from _pytest.monkeypatch import MonkeyPatch |
|
from volcengine.viking_db import ( |
|
Collection, |
|
Data, |
|
DistanceType, |
|
Field, |
|
FieldType, |
|
Index, |
|
IndexType, |
|
QuantType, |
|
VectorIndexParams, |
|
VikingDBService, |
|
) |
|
|
|
from core.rag.datasource.vdb.field import Field as vdb_Field |
|
|
|
|
|
class MockVikingDBClass: |
|
def __init__( |
|
self, |
|
host="api-vikingdb.volces.com", |
|
region="cn-north-1", |
|
ak="", |
|
sk="", |
|
scheme="http", |
|
connection_timeout=30, |
|
socket_timeout=30, |
|
proxy=None, |
|
): |
|
self._viking_db_service = MagicMock() |
|
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}') |
|
|
|
def get_collection(self, collection_name) -> Collection: |
|
return Collection( |
|
collection_name=collection_name, |
|
description="Collection For Dify", |
|
viking_db_service=self._viking_db_service, |
|
primary_key=vdb_Field.PRIMARY_KEY.value, |
|
fields=[ |
|
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), |
|
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), |
|
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), |
|
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), |
|
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768), |
|
], |
|
indexes=[ |
|
Index( |
|
collection_name=collection_name, |
|
index_name=f"{collection_name}_idx", |
|
vector_index=VectorIndexParams( |
|
distance=DistanceType.L2, |
|
index_type=IndexType.HNSW, |
|
quant=QuantType.Float, |
|
), |
|
scalar_index=None, |
|
stat=None, |
|
viking_db_service=self._viking_db_service, |
|
) |
|
], |
|
) |
|
|
|
def drop_collection(self, collection_name): |
|
assert collection_name != "" |
|
|
|
def create_collection(self, collection_name, fields, description="") -> Collection: |
|
return Collection( |
|
collection_name=collection_name, |
|
description=description, |
|
primary_key=vdb_Field.PRIMARY_KEY.value, |
|
viking_db_service=self._viking_db_service, |
|
fields=fields, |
|
) |
|
|
|
def get_index(self, collection_name, index_name) -> Index: |
|
return Index( |
|
collection_name=collection_name, |
|
index_name=index_name, |
|
viking_db_service=self._viking_db_service, |
|
stat=None, |
|
scalar_index=None, |
|
vector_index=VectorIndexParams( |
|
distance=DistanceType.L2, |
|
index_type=IndexType.HNSW, |
|
quant=QuantType.Float, |
|
), |
|
) |
|
|
|
def create_index( |
|
self, |
|
collection_name, |
|
index_name, |
|
vector_index=None, |
|
cpu_quota=2, |
|
description="", |
|
partition_by="", |
|
scalar_index=None, |
|
shard_count=None, |
|
shard_policy=None, |
|
): |
|
return Index( |
|
collection_name=collection_name, |
|
index_name=index_name, |
|
vector_index=vector_index, |
|
cpu_quota=cpu_quota, |
|
description=description, |
|
partition_by=partition_by, |
|
scalar_index=scalar_index, |
|
shard_count=shard_count, |
|
shard_policy=shard_policy, |
|
viking_db_service=self._viking_db_service, |
|
stat=None, |
|
) |
|
|
|
def drop_index(self, collection_name, index_name): |
|
assert collection_name != "" |
|
assert index_name != "" |
|
|
|
def upsert_data(self, data: Union[Data, list[Data]]): |
|
assert data is not None |
|
|
|
def fetch_data(self, id: Union[str, list[str], int, list[int]]): |
|
return Data( |
|
fields={ |
|
vdb_Field.GROUP_KEY.value: "test_group", |
|
vdb_Field.METADATA_KEY.value: "{}", |
|
vdb_Field.CONTENT_KEY.value: "content", |
|
vdb_Field.PRIMARY_KEY.value: id, |
|
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], |
|
}, |
|
id=id, |
|
) |
|
|
|
def delete_data(self, id: Union[str, list[str], int, list[int]]): |
|
assert id is not None |
|
|
|
def search_by_vector( |
|
self, |
|
vector, |
|
sparse_vectors=None, |
|
filter=None, |
|
limit=10, |
|
output_fields=None, |
|
partition="default", |
|
dense_weight=None, |
|
) -> list[Data]: |
|
return [ |
|
Data( |
|
fields={ |
|
vdb_Field.GROUP_KEY.value: "test_group", |
|
vdb_Field.METADATA_KEY.value: '\ |
|
{"source": "/var/folders/ml/xxx/xxx.txt", \ |
|
"document_id": "test_document_id", \ |
|
"dataset_id": "test_dataset_id", \ |
|
"doc_id": "test_id", \ |
|
"doc_hash": "test_hash"}', |
|
vdb_Field.CONTENT_KEY.value: "content", |
|
vdb_Field.PRIMARY_KEY.value: "test_id", |
|
vdb_Field.VECTOR.value: vector, |
|
}, |
|
id="test_id", |
|
score=0.10, |
|
) |
|
] |
|
|
|
def search( |
|
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None |
|
) -> list[Data]: |
|
return [ |
|
Data( |
|
fields={ |
|
vdb_Field.GROUP_KEY.value: "test_group", |
|
vdb_Field.METADATA_KEY.value: '\ |
|
{"source": "/var/folders/ml/xxx/xxx.txt", \ |
|
"document_id": "test_document_id", \ |
|
"dataset_id": "test_dataset_id", \ |
|
"doc_id": "test_id", \ |
|
"doc_hash": "test_hash"}', |
|
vdb_Field.CONTENT_KEY.value: "content", |
|
vdb_Field.PRIMARY_KEY.value: "test_id", |
|
vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398], |
|
}, |
|
id="test_id", |
|
score=0.10, |
|
) |
|
] |
|
|
|
|
|
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" |
|
|
|
|
|
@pytest.fixture |
|
def setup_vikingdb_mock(monkeypatch: MonkeyPatch): |
|
if MOCK: |
|
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__) |
|
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection) |
|
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection) |
|
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection) |
|
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index) |
|
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index) |
|
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index) |
|
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data) |
|
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data) |
|
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data) |
|
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector) |
|
monkeypatch.setattr(Index, "search", MockVikingDBClass.search) |
|
|
|
yield |
|
|
|
if MOCK: |
|
monkeypatch.undo() |
|
|