File size: 4,699 Bytes
7f51819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
This module is responsible for modifying the chat prompt and history.
"""
import re

import extensions.superboogav2.parameters as parameters

from modules import chat, shared
from modules.text_generation import get_encoded_length
from modules.logging_colors import logger
from modules.chat import load_character_memoized
from extensions.superboogav2.utils import create_context_text, create_metadata_source

from .data_processor import process_and_add_to_collector
from .chromadb import ChromaCollector


CHAT_METADATA = create_metadata_source('automatic-chat-insert')


def _remove_tag_if_necessary(user_input: str):
    if not parameters.get_is_manual():
        return user_input
    
    return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input)


def _should_query(input: str):
    if not parameters.get_is_manual():
        return True
    
    if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE):
        return True
    
    return False


def _format_single_exchange(name, text):
    if re.search(r':\s*$', name):
        return '{} {}\n'.format(name, text)
    else:
        return '{}: {}\n'.format(name, text)


def _get_names(state: dict):
    default_char = shared.settings.get('character', "Assistant")
    default_user = shared.settings.get('name1', "You")
    character = state.get('character', default_char)
    user_name = state.get('name1', default_user)
    user_name, bot_name, _, _, _ = load_character_memoized(character, user_name, '')

    return user_name, bot_name


def _concatinate_history(history: dict, state: dict):
    full_history_text = ''
    user_name, bot_name = _get_names(state)

    # Grab the internal history.
    internal_history = history['internal']
    assert isinstance(internal_history, list)

    # Iterate through the history.
    for exchange in internal_history:
        assert isinstance(exchange, list)

        if len(exchange) >= 1:
            full_history_text += _format_single_exchange(user_name, exchange[0])
        if len(exchange) >= 2:
            full_history_text += _format_single_exchange(bot_name, exchange[1])

    return full_history_text[:-1] # Remove the last new line.


def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
    num_context_tokens = get_encoded_length(context_text)

    names = _get_names(state)[::-1]

    history_tokens = 0
    replace_position = None
    for i, messages in enumerate(reversed(history['internal'])):
        for j, message in enumerate(reversed(messages)):
            num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message))
            
            # TODO: This is an extremely naive solution. A more robust implementation must be made.
            if history_tokens + num_context_tokens <= max_len:
                # This message can be replaced
                replace_position = (i, j)
            
            history_tokens += num_message_tokens
    
    if replace_position is None:
        logger.warn("The provided context_text is too long to replace any message in the history.")
    else:
        # replace the message at replace_position with context_text
        i, j = replace_position
        history['internal'][-i-1][-j-1] = context_text


def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
    if parameters.get_add_chat_to_data():
        # Get the whole history as one string
        history_as_text = _concatinate_history(kwargs['history'], state)

        if history_as_text:
            # Delete all documents that were auto-inserted
            collector.delete(ids_to_delete=None, where=CHAT_METADATA)
            # Insert the processed history
            process_and_add_to_collector(history_as_text, collector, False, CHAT_METADATA)

    if _should_query(user_input):
        user_input = _remove_tag_if_necessary(user_input)
        results = collector.get_sorted_by_dist(user_input, n_results=parameters.get_chunk_count(), max_token_count=int(parameters.get_max_token_count()))

        # Check if the strategy is to modify the last message. If so, prepend or append to the user query.
        if parameters.get_injection_strategy() == parameters.APPEND_TO_LAST:
            user_input = user_input + create_context_text(results)
        elif parameters.get_injection_strategy() == parameters.PREPEND_TO_LAST:
            user_input = create_context_text(results) + user_input
        elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT:
            _hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state)
    
    return chat.generate_chat_prompt(user_input, state, **kwargs)