|
from literalai import LiteralClient |
|
from dotenv import load_dotenv |
|
import os |
|
from typing import List, Dict, Any |
|
load_dotenv() |
|
|
|
|
|
class LiteralThreadManager: |
|
""" |
|
The LiteralThreadManager is responsible for managing and extracting conversation threads from LiteralAI db. |
|
It includes methods for retrieving threads, messages, processing chat history, and ensuring efficient and clear |
|
extraction of conversation details. This class is designed to support functions that handle user and |
|
assistant interactions within HarmonyAI. |
|
""" |
|
|
|
def __init__(self, api_key: str): |
|
""" |
|
Initialize the LiteralThreadManager with a LiteralClient. |
|
|
|
Args: |
|
api_key (str): The API key for the LiteralClient. |
|
""" |
|
self.literal_client = LiteralClient(api_key=api_key) |
|
|
|
@staticmethod |
|
def threads_to_dict(threads_input) -> List[Dict[str, Any]]: |
|
""" |
|
Convert a list of threads to a list of dictionaries. |
|
|
|
Args: |
|
threads_input: The input threads to be converted. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries representing the threads. |
|
""" |
|
return [a_thread.to_dict() for a_thread in threads_input.data] |
|
|
|
def filter_threads_by_participant(self, participant_name: str) -> List[Dict[str, Any]]: |
|
""" |
|
Filter threads by participant name. |
|
|
|
Args: |
|
participant_name (str): The name of the participant. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries representing the filtered threads. |
|
""" |
|
self._all_threads = self.literal_client.api.get_threads() |
|
self.thread_dict_list = self.threads_to_dict(self._all_threads) |
|
return [thread for thread in self.thread_dict_list if |
|
thread['participant']['identifier'].lower() == participant_name.lower()] |
|
|
|
def filter_thread_by_id(self, thread_id: str) -> List[Dict[str, Any]]: |
|
""" |
|
Filter a thread by its ID. |
|
|
|
Args: |
|
thread_id (str): The ID of the thread. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list containing the dictionary representation of the thread. |
|
""" |
|
return [self.literal_client.api.get_thread(thread_id).to_dict()] |
|
|
|
def get_other_partner_thread_id(self, thread_id: str) -> str: |
|
""" |
|
Get the partner thread ID of a given thread. |
|
|
|
Args: |
|
thread_id (str): The ID of the thread. |
|
|
|
Returns: |
|
str: The partner thread ID. |
|
""" |
|
current_thread = self.filter_thread_by_id(thread_id) |
|
partner_thread_id = current_thread[0].get("metadata")["partner_thread_id"] |
|
return partner_thread_id |
|
|
|
def get_messages_from_thread(self, input_thread) -> List[str]: |
|
""" |
|
Get all messages from a thread. |
|
|
|
Args: |
|
input_thread: The input thread to retrieve messages from. |
|
|
|
Returns: |
|
List[str]: A list of messages from the thread. |
|
""" |
|
steps_in_thread = input_thread[0]['steps'] |
|
return steps_in_thread |
|
|
|
def get_user_name_from_thread(self, thread) -> str: |
|
""" |
|
Get the participant name from a thread. |
|
|
|
Args: |
|
thread: The input thread to retrieve the participant name from. |
|
|
|
Returns: |
|
str: The participant's name. |
|
""" |
|
return thread[0]['participant']['identifier'] |
|
|
|
def is_conflict_resolved(self, thread_id: str) -> bool: |
|
""" |
|
Check if the conflict in a thread is resolved. |
|
|
|
Args: |
|
thread_id (str): The ID of the thread. |
|
|
|
Returns: |
|
bool: True if the conflict is resolved, False otherwise. |
|
""" |
|
if self.literal_client.api.get_thread(thread_id).metadata.get('isResolved') is None: |
|
self.set_conflict_resolved(thread_id, False) |
|
return self.literal_client.api.get_thread(thread_id).metadata['isResolved'] |
|
|
|
def set_conflict_resolved(self, thread_id: str, is_resolved: bool): |
|
""" |
|
Set the conflict resolution status of a thread. |
|
|
|
Args: |
|
thread_id (str): The ID of the thread. |
|
is_resolved (bool): The resolution status to be set. |
|
""" |
|
thread = self.literal_client.api.get_thread(thread_id) |
|
thread.metadata['isResolved'] = is_resolved |
|
self.literal_client.api.update_thread(id=thread.id, metadata=thread.metadata) |
|
|
|
def extract_chat_history_from_thread(self, input_thread) -> List[Dict[str, str]]: |
|
""" |
|
Extract chat history from a thread. |
|
|
|
Args: |
|
input_thread: The input thread to extract chat history from. |
|
|
|
Returns: |
|
List[Dict[str, str]]: A list of dictionaries representing the chat history. |
|
""" |
|
conversation = self.get_messages_from_thread(input_thread) |
|
name = self.get_user_name_from_thread(input_thread) |
|
chat_history = [] |
|
|
|
for convo in conversation: |
|
if convo['type'] in ['user_message', 'assistant_message']: |
|
role = 'user' if convo['type'] == 'user_message' else 'assistant' |
|
content = convo['output']['content'] |
|
chat_history.append({ |
|
'role': role, |
|
'name': name if role == 'user' else 'HarmonyAI', |
|
'content': content |
|
}) |
|
|
|
if 'generation' in convo and convo['generation'] is not None: |
|
if 'messages' in convo['generation']: |
|
for message in convo['generation']['messages']: |
|
if message['role'] in ['user', 'assistant']: |
|
chat_history.append({ |
|
'role': message['role'], |
|
'name': name if message['role'] == 'user' else 'HarmonyAI', |
|
'content': message['content'] |
|
}) |
|
if 'messageCompletion' in convo['generation'] and convo['generation']['messageCompletion'] is not None: |
|
message_completion = convo['generation']['messageCompletion'] |
|
chat_history.append({ |
|
'role': message_completion['role'], |
|
'name': 'HarmonyAI', |
|
'content': message_completion['content'] |
|
}) |
|
|
|
filtered_chat_history = [] |
|
for message in chat_history: |
|
if not any( |
|
m['content'] == message['content'] and m['role'] == message['role'] for m in filtered_chat_history): |
|
filtered_chat_history.append(message) |
|
|
|
|
|
for message in conversation: |
|
if message.get('type') == 'user_message': |
|
last_user_message = message['output']['content'] |
|
filtered_chat_history.append({ |
|
'role': 'user', |
|
'name': name, |
|
'content': last_user_message |
|
}) |
|
|
|
return filtered_chat_history |
|
|
|
def count_llm_messages(self, thread_id: str) -> int: |
|
""" |
|
Count the number of AI-generated messages in a thread. |
|
|
|
Args: |
|
thread_id (str): The ID of the thread. |
|
|
|
Returns: |
|
int: The number of AI-generated messages in the thread. |
|
""" |
|
thread = self.literal_client.api.get_thread(thread_id) |
|
return sum(step.type == 'llm' for step in thread.steps) |
|
|
|
def send_message(self, thread_id: str, message: str): |
|
""" |
|
Send a message to a thread. |
|
|
|
Args: |
|
thread_id (str): The ID of the thread. |
|
message (str): The message content to be sent. |
|
""" |
|
self.literal_client.api.create_step(thread_id=thread_id, type='assistant_message', name='HarmonyAI', |
|
output={'content': message}) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
manager = LiteralThreadManager(api_key=os.getenv("LITERAL_API_KEY")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
thread_id = '1bb44dc5-0b81-42e9-84a2-cd34b6fe8480' |
|
linda_thrad_id = '76d03507-4084-46ba-ba7e-734be1f58304' |
|
thread_content = manager.filter_thread_by_id(linda_thrad_id) |
|
|
|
|
|
|
|
chat = manager.extract_chat_history_from_thread(thread_content) |
|
print(f"\nChat history for thread ID {thread_id}:") |
|
for message in chat: |
|
print(f"{message['role']} - {message['name']}: {message['content']}") |