|
import os |
|
import re |
|
from typing import Union |
|
|
|
import pytest |
|
from _pytest.monkeypatch import MonkeyPatch |
|
from requests import Response |
|
from requests.sessions import Session |
|
from xinference_client.client.restful.restful_client import ( |
|
Client, |
|
RESTfulChatModelHandle, |
|
RESTfulEmbeddingModelHandle, |
|
RESTfulGenerateModelHandle, |
|
RESTfulRerankModelHandle, |
|
) |
|
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage |
|
|
|
|
|
class MockXinferenceClass: |
|
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]: |
|
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): |
|
raise RuntimeError("404 Not Found") |
|
|
|
if "generate" == model_uid: |
|
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) |
|
if "chat" == model_uid: |
|
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) |
|
if "embedding" == model_uid: |
|
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) |
|
if "rerank" == model_uid: |
|
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) |
|
raise RuntimeError("404 Not Found") |
|
|
|
def get(self: Session, url: str, **kwargs): |
|
response = Response() |
|
if "v1/models/" in url: |
|
|
|
model_uid = url.split("/")[-1] or "" |
|
if not re.match( |
|
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid |
|
) and model_uid not in {"generate", "chat", "embedding", "rerank"}: |
|
response.status_code = 404 |
|
response._content = b"{}" |
|
return response |
|
|
|
|
|
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url): |
|
response.status_code = 404 |
|
response._content = b"{}" |
|
return response |
|
|
|
if model_uid in {"generate", "chat"}: |
|
response.status_code = 200 |
|
response._content = b"""{ |
|
"model_type": "LLM", |
|
"address": "127.0.0.1:43877", |
|
"accelerators": [ |
|
"0", |
|
"1" |
|
], |
|
"model_name": "chatglm3-6b", |
|
"model_lang": [ |
|
"en" |
|
], |
|
"model_ability": [ |
|
"generate", |
|
"chat" |
|
], |
|
"model_description": "latest chatglm3", |
|
"model_format": "pytorch", |
|
"model_size_in_billions": 7, |
|
"quantization": "none", |
|
"model_hub": "huggingface", |
|
"revision": null, |
|
"context_length": 2048, |
|
"replica": 1 |
|
}""" |
|
return response |
|
|
|
elif model_uid == "embedding": |
|
response.status_code = 200 |
|
response._content = b"""{ |
|
"model_type": "embedding", |
|
"address": "127.0.0.1:43877", |
|
"accelerators": [ |
|
"0", |
|
"1" |
|
], |
|
"model_name": "bge", |
|
"model_lang": [ |
|
"en" |
|
], |
|
"revision": null, |
|
"max_tokens": 512 |
|
}""" |
|
return response |
|
|
|
elif "v1/cluster/auth" in url: |
|
response.status_code = 200 |
|
response._content = b"""{ |
|
"auth": true |
|
}""" |
|
return response |
|
|
|
def _check_cluster_authenticated(self): |
|
self._cluster_authed = True |
|
|
|
def rerank( |
|
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool |
|
) -> dict: |
|
|
|
if ( |
|
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) |
|
and self._model_uid != "rerank" |
|
): |
|
raise RuntimeError("404 Not Found") |
|
|
|
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url): |
|
raise RuntimeError("404 Not Found") |
|
|
|
if top_n is None: |
|
top_n = 1 |
|
|
|
return { |
|
"results": [ |
|
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n]) |
|
] |
|
} |
|
|
|
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict: |
|
|
|
if ( |
|
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) |
|
and self._model_uid != "embedding" |
|
): |
|
raise RuntimeError("404 Not Found") |
|
|
|
if isinstance(input, str): |
|
input = [input] |
|
ipt_len = len(input) |
|
|
|
embedding = Embedding( |
|
object="list", |
|
model=self._model_uid, |
|
data=[ |
|
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)]) |
|
for i in range(ipt_len) |
|
], |
|
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len), |
|
) |
|
|
|
return embedding |
|
|
|
|
|
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" |
|
|
|
|
|
@pytest.fixture |
|
def setup_xinference_mock(request, monkeypatch: MonkeyPatch): |
|
if MOCK: |
|
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model) |
|
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated) |
|
monkeypatch.setattr(Session, "get", MockXinferenceClass.get) |
|
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding) |
|
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank) |
|
yield |
|
|
|
if MOCK: |
|
monkeypatch.undo() |
|
|