File size: 7,462 Bytes
42a7266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import re
import random
from .config import AI_PHRASES
from .response_generation import generate_sim

def parse_model_response(response: dict, name: str = "") -> str:
    """
    Parse the LLM response to extract the assistant's message and apply initial post-processing.

    Args:
        response (dict): The raw response from the LLM.
        name (str, optional): Name to strip from the beginning of the text. Defaults to "".

    Returns:
        str: The cleaned and parsed assistant's message.
    """
    assistant_message = response["choices"][0]["message"]["content"]
    cleaned_text = postprocess_text(
        assistant_message, 
        name=name, 
        human_prefix="user:",
        assistant_prefix="assistant:"
    )
    return cleaned_text

def postprocess_text(
    text: str,
    name: str = "",
    human_prefix: str = "user:",
    assistant_prefix: str = "assistant:",
    strip_name: bool = True
) -> str:
    """Eliminates whispers, reactions, ellipses, and quotation marks from generated text by LLMs.

    Args:
        text (str): The text to process.
        name (str, optional): Name to strip from the beginning of the text. Defaults to "".
        human_prefix (str, optional): The user prefix to remove. Defaults to "user:".
        assistant_prefix (str, optional): The assistant prefix to remove. Defaults to "assistant:".
        strip_name (bool, optional): Whether to remove the name at the beginning of the text. Defaults to True.

    Returns:
        str: Cleaned text.
    """
    if text:
        # Replace ellipses with a single period
        text = re.sub(r'\.\.\.', '.', text)

        # Remove unnecessary role prefixes
        text = text.replace(human_prefix, "").replace(assistant_prefix, "")

        # Remove whispers or other marked reactions
        whispers = re.compile(r"(\([\w\s]+\))")  # remove things like "(whispers)"
        reactions = re.compile(r"(\*[\w\s]+\*)")  # remove things like "*stutters*"
        text = whispers.sub("", text)
        text = reactions.sub("", text)

        # Remove all quotation marks (both single and double)
        text = text.replace('"', '').replace("'", "")

        # Normalize spaces
        text = re.sub(r"\s+", " ", text).strip()

    return text

def apply_guardrails(model_input: dict, response: str, endpoint_url: str, endpoint_bearer_token: str) -> str:
    """Apply the 'I am an AI' guardrail to model responses"""
    attempt = 0
    max_attempts = 2

    while attempt < max_attempts and contains_ai_phrase(response):
        # Regenerate the response without modifying the conversation history
        completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token)
        response = parse_model_response(completion)
        attempt += 1

    if contains_ai_phrase(response):
        # Use only the last user message for regeneration
        memory = model_input['messages']
        last_user_message = next((msg for msg in reversed(memory) if msg['role'] == 'user'), None)
        if last_user_message:
            # Create a new conversation with system message and last user message
            model_input_copy = {
                **model_input,
                'messages': [memory[0], last_user_message]  # memory[0] is the system message
            }
            completion = generate_sim(model_input_copy, endpoint_url, endpoint_bearer_token)
            response = parse_model_response(completion)

    return response


def contains_ai_phrase(text: str) -> bool:
    """Check if the text contains any 'I am an AI' phrases."""
    text_lower = text.lower()
    return any(phrase.lower() in text_lower for phrase in AI_PHRASES)

def truncate_response(text: str, punctuation_marks: tuple = ('.', '!', '?', '…')) -> str:
    """
    Truncate the text at the last occurrence of a specified punctuation mark.

    Args:
        text (str): The text to truncate.
        punctuation_marks (tuple, optional): A tuple of punctuation marks to use for truncation. Defaults to ('.', '!', '?', '…').

    Returns:
        str: The truncated text.
    """
    # Find the last position of any punctuation mark from the provided set
    last_punct_position = max(text.rfind(p) for p in punctuation_marks)

    # Check if any punctuation mark is found
    if last_punct_position == -1:
        # No punctuation found, return the original text
        return text.strip()

    # Return the truncated text up to and including the last punctuation mark
    return text[:last_punct_position + 1].strip()

def split_texter_response(text: str) -> str:
    """
    Splits the texter's response into multiple messages,
    introducing '\ntexter:' prefixes after punctuation.

    The number of messages is randomly chosen based on specified probabilities:
    - 1 message: 30% chance
    - 2 messages: 25% chance
    - 3 messages: 45% chance

    The first message does not include the '\ntexter:' prefix.
    """
    # Use regex to split text into sentences, keeping the punctuation
    sentences = re.findall(r'[^.!?]+[.!?]*', text)
    # Remove empty strings from sentences
    sentences = [s.strip() for s in sentences if s.strip()]

    # Decide number of messages based on specified probabilities
    num_messages = random.choices([1, 2, 3], weights=[0.3, 0.25, 0.45], k=1)[0]

    # If not enough sentences to make the splits, adjust num_messages
    if len(sentences) < num_messages:
        num_messages = len(sentences)

    # If num_messages is 1, return the original text
    if num_messages == 1:
        return text.strip()

    # Calculate split points
    # We need to divide the sentences into num_messages parts
    avg = len(sentences) / num_messages
    split_indices = [int(round(avg * i)) for i in range(1, num_messages)]

    # Build the new text
    new_text = ''
    start = 0
    for i, end in enumerate(split_indices + [len(sentences)]):
        segment_sentences = sentences[start:end]
        segment_text = ' '.join(segment_sentences).strip()
        if i == 0:
            # First segment, do not add '\ntexter:'
            new_text += segment_text
        else:
            # Subsequent segments, add '\ntexter:'
            new_text += f"\ntexter: {segment_text}"
        start = end
    return new_text.strip()

def process_model_response(completion: dict, model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> str:
    """
    Process the raw model response, including parsing, applying guardrails,
    truncation, and splitting the response into multiple messages if necessary.

    Args:
        completion (dict): Raw response from the LLM.
        model_input (dict): The model input containing the conversation history.
        endpoint_url (str): The URL of the endpoint.
        endpoint_bearer_token (str): The authentication token for endpoint.

    Returns:
        str: Final processed response ready for the APP.
    """
    # Step 1: Parse the raw response to extract the assistant's message
    assistant_message = parse_model_response(completion)

    # Step 2: Apply guardrails (handle possible AI responses)
    guardrail_message = apply_guardrails(model_input, assistant_message, endpoint_url, endpoint_bearer_token)

    # Step 3: Apply response truncation
    truncated_message = truncate_response(guardrail_message)

    # Step 4: Split the response into multiple messages if needed
    final_response = split_texter_response(truncated_message)

    return final_response