File size: 4,382 Bytes
c69cba4
 
 
 
 
 
988981a
c69cba4
 
 
 
 
 
 
 
 
 
988981a
 
c69cba4
 
 
 
 
 
988981a
 
 
 
 
 
 
 
c69cba4
f6c81a6
988981a
 
 
 
 
f6c81a6
 
 
 
 
 
 
 
 
 
 
c69cba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988981a
c69cba4
 
 
 
 
 
 
 
bfdf8df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c69cba4
bfdf8df
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import requests
from urllib.parse import quote
import discord
from typing import List

from qa_engine import logger, Config, QAEngine
from discord_bot.client.utils import split_text_into_chunks


class DiscordClient(discord.Client):
    """
    Discord Client class, used for interacting with a Discord server.
    """
    def __init__(
        self,
        qa_engine: QAEngine,
        config: Config,  
    ):  
        logger.info('Initializing Discord client...')
        intents = discord.Intents.all()
        intents.message_content = True
        super().__init__(intents=intents, command_prefix='!')

        self.qa_engine: QAEngine = qa_engine
        self.channel_ids: list[int] = DiscordClient._process_channel_ids(
            config.discord_channel_ids
        )
        self.num_last_messages: int = config.num_last_messages
        self.use_names_in_context: bool = config.use_names_in_context
        self.enable_commands: bool = config.enable_commands
        self.debug: bool = config.debug
        self.min_message_len: int = 1800
        self.max_message_len: int = 2000
        
        assert all([isinstance(id, int) for id in self.channel_ids]), \
            'All channel ids should be of type int'
        assert self.num_last_messages >= 1, \
            'The number of last messages in context should be at least 1'
        
    
    @staticmethod
    def _process_channel_ids(channel_ids) -> list[int]:
        if isinstance(channel_ids, str):
            return eval(channel_ids)
        elif isinstance(channel_ids, list):
            return channel_ids
        elif isinstance(channel_ids, int):
            return [channel_ids]
        else:
            return []


    async def on_ready(self):
        """
        Callback function to be called when the client is ready.
        """
        logger.info('Successfully logged in as: {0.user}'.format(self))
        await self.change_presence(activity=discord.Game(name='Chatting...'))


    async def get_last_messages(self, message) -> List[str]:
        """
        Method to fetch recent messages from a message's channel.

        Args:
            message (Message): The discord Message object used to identify the channel.

        Returns:
            List[str]: Reversed list of recent messages from the channel,
            excluding the input message. Messages may be prefixed with the author's name 
            if `self.use_names_in_context` is True.
        """
        last_messages: List[str] = []
        async for msg in message.channel.history(
            limit=self.num_last_messages):
            if self.use_names_in_context:
                last_messages.append(f'{msg.author}: {msg.content}')
            else:
                last_messages.append(msg.content)
        last_messages.reverse()
        last_messages.pop() # remove last message from context
        return last_messages


    async def send_message(self, message, answer: str, sources: str):
        chunks = split_text_into_chunks(
            text=answer,
            split_characters=['. ', ', ', '\n'],
            min_size=self.min_message_len,
            max_size=self.max_message_len
        )
        for chunk in chunks:
            await message.channel.send(chunk)
        await message.channel.send(sources)


    async def on_message(self, message):

        if self.channel_ids and message.channel.id not in self.channel_ids:
            return
        
        if message.author == self.user:
            return
        
        """
        if self.enable_commands and message.content.startswith('!'):
            if message.content == '!clear':
                await message.channel.purge()
                return        
        """

        last_messages = await self.get_last_messages(message)
        context = '\n'.join(last_messages)

        logger.info('Received message: {0.content}'.format(message))
        response = self.qa_engine.get_response(
            question=message.content,
            messages_context=context
        )
        logger.info('Sending response: {0}'.format(response))
        try:
            await self.send_message(
                message,
                response.get_answer(),
                response.get_sources_as_text()
            )
        except Exception as e:
            logger.error('Failed to send response: {0}'.format(e))