File size: 5,253 Bytes
c69cba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfdf8df
c69cba4
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c81a6
c69cba4
 
 
 
 
 
f6c81a6
 
 
 
 
 
 
 
 
 
 
 
c69cba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfdf8df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c69cba4
bfdf8df
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
import requests
from urllib.parse import quote
import discord
from typing import List

from qa_engine import logger, QAEngine
from discord_bot.client.utils import split_text_into_chunks


class DiscordClient(discord.Client):
    """
    Discord Client class, used for interacting with a Discord server.

    Args:
        qa_service_url (str): The URL of the question answering service.
        num_last_messages (int, optional): The number of previous messages to use as context for generating answers.
        Defaults to 5.
        use_names_in_context (bool, optional): Whether to include user names in the message context. Defaults to True.
        enable_commands (bool, optional): Whether to enable commands for the bot. Defaults to True.

    Attributes:
        qa_service_url (str): The URL of the question answering service.
        num_last_messages (int): The number of previous messages to use as context for generating answers.
        use_names_in_context (bool): Whether to include user names in the message context.
        enable_commands (bool): Whether to enable commands for the bot.
        max_message_len (int): The maximum length of a message.
        system_prompt (str): The system prompt to be used.

    """
    def __init__(
        self,
        qa_engine: QAEngine,
        channel_ids: list[int] = [],
        num_last_messages: int = 5,
        use_names_in_context: bool = True,
        enable_commands: bool = True,
        debug: bool = False
    ):
        logger.info('Initializing Discord client...')
        intents = discord.Intents.all()
        intents.message_content = True
        super().__init__(intents=intents, command_prefix='!')

        assert num_last_messages >= 1, \
            'The number of last messages in context should be at least 1'

        self.qa_engine: QAEngine = qa_engine
        self.channel_ids: list[int] = DiscordClient._process_channel_ids(channel_ids)
        self.num_last_messages: int = num_last_messages
        self.use_names_in_context: bool = use_names_in_context
        self.enable_commands: bool = enable_commands
        self.debug: bool = debug
        self.min_messgae_len: int = 1800
        self.max_message_len: int = 2000
        
    
    @staticmethod
    def _process_channel_ids(channel_ids) -> list[int]:
        if isinstance(channel_ids, str):
            return eval(channel_ids)
        elif isinstance(channel_ids, list):
            return channel_ids
        elif isinstance(channel_ids, int):
            return [channel_ids]
        else:
            return []


    async def on_ready(self):
        """
        Callback function to be called when the client is ready.
        """
        logger.info('Successfully logged in as: {0.user}'.format(self))
        await self.change_presence(activity=discord.Game(name='Chatting...'))


    async def get_last_messages(self, message) -> List[str]:
        """
        Method to fetch recent messages from a message's channel.

        Args:
            message (Message): The discord Message object used to identify the channel.

        Returns:
            List[str]: Reversed list of recent messages from the channel,
            excluding the input message. Messages may be prefixed with the author's name 
            if `self.use_names_in_context` is True.
        """
        last_messages: List[str] = []
        async for msg in message.channel.history(
            limit=self.num_last_messages):
            if self.use_names_in_context:
                last_messages.append(f'{msg.author}: {msg.content}')
            else:
                last_messages.append(msg.content)
        last_messages.reverse()
        last_messages.pop() # remove last message from context
        return last_messages


    async def send_message(self, message, answer: str, sources: str):
        chunks = split_text_into_chunks(
            text=answer,
            split_characters=['. ', ', ', '\n'],
            min_size=self.min_messgae_len,
            max_size=self.max_message_len
        )
        for chunk in chunks:
            await message.channel.send(chunk)
        await message.channel.send(sources)


    async def on_message(self, message):

        if self.channel_ids and message.channel.id not in self.channel_ids:
            return
        
        if message.author == self.user:
            return
        
        """
        if self.enable_commands and message.content.startswith('!'):
            if message.content == '!clear':
                await message.channel.purge()
                return        
        """

        last_messages = await self.get_last_messages(message)
        context = '\n'.join(last_messages)

        logger.info('Received message: {0.content}'.format(message))
        response = self.qa_engine.get_response(
            question=message.content,
            messages_context=context
        )
        logger.info('Sending response: {0}'.format(response))
        try:
            await self.send_message(
                message,
                response.get_answer(),
                response.get_sources_as_text()
            )
        except Exception as e:
            logger.error('Failed to send response: {0}'.format(e))