import json import requests from urllib.parse import quote import discord from typing import List from qa_engine import logger, QAEngine from discord_bot.client.utils import split_text_into_chunks class DiscordClient(discord.Client): """ Discord Client class, used for interacting with a Discord server. Args: qa_service_url (str): The URL of the question answering service. num_last_messages (int, optional): The number of previous messages to use as context for generating answers. Defaults to 5. use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True. enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True. Attributes: qa_service_url (str): The URL of the question answering service. num_last_messages (int): The number of previous messages to use as context for generating answers. use_names_in_context (bool): Whether to include user names in the message context. enable_commands (bool): Whether to enable commands for the bot. max_message_len (int): The maximum length of a message. system_prompt (str): The system prompt to be used. """ def __init__( self, qa_engine: QAEngine, channel_ids: list[int] = [], num_last_messages: int = 5, use_names_in_context: bool = True, enable_commands: bool = True, debug: bool = False ): logger.info('Initializing Discord client...') intents = discord.Intents.all() intents.message_content = True super().__init__(intents=intents, command_prefix='!') assert num_last_messages >= 1, \ 'The number of last messages in context should be at least 1' self.qa_engine: QAEngine = qa_engine self.channel_ids: list[int] = DiscordClient._process_channel_ids(channel_ids) self.num_last_messages: int = num_last_messages self.use_names_in_context: bool = use_names_in_context self.enable_commands: bool = enable_commands self.debug: bool = debug self.min_messgae_len: int = 1800 self.max_message_len: int = 2000 @staticmethod def _process_channel_ids(channel_ids) -> list[int]: if isinstance(channel_ids, str): return eval(channel_ids) elif isinstance(channel_ids, list): return channel_ids elif isinstance(channel_ids, int): return [channel_ids] else: return [] async def on_ready(self): """ Callback function to be called when the client is ready. """ logger.info('Successfully logged in as: {0.user}'.format(self)) await self.change_presence(activity=discord.Game(name='Chatting...')) async def get_last_messages(self, message) -> List[str]: """ Method to fetch recent messages from a message's channel. Args: message (Message): The discord Message object used to identify the channel. Returns: List[str]: Reversed list of recent messages from the channel, excluding the input message. Messages may be prefixed with the author's name if `self.use_names_in_context` is True. """ last_messages: List[str] = [] async for msg in message.channel.history( limit=self.num_last_messages): if self.use_names_in_context: last_messages.append(f'{msg.author}: {msg.content}') else: last_messages.append(msg.content) last_messages.reverse() last_messages.pop() # remove last message from context return last_messages async def send_message(self, message, answer: str, sources: str): chunks = split_text_into_chunks( text=answer, split_characters=['. ', ', ', '\n'], min_size=self.min_messgae_len, max_size=self.max_message_len ) for chunk in chunks: await message.channel.send(chunk) await message.channel.send(sources) async def on_message(self, message): if self.channel_ids and message.channel.id not in self.channel_ids: return if message.author == self.user: return """ if self.enable_commands and message.content.startswith('!'): if message.content == '!clear': await message.channel.purge() return """ last_messages = await self.get_last_messages(message) context = '\n'.join(last_messages) logger.info('Received message: {0.content}'.format(message)) response = self.qa_engine.get_response( question=message.content, messages_context=context ) logger.info('Sending response: {0}'.format(response)) try: await self.send_message( message, response.get_answer(), response.get_sources_as_text() ) except Exception as e: logger.error('Failed to send response: {0}'.format(e))