|
import platform |
|
|
|
import pytest |
|
|
|
from h2ogpt_client import Client |
|
|
|
platform.python_version() |
|
|
|
|
|
@pytest.fixture |
|
def client(server_url, h2ogpt_key) -> Client: |
|
return Client(server_url, h2ogpt_key=h2ogpt_key) |
|
|
|
|
|
def _create_text_completion(client): |
|
model = client.models.list()[-1] |
|
return client.text_completion.create(model=model) |
|
|
|
|
|
@pytest.mark.asyncio |
|
async def test_text_completion(client): |
|
text_completion = _create_text_completion(client) |
|
response = await text_completion.complete(prompt="Hello world") |
|
assert response |
|
print(response) |
|
|
|
|
|
@pytest.mark.asyncio |
|
async def test_text_completion_stream(client): |
|
text_completion = _create_text_completion(client) |
|
response = await text_completion.complete( |
|
prompt="Write a poem about the Amazon rainforest. End it with an emoji.", |
|
enable_streaming=True, |
|
) |
|
async for token in response: |
|
assert token |
|
print(token, end="") |
|
|
|
|
|
def test_text_completion_sync(client): |
|
text_completion = _create_text_completion(client) |
|
response = text_completion.complete_sync(prompt="Hello world") |
|
assert response |
|
print(response) |
|
|
|
|
|
def test_text_completion_sync_stream(client): |
|
text_completion = _create_text_completion(client) |
|
response = text_completion.complete_sync( |
|
prompt="Write a poem about the Amazon rainforest. End it with an emoji.", |
|
enable_streaming=True, |
|
) |
|
for token in response: |
|
assert token |
|
print(token, end="") |
|
|
|
|
|
def _create_chat_completion(client): |
|
model = client.models.list()[-1] |
|
return client.chat_completion.create(model=model) |
|
|
|
|
|
@pytest.mark.asyncio |
|
async def test_chat_completion(client): |
|
chat_completion = _create_chat_completion(client) |
|
|
|
chat1 = await chat_completion.chat(prompt="Hey!") |
|
assert chat1["user"] == "Hey!" |
|
assert chat1["gpt"] |
|
|
|
chat2 = await chat_completion.chat(prompt="What is the capital of USA?") |
|
assert chat2["user"] == "What is the capital of USA?" |
|
assert chat2["gpt"] |
|
|
|
chat3 = await chat_completion.chat(prompt="What is the population in there?") |
|
assert chat3["user"] == "What is the population in there?" |
|
assert chat3["gpt"] |
|
|
|
chat_history = chat_completion.chat_history() |
|
assert chat_history == [chat1, chat2, chat3] |
|
print(chat_history) |
|
|
|
|
|
def test_chat_completion_sync(client): |
|
chat_completion = _create_chat_completion(client) |
|
|
|
chat1 = chat_completion.chat_sync(prompt="What is UNESCO?") |
|
assert chat1["user"] == "What is UNESCO?" |
|
assert chat1["gpt"] |
|
|
|
chat2 = chat_completion.chat_sync(prompt="Is it a part of the UN?") |
|
assert chat2["user"] == "Is it a part of the UN?" |
|
assert chat2["gpt"] |
|
|
|
chat3 = chat_completion.chat_sync(prompt="Where is the headquarters?") |
|
assert chat3["user"] == "Where is the headquarters?" |
|
assert chat3["gpt"] |
|
|
|
chat_history = chat_completion.chat_history() |
|
assert chat_history == [chat1, chat2, chat3] |
|
print(chat_history) |
|
|
|
|
|
def test_available_models(client): |
|
models = client.models.list() |
|
assert len(models) |
|
print(models) |
|
|
|
|
|
def test_server_properties(client, server_url): |
|
assert client.server.address.startswith(server_url) |
|
assert client.server.hash |
|
|
|
|
|
def test_parameters_order(client, eval_func_param_names): |
|
text_completion = client.text_completion.create() |
|
assert eval_func_param_names == list(text_completion._parameters.keys()) |
|
chat_completion = client.chat_completion.create() |
|
assert eval_func_param_names == list(chat_completion._parameters.keys()) |
|
|
|
|
|
@pytest.mark.parametrize("local_server", [True, False]) |
|
def test_readme_example(local_server): |
|
|
|
|
|
import asyncio |
|
import os |
|
|
|
from h2ogpt_client import Client |
|
|
|
if local_server: |
|
client = Client("http://0.0.0.0:7860") |
|
else: |
|
h2ogpt_key = os.getenv("H2OGPT_KEY") or os.getenv("H2OGPT_H2OGPT_KEY") |
|
if h2ogpt_key is None: |
|
return |
|
|
|
client = Client("https://gpt.h2o.ai", h2ogpt_key=h2ogpt_key) |
|
|
|
|
|
text_completion = client.text_completion.create() |
|
response = asyncio.run(text_completion.complete("Hello world")) |
|
print("asyncio text completion response: %s" % response) |
|
|
|
response = text_completion.complete_sync("Hello world") |
|
print("sync text completion response: %s" % response) |
|
|
|
|
|
chat_completion = client.chat_completion.create() |
|
reply = asyncio.run(chat_completion.chat("Hey!")) |
|
print("asyncio text completion user: %s gpt: %s" % (reply["user"], reply["gpt"])) |
|
chat_history = chat_completion.chat_history() |
|
print("chat_history: %s" % chat_history) |
|
|
|
reply = chat_completion.chat_sync("Hey!") |
|
print("sync chat completion gpt: %s" % reply["gpt"]) |
|
|