from datetime import datetime, timezone from uuid import UUID, uuid4 import pytest from langflow.memory import ( aadd_messages, aadd_messagetables, add_messages, add_messagetables, adelete_messages, aget_messages, astore_message, aupdate_messages, delete_messages, get_messages, store_message, update_messages, ) from langflow.schema.content_block import ContentBlock from langflow.schema.content_types import TextContent, ToolContent from langflow.schema.message import Message from langflow.schema.properties import Properties, Source # Assuming you have these imports available from langflow.services.database.models.message import MessageCreate, MessageRead from langflow.services.database.models.message.model import MessageTable from langflow.services.deps import async_session_scope from langflow.services.tracing.utils import convert_to_langchain_type @pytest.fixture async def created_message(): async with async_session_scope() as session: message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") messagetable = MessageTable.model_validate(message, from_attributes=True) messagetables = await aadd_messagetables([messagetable], session) return MessageRead.model_validate(messagetables[0], from_attributes=True) @pytest.fixture async def created_messages(session): # noqa: ARG001 async with async_session_scope() as _session: messages = [ MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] messagetables = await aadd_messagetables(messagetables, _session) return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables] @pytest.mark.usefixtures("client") def test_get_messages(): add_messages( [ Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), ] ) messages = get_messages(sender="User", session_id="session_id2", limit=2) assert len(messages) == 2 assert messages[0].text == "Test message 1" assert messages[1].text == "Test message 2" @pytest.mark.usefixtures("client") async def test_aget_messages(): await aadd_messages( [ Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), Message(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), ] ) messages = await aget_messages(sender="User", session_id="session_id2", limit=2) assert len(messages) == 2 assert messages[0].text == "Test message 1" assert messages[1].text == "Test message 2" @pytest.mark.usefixtures("client") def test_add_messages(): message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") messages = add_messages(message) assert len(messages) == 1 assert messages[0].text == "New Test message" @pytest.mark.usefixtures("client") async def test_aadd_messages(): message = Message(text="New Test message", sender="User", sender_name="User", session_id="new_session_id") messages = await aadd_messages(message) assert len(messages) == 1 assert messages[0].text == "New Test message" @pytest.mark.usefixtures("client") def test_add_messagetables(session): messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")] added_messages = add_messagetables(messages, session) assert len(added_messages) == 1 assert added_messages[0].text == "New Test message" @pytest.mark.usefixtures("client") async def test_aadd_messagetables(async_session): messages = [MessageTable(text="New Test message", sender="User", sender_name="User", session_id="new_session_id")] added_messages = await aadd_messagetables(messages, async_session) assert len(added_messages) == 1 assert added_messages[0].text == "New Test message" @pytest.mark.usefixtures("client") def test_delete_messages(): session_id = "new_session_id" message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id) add_messages([message]) messages = get_messages(sender="User", session_id=session_id) assert len(messages) == 1 delete_messages(session_id) messages = get_messages(sender="User", session_id=session_id) assert len(messages) == 0 @pytest.mark.usefixtures("client") async def test_adelete_messages(): session_id = "new_session_id" message = Message(text="New Test message", sender="User", sender_name="User", session_id=session_id) await aadd_messages([message]) messages = await aget_messages(sender="User", session_id=session_id) assert len(messages) == 1 await adelete_messages(session_id) messages = await aget_messages(sender="User", session_id=session_id) assert len(messages) == 0 @pytest.mark.usefixtures("client") def test_store_message(): session_id = "stored_session_id" message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id) store_message(message) stored_messages = get_messages(sender="User", session_id=session_id) assert len(stored_messages) == 1 assert stored_messages[0].text == "Stored message" @pytest.mark.usefixtures("client") async def test_astore_message(): session_id = "stored_session_id" message = Message(text="Stored message", sender="User", sender_name="User", session_id=session_id) await astore_message(message) stored_messages = await aget_messages(sender="User", session_id=session_id) assert len(stored_messages) == 1 assert stored_messages[0].text == "Stored message" @pytest.mark.parametrize("method_name", ["message", "convert_to_langchain_type"]) def test_convert_to_langchain(method_name): def convert(value): if method_name == "message": return value.to_lc_message() if method_name == "convert_to_langchain_type": return convert_to_langchain_type(value) msg = f"Invalid method: {method_name}" raise ValueError(msg) lc_message = convert(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2")) assert lc_message.content == "Test message 1" assert lc_message.type == "human" lc_message = convert(Message(text="Test message 2", sender="AI", session_id="session_id2")) assert lc_message.content == "Test message 2" assert lc_message.type == "ai" iterator = iter(["stream", "message"]) lc_message = convert(Message(text=iterator, sender="AI", session_id="session_id2")) assert lc_message.content == "" assert lc_message.type == "ai" assert len(list(iterator)) == 2 @pytest.mark.usefixtures("client") def test_update_single_message(created_message): # Modify the message created_message.text = "Updated message" updated = update_messages(created_message) assert len(updated) == 1 assert updated[0].text == "Updated message" assert updated[0].id == created_message.id @pytest.mark.usefixtures("client") def test_update_multiple_messages(created_messages): # Modify the messages for i, message in enumerate(created_messages): message.text = f"Updated message {i}" updated = update_messages(created_messages) assert len(updated) == len(created_messages) for i, message in enumerate(updated): assert message.text == f"Updated message {i}" assert message.id == created_messages[i].id @pytest.mark.usefixtures("client") def test_update_nonexistent_message(): # Create a message with a non-existent UUID message = MessageRead( id=uuid4(), # Generate a random UUID that won't exist in the database text="Test message", sender="User", sender_name="User", session_id="session_id", flow_id=uuid4(), ) updated = update_messages(message) assert len(updated) == 0 @pytest.mark.usefixtures("client") def test_update_mixed_messages(created_messages): # Create a mix of existing and non-existing messages nonexistent_message = MessageRead( id=uuid4(), # Generate a random UUID that won't exist in the database text="Test message", sender="User", sender_name="User", session_id="session_id", flow_id=uuid4(), ) messages_to_update = created_messages[:1] + [nonexistent_message] created_messages[0].text = "Updated existing message" updated = update_messages(messages_to_update) assert len(updated) == 1 assert updated[0].text == "Updated existing message" assert updated[0].id == created_messages[0].id assert isinstance(updated[0].id, UUID) # Verify ID is UUID type @pytest.mark.usefixtures("client") def test_update_message_with_timestamp(created_message): # Set a specific timestamp new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) created_message.timestamp = new_timestamp created_message.text = "Updated message with timestamp" updated = update_messages(created_message) assert len(updated) == 1 assert updated[0].text == "Updated message with timestamp" # Compare timestamps without timezone info since DB doesn't preserve it assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None) assert updated[0].id == created_message.id @pytest.mark.usefixtures("client") def test_update_multiple_messages_with_timestamps(created_messages): # Modify messages with different timestamps for i, message in enumerate(created_messages): message.text = f"Updated message {i}" message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) updated = update_messages(created_messages) assert len(updated) == len(created_messages) for i, message in enumerate(updated): assert message.text == f"Updated message {i}" # Compare timestamps without timezone info expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None) assert message.id == created_messages[i].id @pytest.mark.usefixtures("client") def test_update_message_with_content_blocks(created_message): # Create a content block using proper models text_content = TextContent( type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"} ) tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10) content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True) created_message.content_blocks = [content_block] created_message.text = "Message with content blocks" updated = update_messages(created_message) assert len(updated) == 1 assert updated[0].text == "Message with content blocks" assert len(updated[0].content_blocks) == 1 # Verify the content block structure updated_block = updated[0].content_blocks[0] assert updated_block.title == "Test Block" assert len(updated_block.contents) == 2 # Verify text content text_content = updated_block.contents[0] assert text_content.type == "text" assert text_content.text == "Test content" assert text_content.duration == 5 assert text_content.header["title"] == "Test Header" # Verify tool content tool_content = updated_block.contents[1] assert tool_content.type == "tool_use" assert tool_content.name == "test_tool" assert tool_content.tool_input == {"param": "value"} assert tool_content.duration == 10 @pytest.mark.usefixtures("client") def test_update_message_with_nested_properties(created_message): # Create a text content with nested properties text_content = TextContent( type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15 ) content_block = ContentBlock( title="Test Properties", contents=[text_content], allow_markdown=True, media_url=["http://example.com/image.jpg"], ) # Set properties according to the Properties model structure created_message.properties = Properties( text_color="blue", background_color="white", edited=False, source=Source(id="test_id", display_name="Test Source", source="test"), icon="TestIcon", allow_markdown=True, state="complete", targets=[], ) created_message.text = "Message with nested properties" created_message.content_blocks = [content_block] updated = update_messages(created_message) assert len(updated) == 1 assert updated[0].text == "Message with nested properties" # Verify the properties were properly serialized and stored assert updated[0].properties.text_color == "blue" assert updated[0].properties.background_color == "white" assert updated[0].properties.edited is False assert updated[0].properties.source.id == "test_id" assert updated[0].properties.source.display_name == "Test Source" assert updated[0].properties.source.source == "test" assert updated[0].properties.icon == "TestIcon" assert updated[0].properties.allow_markdown is True assert updated[0].properties.state == "complete" assert updated[0].properties.targets == [] @pytest.mark.usefixtures("client") async def test_aupdate_single_message(created_message): # Modify the message created_message.text = "Updated message" updated = await aupdate_messages(created_message) assert len(updated) == 1 assert updated[0].text == "Updated message" assert updated[0].id == created_message.id @pytest.mark.usefixtures("client") async def test_aupdate_multiple_messages(created_messages): # Modify the messages for i, message in enumerate(created_messages): message.text = f"Updated message {i}" updated = await aupdate_messages(created_messages) assert len(updated) == len(created_messages) for i, message in enumerate(updated): assert message.text == f"Updated message {i}" assert message.id == created_messages[i].id @pytest.mark.usefixtures("client") async def test_aupdate_nonexistent_message(): # Create a message with a non-existent UUID message = MessageRead( id=uuid4(), # Generate a random UUID that won't exist in the database text="Test message", sender="User", sender_name="User", session_id="session_id", flow_id=uuid4(), ) updated = await aupdate_messages(message) assert len(updated) == 0 @pytest.mark.usefixtures("client") async def test_aupdate_mixed_messages(created_messages): # Create a mix of existing and non-existing messages nonexistent_message = MessageRead( id=uuid4(), # Generate a random UUID that won't exist in the database text="Test message", sender="User", sender_name="User", session_id="session_id", flow_id=uuid4(), ) messages_to_update = created_messages[:1] + [nonexistent_message] created_messages[0].text = "Updated existing message" updated = await aupdate_messages(messages_to_update) assert len(updated) == 1 assert updated[0].text == "Updated existing message" assert updated[0].id == created_messages[0].id assert isinstance(updated[0].id, UUID) # Verify ID is UUID type @pytest.mark.usefixtures("client") async def test_aupdate_message_with_timestamp(created_message): # Set a specific timestamp new_timestamp = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) created_message.timestamp = new_timestamp created_message.text = "Updated message with timestamp" updated = await aupdate_messages(created_message) assert len(updated) == 1 assert updated[0].text == "Updated message with timestamp" # Compare timestamps without timezone info since DB doesn't preserve it assert updated[0].timestamp.replace(tzinfo=None) == new_timestamp.replace(tzinfo=None) assert updated[0].id == created_message.id @pytest.mark.usefixtures("client") async def test_aupdate_multiple_messages_with_timestamps(created_messages): # Modify messages with different timestamps for i, message in enumerate(created_messages): message.text = f"Updated message {i}" message.timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) updated = await aupdate_messages(created_messages) assert len(updated) == len(created_messages) for i, message in enumerate(updated): assert message.text == f"Updated message {i}" # Compare timestamps without timezone info expected_timestamp = datetime(2024, 1, 1, i, 0, 0, tzinfo=timezone.utc) assert message.timestamp.replace(tzinfo=None) == expected_timestamp.replace(tzinfo=None) assert message.id == created_messages[i].id @pytest.mark.usefixtures("client") async def test_aupdate_message_with_content_blocks(created_message): # Create a content block using proper models text_content = TextContent( type="text", text="Test content", duration=5, header={"title": "Test Header", "icon": "TestIcon"} ) tool_content = ToolContent(type="tool_use", name="test_tool", tool_input={"param": "value"}, duration=10) content_block = ContentBlock(title="Test Block", contents=[text_content, tool_content], allow_markdown=True) created_message.content_blocks = [content_block] created_message.text = "Message with content blocks" updated = await aupdate_messages(created_message) assert len(updated) == 1 assert updated[0].text == "Message with content blocks" assert len(updated[0].content_blocks) == 1 # Verify the content block structure updated_block = updated[0].content_blocks[0] assert updated_block.title == "Test Block" assert len(updated_block.contents) == 2 # Verify text content text_content = updated_block.contents[0] assert text_content.type == "text" assert text_content.text == "Test content" assert text_content.duration == 5 assert text_content.header["title"] == "Test Header" # Verify tool content tool_content = updated_block.contents[1] assert tool_content.type == "tool_use" assert tool_content.name == "test_tool" assert tool_content.tool_input == {"param": "value"} assert tool_content.duration == 10 @pytest.mark.usefixtures("client") async def test_aupdate_message_with_nested_properties(created_message): # Create a text content with nested properties text_content = TextContent( type="text", text="Test content", header={"title": "Test Header", "icon": "TestIcon"}, duration=15 ) content_block = ContentBlock( title="Test Properties", contents=[text_content], allow_markdown=True, media_url=["http://example.com/image.jpg"], ) # Set properties according to the Properties model structure created_message.properties = Properties( text_color="blue", background_color="white", edited=False, source=Source(id="test_id", display_name="Test Source", source="test"), icon="TestIcon", allow_markdown=True, state="complete", targets=[], ) created_message.text = "Message with nested properties" created_message.content_blocks = [content_block] updated = await aupdate_messages(created_message) assert len(updated) == 1 assert updated[0].text == "Message with nested properties" # Verify the properties were properly serialized and stored assert updated[0].properties.text_color == "blue" assert updated[0].properties.background_color == "white" assert updated[0].properties.edited is False assert updated[0].properties.source.id == "test_id" assert updated[0].properties.source.display_name == "Test Source" assert updated[0].properties.source.source == "test" assert updated[0].properties.icon == "TestIcon" assert updated[0].properties.allow_markdown is True assert updated[0].properties.state == "complete" assert updated[0].properties.targets == []