Spaces:
Runtime error
Runtime error
import json | |
import requests | |
from urllib.parse import quote | |
import discord | |
from typing import List | |
from qa_engine import logger, Config, QAEngine | |
from discord_bot.client.utils import split_text_into_chunks | |
class DiscordClient(discord.Client): | |
""" | |
Discord Client class, used for interacting with a Discord server. | |
""" | |
def __init__( | |
self, | |
qa_engine: QAEngine, | |
config: Config, | |
): | |
logger.info('Initializing Discord client...') | |
intents = discord.Intents.all() | |
intents.message_content = True | |
super().__init__(intents=intents, command_prefix='!') | |
self.qa_engine: QAEngine = qa_engine | |
self.channel_ids: list[int] = DiscordClient._process_channel_ids( | |
config.discord_channel_ids | |
) | |
self.num_last_messages: int = config.num_last_messages | |
self.use_names_in_context: bool = config.use_names_in_context | |
self.enable_commands: bool = config.enable_commands | |
self.debug: bool = config.debug | |
self.min_message_len: int = 1800 | |
self.max_message_len: int = 2000 | |
assert all([isinstance(id, int) for id in self.channel_ids]), \ | |
'All channel ids should be of type int' | |
assert self.num_last_messages >= 1, \ | |
'The number of last messages in context should be at least 1' | |
def _process_channel_ids(channel_ids) -> list[int]: | |
if isinstance(channel_ids, str): | |
return eval(channel_ids) | |
elif isinstance(channel_ids, list): | |
return channel_ids | |
elif isinstance(channel_ids, int): | |
return [channel_ids] | |
else: | |
return [] | |
async def on_ready(self): | |
""" | |
Callback function to be called when the client is ready. | |
""" | |
logger.info('Successfully logged in as: {0.user}'.format(self)) | |
await self.change_presence(activity=discord.Game(name='Chatting...')) | |
async def get_last_messages(self, message) -> List[str]: | |
""" | |
Method to fetch recent messages from a message's channel. | |
Args: | |
message (Message): The discord Message object used to identify the channel. | |
Returns: | |
List[str]: Reversed list of recent messages from the channel, | |
excluding the input message. Messages may be prefixed with the author's name | |
if `self.use_names_in_context` is True. | |
""" | |
last_messages: List[str] = [] | |
async for msg in message.channel.history( | |
limit=self.num_last_messages): | |
if self.use_names_in_context: | |
last_messages.append(f'{msg.author}: {msg.content}') | |
else: | |
last_messages.append(msg.content) | |
last_messages.reverse() | |
last_messages.pop() # remove last message from context | |
return last_messages | |
async def send_message(self, message, answer: str, sources: str): | |
chunks = split_text_into_chunks( | |
text=answer, | |
split_characters=['. ', ', ', '\n'], | |
min_size=self.min_message_len, | |
max_size=self.max_message_len | |
) | |
for chunk in chunks: | |
await message.channel.send(chunk) | |
await message.channel.send(sources) | |
async def on_message(self, message): | |
if self.channel_ids and message.channel.id not in self.channel_ids: | |
return | |
if message.author == self.user: | |
return | |
""" | |
if self.enable_commands and message.content.startswith('!'): | |
if message.content == '!clear': | |
await message.channel.purge() | |
return | |
""" | |
last_messages = await self.get_last_messages(message) | |
context = '\n'.join(last_messages) | |
logger.info('Received message: {0.content}'.format(message)) | |
response = self.qa_engine.get_response( | |
question=message.content, | |
messages_context=context | |
) | |
logger.info('Sending response: {0}'.format(response)) | |
try: | |
await self.send_message( | |
message, | |
response.get_answer(), | |
response.get_sources_as_text() | |
) | |
except Exception as e: | |
logger.error('Failed to send response: {0}'.format(e)) | |