kikuepi's picture
Upload 4913 files
4304c6d verified
import os
import re
from typing import Union
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests import Response
from requests.exceptions import ConnectionError
from requests.sessions import Session
from xinference_client.client.restful.restful_client import (
Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulEmbeddingModelHandle,
RESTfulGenerateModelHandle,
RESTfulRerankModelHandle,
)
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
class MockXinferenceClass:
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
raise RuntimeError('404 Not Found')
if 'generate' == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'chat' == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'embedding' == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'rerank' == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError('404 Not Found')
def get(self: Session, url: str, **kwargs):
response = Response()
if 'v1/models/' in url:
# get model uid
model_uid = url.split('/')[-1] or ''
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404
response._content = b'{}'
return response
# check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404
response._content = b'{}'
return response
if model_uid in ['generate', 'chat']:
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response
elif model_uid == 'embedding':
response.status_code = 200
response._content = b'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return response
elif 'v1/cluster/auth' in url:
response.status_code = 200
response._content = b'''{
"auth": true
}'''
return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'rerank':
raise RuntimeError('404 Not Found')
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
raise RuntimeError('404 Not Found')
if top_n is None:
top_n = 1
return {
'results': [
{
'index': i,
'document': doc,
'relevance_score': 0.9
}
for i, doc in enumerate(documents[:top_n])
]
}
def create_embedding(
self: RESTfulGenerateModelHandle,
input: Union[str, list[str]],
**kwargs
) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'embedding':
raise RuntimeError('404 Not Found')
if isinstance(input, str):
input = [input]
ipt_len = len(input)
embedding = Embedding(
object="list",
model=self._model_uid,
data=[
EmbeddingData(
index=i,
object="embedding",
embedding=[1919.810 for _ in range(768)]
)
for i in range(ipt_len)
],
usage=EmbeddingUsage(
prompt_tokens=ipt_len,
total_tokens=ipt_len
)
)
return embedding
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
@pytest.fixture
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
yield
if MOCK:
monkeypatch.undo()