medrag / tests /assistant /test_llm_client.py
geekyrakshit's picture
update: app
170d9a9
from PIL import Image
from pydantic import BaseModel
from medrag_multi_modal.assistant.llm_client import ClientType, LLMClient
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
class ImageDescription(BaseModel):
description: str
def test_openai_llm_client():
llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI)
event = llm_client.predict(
system_prompt="Extract the event information",
user_prompt="Alice and Bob are going to a science fair on Friday.",
schema=CalendarEvent,
)
assert event.name.lower() == "science fair"
assert event.date.lower() == "friday"
assert [item.lower() for item in event.participants] == ["alice", "bob"]
def test_openai_image_description():
llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI)
description = llm_client.predict(
system_prompt="Describe the image",
user_prompt=[Image.open("./assets/test_image.png")],
schema=ImageDescription,
)
assert "astronaut" in description.description.lower()
def test_google_llm_client():
llm_client = LLMClient(
model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI
)
event = llm_client.predict(
system_prompt="Extract the event information",
user_prompt="Alice and Bob are going to a science fair on Friday.",
schema=CalendarEvent,
)
event = event[0] if isinstance(event, list) else event
assert event["name"].lower() == "science fair"
assert event["date"].lower() == "friday"
assert [item.lower() for item in event["participants"]] == ["alice", "bob"]
def test_google_image_client():
llm_client = LLMClient(
model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI
)
description = llm_client.predict(
system_prompt="Describe the image",
user_prompt=[Image.open("./assets/test_image.png")],
schema=ImageDescription,
)
description = description[0] if isinstance(description, list) else description
assert "astronaut" in description["description"].lower()