Spaces:
Running
Running
from datetime import datetime, timezone | |
from uuid import UUID, uuid4 | |
import pytest | |
from langflow.memory import ( | |
aadd_messages, | |
aadd_messagetables, | |
add_messages, | |
add_messagetables, | |
adelete_messages, | |
aget_messages, | |
astore_message, | |
aupdate_messages, | |
delete_messages, | |
get_messages, | |
store_message, | |
update_messages, | |
) | |
from langflow.schema.content_block import ContentBlock | |
from langflow.schema.content_types import TextContent, ToolContent | |
from langflow.schema.message import Message | |
from langflow.schema.properties import Properties, Source | |
# Assuming you have these imports available | |
from langflow.services.database.models.message import MessageCreate, MessageRead | |
from langflow.services.database.models.message.model import MessageTable | |
from langflow.services.deps import async_session_scope | |
from langflow.services.tracing.utils import convert_to_langchain_type | |
async def created_message(): | |
async with async_session_scope() as session: | |
message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") | |
messagetable = MessageTable.model_validate(message, from_attributes=True) | |
messagetables = await aadd_messagetables([messagetable], session) | |
return MessageRead.model_validate(messagetables[0], from_attributes=True) | |
async def created_messages(session): # noqa: ARG001 | |
async with async_session_scope() as _session: | |
messages = [ | |
MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), | |
MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), | |
MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), | |
] | |
messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] | |
messagetables = await aadd_messagetables(messagetables, _session) | |
return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables] | |
def test_get_messages(): | |
add_messages( | |
[ | |
Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), | |
Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), | |
] | |
) | |
messages = get_messages(sender="User", session_id="session_id2", limit=2) | |
assert len(messages) == 2 | |
assert messages[0].text == "Test message 1" | |
assert messages[1].text == "Test message 2" | |
async def test_aget_messages(): | |
await aadd_messages( | |
[ | |
Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), | |
Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), | |
] | |
) | |
messages = await aget_messages(sender="User", session_id="session_id2", limit=2) | |
assert len(messages) == 2 | |
assert messages[0].text == "Test message 1" | |
assert messages[1].text == "Test message 2" | |
def test_add_messages(): | |
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") | |
messages = add_messages(message) | |
assert len(messages) == 1 | |
assert messages[0].text == "New Test message" | |
async def test_aadd_messages(): | |
message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") | |
messages = await aadd_messages(message) | |
assert len(messages) == 1 | |
assert messages[0].text == "New Test message" | |
def test_add_messagetables(session): | |
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")] | |
added_messages = add_messagetables(messages, session) | |
assert len(added_messages) == 1 | |
assert added_messages[0].text == "New Test message" | |
async def test_aadd_messagetables(async_session): | |
messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")] | |
added_messages = await aadd_messagetables(messages, async_session) | |
assert len(added_messages) == 1 | |
assert added_messages[0].text == "New Test message" | |
def test_delete_messages(): | |
session_id = "new_session_id" | |
message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id) | |
add_messages([message]) | |
messages = get_messages(sender="User", session_id=session_id) | |
assert len(messages) == 1 | |
delete_messages(session_id) | |
messages = get_messages(sender="User", session_id=session_id) | |
assert len(messages) == 0 | |
async def test_adelete_messages(): | |
session_id = "new_session_id" | |
message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id) | |
await aadd_messages([message]) | |
messages = await aget_messages(sender="User", session_id=session_id) | |
assert len(messages) == 1 | |
await adelete_messages(session_id) | |
messages = await aget_messages(sender="User", session_id=session_id) | |
assert len(messages) == 0 | |
def test_store_message(): | |
session_id = "stored_session_id" | |
message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id) | |
store_message(message) | |
stored_messages = get_messages(sender="User", session_id=session_id) | |
assert len(stored_messages) == 1 | |
assert stored_messages[0].text == "Stored message" | |
async def test_astore_message(): | |
session_id = "stored_session_id" | |
message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id) | |
await astore_message(message) | |
stored_messages = await aget_messages(sender="User", session_id=session_id) | |
assert len(stored_messages) == 1 | |
assert stored_messages[0].text == "Stored message" | |
def test_convert_to_langchain(method_name): | |
def convert(value): | |
if method_name == "message": | |
return value.to_lc_message() | |
if method_name == "convert_to_langchain_type": | |
return convert_to_langchain_type(value) | |
msg = f"Invalid method: {method_name}" | |
raise ValueError(msg) | |
lc_message = convert(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2")) | |
assert lc_message.content == "Test message 1" | |
assert lc_message.type == "human" | |
lc_message = convert(Message(text="Test message 2", sender="AI", session_id="session_id2")) | |
assert lc_message.content == "Test message 2" | |
assert lc_message.type == "ai" | |
iterator = iter(["stream", "message"]) | |
lc_message = convert(Message(text=iterator, sender="AI", session_id="session_id2")) | |
assert lc_message.content == "" | |
assert lc_message.type == "ai" | |
assert len(list(iterator)) == 2 | |
def test_update_single_message(created_message): | |
# Modify the message | |
created_message.text = "Updated message" | |
updated = update_messages(created_message) | |
assert len(updated) == 1 | |
assert updated[0].text == "Updated message" | |
assert updated[0].id == created_message.id | |
def test_update_multiple_messages(created_messages): | |
# Modify the messages | |
for i, message in enumerate(created_messages): | |
message.text = f"Updated message {i}" | |
updated = update_messages(created_messages) | |
assert len(updated) == len(created_messages) | |
for i, message in enumerate(updated): | |
assert message.text == f"Updated message {i}" | |
assert message.id == created_messages[i].id | |
def test_update_nonexistent_message(): | |
# Create a message with a non-existent UUID | |
message = MessageRead( | |
id=uuid4(), # Generate a random UUID that won't exist in the database | |
text="Test message", | |
sender="User", | |
sender_name="User", | |
session_id="session_id", | |
flow_id=uuid4(), | |
) | |
updated = update_messages(message) | |
assert len(updated) == 0 | |
def test_update_mixed_messages(created_messages): | |
# Create a mix of existing and non-existing messages | |
nonexistent_message = MessageRead( | |
id=uuid4(), # Generate a random UUID that won't exist in the database | |
text="Test message", | |
sender="User", | |
sender_name="User", | |
session_id="session_id", | |
flow_id=uuid4(), | |
) | |
messages_to_update = created_messages[:1] + [nonexistent_message] | |
created_messages[0].text = "Updated existing message" | |
updated = update_messages(messages_to_update) | |
assert len(updated) == 1 | |
assert updated[0].text == "Updated existing message" | |
assert updated[0].id == created_messages[0].id | |
assert isinstance(updated[0].id, UUID) # Verify ID is UUID type | |
def test_update_message_with_timestamp(created_message): | |
# Set a specific timestamp | |
new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) | |
created_message.timestamp = new_timestamp | |
created_message.text = "Updated message with timestamp" | |
updated = update_messages(created_message) | |
assert len(updated) == 1 | |
assert updated[0].text == "Updated message with timestamp" | |
# Compare timestamps without timezone info since DB doesn't preserve it | |
assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None) | |
assert updated[0].id == created_message.id | |
def test_update_multiple_messages_with_timestamps(created_messages): | |
# Modify messages with different timestamps | |
for i, message in enumerate(created_messages): | |
message.text = f"Updated message {i}" | |
message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) | |
updated = update_messages(created_messages) | |
assert len(updated) == len(created_messages) | |
for i, message in enumerate(updated): | |
assert message.text == f"Updated message {i}" | |
# Compare timestamps without timezone info | |
expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) | |
assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None) | |
assert message.id == created_messages[i].id | |
def test_update_message_with_content_blocks(created_message): | |
# Create a content block using proper models | |
text_content = TextContent( | |
type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"} | |
) | |
tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10) | |
content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True) | |
created_message.content_blocks = [content_block] | |
created_message.text = "Message with content blocks" | |
updated = update_messages(created_message) | |
assert len(updated) == 1 | |
assert updated[0].text == "Message with content blocks" | |
assert len(updated[0].content_blocks) == 1 | |
# Verify the content block structure | |
updated_block = updated[0].content_blocks[0] | |
assert updated_block.title == "Test Block" | |
assert len(updated_block.contents) == 2 | |
# Verify text content | |
text_content = updated_block.contents[0] | |
assert text_content.type == "text" | |
assert text_content.text == "Test content" | |
assert text_content.duration == 5 | |
assert text_content.header["title"] == "Test Header" | |
# Verify tool content | |
tool_content = updated_block.contents[1] | |
assert tool_content.type == "tool_use" | |
assert tool_content.name == "test_tool" | |
assert tool_content.tool_input == {"param": "value"} | |
assert tool_content.duration == 10 | |
def test_update_message_with_nested_properties(created_message): | |
# Create a text content with nested properties | |
text_content = TextContent( | |
type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15 | |
) | |
content_block = ContentBlock( | |
title="Test Properties", | |
contents=[text_content], | |
allow_markdown=True, | |
media_url=["http://example.com/image.jpg"], | |
) | |
# Set properties according to the Properties model structure | |
created_message.properties = Properties( | |
text_color="blue", | |
background_color="white", | |
edited=False, | |
source=Source(id="test_id", display_name="Test Source", source="test"), | |
icon="TestIcon", | |
allow_markdown=True, | |
state="complete", | |
targets=[], | |
) | |
created_message.text = "Message with nested properties" | |
created_message.content_blocks = [content_block] | |
updated = update_messages(created_message) | |
assert len(updated) == 1 | |
assert updated[0].text == "Message with nested properties" | |
# Verify the properties were properly serialized and stored | |
assert updated[0].properties.text_color == "blue" | |
assert updated[0].properties.background_color == "white" | |
assert updated[0].properties.edited is False | |
assert updated[0].properties.source.id == "test_id" | |
assert updated[0].properties.source.display_name == "Test Source" | |
assert updated[0].properties.source.source == "test" | |
assert updated[0].properties.icon == "TestIcon" | |
assert updated[0].properties.allow_markdown is True | |
assert updated[0].properties.state == "complete" | |
assert updated[0].properties.targets == [] | |
async def test_aupdate_single_message(created_message): | |
# Modify the message | |
created_message.text = "Updated message" | |
updated = await aupdate_messages(created_message) | |
assert len(updated) == 1 | |
assert updated[0].text == "Updated message" | |
assert updated[0].id == created_message.id | |
async def test_aupdate_multiple_messages(created_messages): | |
# Modify the messages | |
for i, message in enumerate(created_messages): | |
message.text = f"Updated message {i}" | |
updated = await aupdate_messages(created_messages) | |
assert len(updated) == len(created_messages) | |
for i, message in enumerate(updated): | |
assert message.text == f"Updated message {i}" | |
assert message.id == created_messages[i].id | |
async def test_aupdate_nonexistent_message(): | |
# Create a message with a non-existent UUID | |
message = MessageRead( | |
id=uuid4(), # Generate a random UUID that won't exist in the database | |
text="Test message", | |
sender="User", | |
sender_name="User", | |
session_id="session_id", | |
flow_id=uuid4(), | |
) | |
updated = await aupdate_messages(message) | |
assert len(updated) == 0 | |
async def test_aupdate_mixed_messages(created_messages): | |
# Create a mix of existing and non-existing messages | |
nonexistent_message = MessageRead( | |
id=uuid4(), # Generate a random UUID that won't exist in the database | |
text="Test message", | |
sender="User", | |
sender_name="User", | |
session_id="session_id", | |
flow_id=uuid4(), | |
) | |
messages_to_update = created_messages[:1] + [nonexistent_message] | |
created_messages[0].text = "Updated existing message" | |
updated = await aupdate_messages(messages_to_update) | |
assert len(updated) == 1 | |
assert updated[0].text == "Updated existing message" | |
assert updated[0].id == created_messages[0].id | |
assert isinstance(updated[0].id, UUID) # Verify ID is UUID type | |
async def test_aupdate_message_with_timestamp(created_message): | |
# Set a specific timestamp | |
new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) | |
created_message.timestamp = new_timestamp | |
created_message.text = "Updated message with timestamp" | |
updated = await aupdate_messages(created_message) | |
assert len(updated) == 1 | |
assert updated[0].text == "Updated message with timestamp" | |
# Compare timestamps without timezone info since DB doesn't preserve it | |
assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None) | |
assert updated[0].id == created_message.id | |
async def test_aupdate_multiple_messages_with_timestamps(created_messages): | |
# Modify messages with different timestamps | |
for i, message in enumerate(created_messages): | |
message.text = f"Updated message {i}" | |
message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) | |
updated = await aupdate_messages(created_messages) | |
assert len(updated) == len(created_messages) | |
for i, message in enumerate(updated): | |
assert message.text == f"Updated message {i}" | |
# Compare timestamps without timezone info | |
expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) | |
assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None) | |
assert message.id == created_messages[i].id | |
async def test_aupdate_message_with_content_blocks(created_message): | |
# Create a content block using proper models | |
text_content = TextContent( | |
type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"} | |
) | |
tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10) | |
content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True) | |
created_message.content_blocks = [content_block] | |
created_message.text = "Message with content blocks" | |
updated = await aupdate_messages(created_message) | |
assert len(updated) == 1 | |
assert updated[0].text == "Message with content blocks" | |
assert len(updated[0].content_blocks) == 1 | |
# Verify the content block structure | |
updated_block = updated[0].content_blocks[0] | |
assert updated_block.title == "Test Block" | |
assert len(updated_block.contents) == 2 | |
# Verify text content | |
text_content = updated_block.contents[0] | |
assert text_content.type == "text" | |
assert text_content.text == "Test content" | |
assert text_content.duration == 5 | |
assert text_content.header["title"] == "Test Header" | |
# Verify tool content | |
tool_content = updated_block.contents[1] | |
assert tool_content.type == "tool_use" | |
assert tool_content.name == "test_tool" | |
assert tool_content.tool_input == {"param": "value"} | |
assert tool_content.duration == 10 | |
async def test_aupdate_message_with_nested_properties(created_message): | |
# Create a text content with nested properties | |
text_content = TextContent( | |
type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15 | |
) | |
content_block = ContentBlock( | |
title="Test Properties", | |
contents=[text_content], | |
allow_markdown=True, | |
media_url=["http://example.com/image.jpg"], | |
) | |
# Set properties according to the Properties model structure | |
created_message.properties = Properties( | |
text_color="blue", | |
background_color="white", | |
edited=False, | |
source=Source(id="test_id", display_name="Test Source", source="test"), | |
icon="TestIcon", | |
allow_markdown=True, | |
state="complete", | |
targets=[], | |
) | |
created_message.text = "Message with nested properties" | |
created_message.content_blocks = [content_block] | |
updated = await aupdate_messages(created_message) | |
assert len(updated) == 1 | |
assert updated[0].text == "Message with nested properties" | |
# Verify the properties were properly serialized and stored | |
assert updated[0].properties.text_color == "blue" | |
assert updated[0].properties.background_color == "white" | |
assert updated[0].properties.edited is False | |
assert updated[0].properties.source.id == "test_id" | |
assert updated[0].properties.source.display_name == "Test Source" | |
assert updated[0].properties.source.source == "test" | |
assert updated[0].properties.icon == "TestIcon" | |
assert updated[0].properties.allow_markdown is True | |
assert updated[0].properties.state == "complete" | |
assert updated[0].properties.targets == [] | |