|
import time |
|
|
|
import pytest |
|
import os |
|
|
|
|
|
from tests.utils import wrap_test_forked |
|
|
|
|
|
def launch_openai_server(): |
|
from openai_server.server import run |
|
run() |
|
|
|
|
|
def test_openai_server(): |
|
|
|
|
|
|
|
|
|
launch_openai_server() |
|
|
|
|
|
|
|
repeat0 = 1 |
|
|
|
|
|
@pytest.mark.parametrize("stream_output", [False, True]) |
|
@pytest.mark.parametrize("chat", [False, True]) |
|
@pytest.mark.parametrize("local_server", [False]) |
|
@wrap_test_forked |
|
def test_openai_client_test2(stream_output, chat, local_server): |
|
prompt = "Who are you?" |
|
api_key = 'EMPTY' |
|
enforce_h2ogpt_api_key = False |
|
repeat = 1 |
|
run_openai_client(stream_output, chat, local_server, prompt, api_key, enforce_h2ogpt_api_key, repeat) |
|
|
|
|
|
@pytest.mark.parametrize("stream_output", [False, True]) |
|
@pytest.mark.parametrize("chat", [False, True]) |
|
@pytest.mark.parametrize("local_server", [True]) |
|
@pytest.mark.parametrize("prompt", ["Who are you?", "Tell a very long kid's story about birds."]) |
|
@pytest.mark.parametrize("api_key", [None, "EMPTY", os.environ.get('H2OGPT_H2OGPT_KEY', 'EMPTY')]) |
|
@pytest.mark.parametrize("enforce_h2ogpt_api_key", [False, True]) |
|
@pytest.mark.parametrize("repeat", list(range(0, repeat0))) |
|
@wrap_test_forked |
|
def test_openai_client(stream_output, chat, local_server, prompt, api_key, enforce_h2ogpt_api_key, repeat): |
|
run_openai_client(stream_output, chat, local_server, prompt, api_key, enforce_h2ogpt_api_key, repeat) |
|
|
|
|
|
def run_openai_client(stream_output, chat, local_server, prompt, api_key, enforce_h2ogpt_api_key, repeat): |
|
base_model = 'openchat/openchat-3.5-1210' |
|
|
|
if local_server: |
|
from src.gen import main |
|
main(base_model=base_model, chat=False, |
|
stream_output=stream_output, gradio=True, |
|
num_beams=1, block_gradio_exit=False, |
|
add_disk_models_to_ui=False, |
|
enable_tts=False, |
|
enable_stt=False, |
|
enforce_h2ogpt_api_key=enforce_h2ogpt_api_key, |
|
|
|
h2ogpt_api_keys=[api_key] if api_key else None, |
|
) |
|
time.sleep(10) |
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
print('api_key: %s' % api_key) |
|
|
|
base_url = 'http://localhost:5000/v1' |
|
verbose = True |
|
system_prompt = "You are a helpful assistant." |
|
chat_conversation = [] |
|
add_chat_history_to_context = True |
|
|
|
client_kwargs = dict(model=base_model, |
|
max_tokens=200, |
|
stream=stream_output) |
|
|
|
from openai import OpenAI, AsyncOpenAI |
|
client_args = dict(base_url=base_url, api_key=api_key) |
|
openai_client = OpenAI(**client_args) |
|
async_client = AsyncOpenAI(**client_args) |
|
|
|
try: |
|
test_chat(chat, openai_client, async_client, system_prompt, chat_conversation, add_chat_history_to_context, |
|
prompt, client_kwargs, stream_output, verbose) |
|
except AssertionError: |
|
if enforce_h2ogpt_api_key and api_key is None: |
|
print("Expected to fail since no key but enforcing.") |
|
else: |
|
raise |
|
|
|
|
|
model_info = openai_client.models.retrieve(base_model) |
|
assert model_info.base_model == base_model |
|
model_list = openai_client.models.list() |
|
assert model_list.data[0] == base_model |
|
|
|
|
|
def test_chat(chat, openai_client, async_client, system_prompt, chat_conversation, add_chat_history_to_context, |
|
prompt, client_kwargs, stream_output, verbose): |
|
|
|
|
|
if chat: |
|
client = openai_client.chat.completions |
|
async_client = async_client.chat.completions |
|
|
|
messages0 = [] |
|
if system_prompt: |
|
messages0.append({"role": "system", "content": system_prompt}) |
|
if chat_conversation and add_chat_history_to_context: |
|
for message1 in chat_conversation: |
|
if len(message1) == 2: |
|
messages0.append( |
|
{'role': 'user', 'content': message1[0] if message1[0] is not None else ''}) |
|
messages0.append( |
|
{'role': 'assistant', 'content': message1[1] if message1[1] is not None else ''}) |
|
messages0.append({'role': 'user', 'content': prompt if prompt is not None else ''}) |
|
|
|
client_kwargs.update(dict(messages=messages0)) |
|
else: |
|
client = openai_client.completions |
|
async_client = async_client.completions |
|
|
|
client_kwargs.update(dict(prompt=prompt)) |
|
|
|
responses = client.create(**client_kwargs) |
|
|
|
if not stream_output: |
|
if chat: |
|
text = responses.choices[0].message.content |
|
else: |
|
text = responses.choices[0].text |
|
print(text) |
|
else: |
|
collected_events = [] |
|
text = '' |
|
for event in responses: |
|
collected_events.append(event) |
|
if chat: |
|
delta = event.choices[0].delta.content |
|
else: |
|
delta = event.choices[0].text |
|
text += delta |
|
if verbose: |
|
print('delta: %s' % delta) |
|
print(text) |
|
|
|
if "Who" in prompt: |
|
assert 'OpenAI' in text or 'chatbot' in text |
|
else: |
|
assert 'birds' in text |
|
|
|
|
|
if __name__ == '__main__': |
|
launch_openai_server() |
|
|