|
""" |
|
Test client. |
|
|
|
We just test the dummy client. |
|
""" |
|
from manifest.clients.dummy import DummyClient |
|
|
|
|
|
def test_init() -> None: |
|
"""Test client initialization.""" |
|
client = DummyClient(connection_str=None) |
|
assert client.n == 1 |
|
|
|
args = {"n": 3} |
|
client = DummyClient(connection_str=None, client_args=args) |
|
assert client.n == 3 |
|
|
|
|
|
def test_get_params() -> None: |
|
"""Test get param functions.""" |
|
client = DummyClient(connection_str=None) |
|
assert client.get_model_params() == { |
|
"engine": "dummy", |
|
"model": "text-davinci-003", |
|
} |
|
assert client.get_model_inputs() == [ |
|
"engine", |
|
"temperature", |
|
"max_tokens", |
|
"n", |
|
"top_p", |
|
"top_k", |
|
"batch_size", |
|
] |
|
|
|
|
|
def test_get_request() -> None: |
|
"""Test client get request.""" |
|
args = {"n": 3} |
|
client = DummyClient(connection_str=None, client_args=args) |
|
request_params = client.get_request("hello", {}) |
|
response = client.run_request(request_params) |
|
assert client.get_cache_key(request_params) == { |
|
"prompt": "hello", |
|
"model": "text-davinci-003", |
|
"n": 3, |
|
"temperature": 0.0, |
|
"max_tokens": 10, |
|
"top_p": 1.0, |
|
"best_of": 1, |
|
"engine": "dummy", |
|
"request_cls": "LMRequest", |
|
} |
|
assert response.get_json_response() == { |
|
"choices": [ |
|
{ |
|
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", |
|
"token_logprobs": [ |
|
-0.2649905035732101, |
|
-1.210794839387105, |
|
-1.2173929801003434, |
|
-0.7758233850171001, |
|
-0.7165940659570416, |
|
-1.7430328887209088, |
|
-1.5379414228820203, |
|
-1.7838011423472508, |
|
-1.139095076944217, |
|
-0.6321855879833425, |
|
], |
|
"tokens": [ |
|
"70470", |
|
"80723", |
|
"52693", |
|
"39743", |
|
"38983", |
|
"1303", |
|
"56072", |
|
"22306", |
|
"17738", |
|
"53176", |
|
], |
|
} |
|
] |
|
* 3 |
|
} |
|
assert response.get_usage_obj().dict() == { |
|
"usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}] |
|
* 3, |
|
} |
|
|
|
request_params = client.get_request("hello", {"n": 5}) |
|
response = client.run_request(request_params) |
|
assert client.get_cache_key(request_params) == { |
|
"prompt": "hello", |
|
"model": "text-davinci-003", |
|
"n": 5, |
|
"temperature": 0.0, |
|
"max_tokens": 10, |
|
"top_p": 1.0, |
|
"best_of": 1, |
|
"engine": "dummy", |
|
"request_cls": "LMRequest", |
|
} |
|
assert response.get_json_response() == { |
|
"choices": [ |
|
{ |
|
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", |
|
"token_logprobs": [ |
|
-0.2649905035732101, |
|
-1.210794839387105, |
|
-1.2173929801003434, |
|
-0.7758233850171001, |
|
-0.7165940659570416, |
|
-1.7430328887209088, |
|
-1.5379414228820203, |
|
-1.7838011423472508, |
|
-1.139095076944217, |
|
-0.6321855879833425, |
|
], |
|
"tokens": [ |
|
"70470", |
|
"80723", |
|
"52693", |
|
"39743", |
|
"38983", |
|
"1303", |
|
"56072", |
|
"22306", |
|
"17738", |
|
"53176", |
|
], |
|
} |
|
] |
|
* 5 |
|
} |
|
assert response.get_usage_obj().dict() == { |
|
"usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}] |
|
* 5, |
|
} |
|
|
|
request_params = client.get_request(["hello"] * 5, {"n": 1}) |
|
response = client.run_request(request_params) |
|
assert client.get_cache_key(request_params) == { |
|
"prompt": ["hello"] * 5, |
|
"model": "text-davinci-003", |
|
"n": 1, |
|
"temperature": 0.0, |
|
"max_tokens": 10, |
|
"top_p": 1.0, |
|
"best_of": 1, |
|
"engine": "dummy", |
|
"request_cls": "LMRequest", |
|
} |
|
assert response.get_json_response() == { |
|
"choices": [ |
|
{ |
|
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", |
|
"token_logprobs": [ |
|
-0.2649905035732101, |
|
-1.210794839387105, |
|
-1.2173929801003434, |
|
-0.7758233850171001, |
|
-0.7165940659570416, |
|
-1.7430328887209088, |
|
-1.5379414228820203, |
|
-1.7838011423472508, |
|
-1.139095076944217, |
|
-0.6321855879833425, |
|
], |
|
"tokens": [ |
|
"70470", |
|
"80723", |
|
"52693", |
|
"39743", |
|
"38983", |
|
"1303", |
|
"56072", |
|
"22306", |
|
"17738", |
|
"53176", |
|
], |
|
} |
|
] |
|
* 5 |
|
} |
|
assert response.get_usage_obj().dict() == { |
|
"usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}] |
|
* 5, |
|
} |
|
|