|
import requests
|
|
import json
|
|
from typing import Union, Dict, Generator
|
|
import time
|
|
|
|
class ChatCompletionTester:
|
|
def __init__(self, base_url: str = "http://localhost:8000"):
|
|
self.base_url = base_url
|
|
self.endpoint = f"{base_url}/chat/completions"
|
|
|
|
def create_test_payload(self, stream: bool = False) -> Dict:
|
|
"""Create a sample payload for testing"""
|
|
return {
|
|
"model": "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "What is the capital of France?"}
|
|
],
|
|
"temperature": 0.7,
|
|
"max_tokens": 4096,
|
|
"stream": stream
|
|
}
|
|
|
|
def test_non_streaming(self) -> Union[Dict, None]:
|
|
"""Test non-streaming response"""
|
|
print("\n=== Testing Non-Streaming Response ===")
|
|
try:
|
|
payload = self.create_test_payload(stream=False)
|
|
print("Sending request...")
|
|
|
|
response = requests.post(
|
|
self.endpoint,
|
|
json=payload,
|
|
headers={"Content-Type": "application/json"}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
content = result['choices'][0]['message']['content']
|
|
print("\nResponse received successfully!")
|
|
print(f"Content: {content}")
|
|
return result
|
|
else:
|
|
print(f"Error: Status code {response.status_code}")
|
|
print(f"Response: {response.text}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"Error during non-streaming test: {str(e)}")
|
|
return None
|
|
|
|
def test_streaming(self) -> Union[str, None]:
|
|
"""Test streaming response"""
|
|
print("\n=== Testing Streaming Response ===")
|
|
try:
|
|
payload = self.create_test_payload(stream=True)
|
|
print("Sending request...")
|
|
|
|
response = requests.post(
|
|
self.endpoint,
|
|
json=payload,
|
|
headers={"Content-Type": "application/json"},
|
|
stream=True
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
print("\nReceiving streaming response:")
|
|
full_response = ""
|
|
for line in response.iter_lines(decode_unicode=True):
|
|
if line:
|
|
if line.startswith("data: "):
|
|
try:
|
|
data = json.loads(line[6:])
|
|
if data == "[DONE]":
|
|
continue
|
|
content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
|
if content:
|
|
print(content, end="", flush=True)
|
|
full_response += content
|
|
except json.JSONDecodeError:
|
|
continue
|
|
print("\n\nStreaming completed!")
|
|
return full_response
|
|
else:
|
|
print(f"Error: Status code {response.status_code}")
|
|
print(f"Response: {response.text}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"Error during streaming test: {str(e)}")
|
|
return None
|
|
|
|
def run_all_tests(self):
|
|
"""Run both streaming and non-streaming tests"""
|
|
print("Starting API endpoint tests...")
|
|
|
|
|
|
try:
|
|
requests.get(self.base_url)
|
|
print("β Server is accessible")
|
|
except requests.exceptions.ConnectionError:
|
|
print("β Server is not accessible. Please ensure the FastAPI server is running.")
|
|
return
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
non_streaming_result = self.test_non_streaming()
|
|
if non_streaming_result:
|
|
print("β Non-streaming test passed")
|
|
else:
|
|
print("β Non-streaming test failed")
|
|
|
|
|
|
streaming_result = self.test_streaming()
|
|
if streaming_result:
|
|
print("β Streaming test passed")
|
|
else:
|
|
print("β Streaming test failed")
|
|
|
|
end_time = time.time()
|
|
print(f"\nAll tests completed in {end_time - start_time:.2f} seconds")
|
|
|
|
def main():
|
|
|
|
tester = ChatCompletionTester()
|
|
|
|
|
|
tester.run_all_tests()
|
|
|
|
if __name__ == "__main__":
|
|
main() |