KonradSzafer's picture
config update
988981a
raw
history blame
4.38 kB
import json
import requests
from urllib.parse import quote
import discord
from typing import List
from qa_engine import logger, Config, QAEngine
from discord_bot.client.utils import split_text_into_chunks
class DiscordClient(discord.Client):
"""
Discord Client class, used for interacting with a Discord server.
"""
def __init__(
self,
qa_engine: QAEngine,
config: Config,
):
logger.info('Initializing Discord client...')
intents = discord.Intents.all()
intents.message_content = True
super().__init__(intents=intents, command_prefix='!')
self.qa_engine: QAEngine = qa_engine
self.channel_ids: list[int] = DiscordClient._process_channel_ids(
config.discord_channel_ids
)
self.num_last_messages: int = config.num_last_messages
self.use_names_in_context: bool = config.use_names_in_context
self.enable_commands: bool = config.enable_commands
self.debug: bool = config.debug
self.min_message_len: int = 1800
self.max_message_len: int = 2000
assert all([isinstance(id, int) for id in self.channel_ids]), \
'All channel ids should be of type int'
assert self.num_last_messages >= 1, \
'The number of last messages in context should be at least 1'
@staticmethod
def _process_channel_ids(channel_ids) -> list[int]:
if isinstance(channel_ids, str):
return eval(channel_ids)
elif isinstance(channel_ids, list):
return channel_ids
elif isinstance(channel_ids, int):
return [channel_ids]
else:
return []
async def on_ready(self):
"""
Callback function to be called when the client is ready.
"""
logger.info('Successfully logged in as: {0.user}'.format(self))
await self.change_presence(activity=discord.Game(name='Chatting...'))
async def get_last_messages(self, message) -> List[str]:
"""
Method to fetch recent messages from a message's channel.
Args:
message (Message): The discord Message object used to identify the channel.
Returns:
List[str]: Reversed list of recent messages from the channel,
excluding the input message. Messages may be prefixed with the author's name
if `self.use_names_in_context` is True.
"""
last_messages: List[str] = []
async for msg in message.channel.history(
limit=self.num_last_messages):
if self.use_names_in_context:
last_messages.append(f'{msg.author}: {msg.content}')
else:
last_messages.append(msg.content)
last_messages.reverse()
last_messages.pop() # remove last message from context
return last_messages
async def send_message(self, message, answer: str, sources: str):
chunks = split_text_into_chunks(
text=answer,
split_characters=['. ', ', ', '\n'],
min_size=self.min_message_len,
max_size=self.max_message_len
)
for chunk in chunks:
await message.channel.send(chunk)
await message.channel.send(sources)
async def on_message(self, message):
if self.channel_ids and message.channel.id not in self.channel_ids:
return
if message.author == self.user:
return
"""
if self.enable_commands and message.content.startswith('!'):
if message.content == '!clear':
await message.channel.purge()
return
"""
last_messages = await self.get_last_messages(message)
context = '\n'.join(last_messages)
logger.info('Received message: {0.content}'.format(message))
response = self.qa_engine.get_response(
question=message.content,
messages_context=context
)
logger.info('Sending response: {0}'.format(response))
try:
await self.send_message(
message,
response.get_answer(),
response.get_sources_as_text()
)
except Exception as e:
logger.error('Failed to send response: {0}'.format(e))