Spaces:
Sleeping
Sleeping
import re | |
import random | |
from .config import AI_PHRASES | |
from .response_generation import generate_sim | |
def parse_model_response(response: dict, name: str = "") -> str: | |
""" | |
Parse the LLM response to extract the assistant's message and apply initial post-processing. | |
Args: | |
response (dict): The raw response from the LLM. | |
name (str, optional): Name to strip from the beginning of the text. Defaults to "". | |
Returns: | |
str: The cleaned and parsed assistant's message. | |
""" | |
assistant_message = response["choices"][0]["message"]["content"] | |
cleaned_text = postprocess_text( | |
assistant_message, | |
name=name, | |
human_prefix="user:", | |
assistant_prefix="assistant:" | |
) | |
return cleaned_text | |
def postprocess_text( | |
text: str, | |
name: str = "", | |
human_prefix: str = "user:", | |
assistant_prefix: str = "assistant:", | |
strip_name: bool = True | |
) -> str: | |
"""Eliminates whispers, reactions, ellipses, and quotation marks from generated text by LLMs. | |
Args: | |
text (str): The text to process. | |
name (str, optional): Name to strip from the beginning of the text. Defaults to "". | |
human_prefix (str, optional): The user prefix to remove. Defaults to "user:". | |
assistant_prefix (str, optional): The assistant prefix to remove. Defaults to "assistant:". | |
strip_name (bool, optional): Whether to remove the name at the beginning of the text. Defaults to True. | |
Returns: | |
str: Cleaned text. | |
""" | |
if text: | |
# Replace ellipses with a single period | |
text = re.sub(r'\.\.\.', '.', text) | |
# Remove unnecessary role prefixes | |
text = text.replace(human_prefix, "").replace(assistant_prefix, "") | |
# Remove whispers or other marked reactions | |
whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)" | |
reactions = re.compile(r"(\*[\w\s]+\*)") # remove things like "*stutters*" | |
text = whispers.sub("", text) | |
text = reactions.sub("", text) | |
# Remove all quotation marks (both single and double) | |
text = text.replace('"', '').replace("'", "") | |
# Normalize spaces | |
text = re.sub(r"\s+", " ", text).strip() | |
return text | |
def apply_guardrails(model_input: dict, response: str, endpoint_url: str, endpoint_bearer_token: str) -> str: | |
"""Apply the 'I am an AI' guardrail to model responses""" | |
attempt = 0 | |
max_attempts = 2 | |
while attempt < max_attempts and contains_ai_phrase(response): | |
# Regenerate the response without modifying the conversation history | |
completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token) | |
response = parse_model_response(completion) | |
attempt += 1 | |
if contains_ai_phrase(response): | |
# Use only the last user message for regeneration | |
memory = model_input['messages'] | |
last_user_message = next((msg for msg in reversed(memory) if msg['role'] == 'user'), None) | |
if last_user_message: | |
# Create a new conversation with system message and last user message | |
model_input_copy = { | |
**model_input, | |
'messages': [memory[0], last_user_message] # memory[0] is the system message | |
} | |
completion = generate_sim(model_input_copy, endpoint_url, endpoint_bearer_token) | |
response = parse_model_response(completion) | |
return response | |
def contains_ai_phrase(text: str) -> bool: | |
"""Check if the text contains any 'I am an AI' phrases.""" | |
text_lower = text.lower() | |
return any(phrase.lower() in text_lower for phrase in AI_PHRASES) | |
def truncate_response(text: str, punctuation_marks: tuple = ('.', '!', '?', '…')) -> str: | |
""" | |
Truncate the text at the last occurrence of a specified punctuation mark. | |
Args: | |
text (str): The text to truncate. | |
punctuation_marks (tuple, optional): A tuple of punctuation marks to use for truncation. Defaults to ('.', '!', '?', '…'). | |
Returns: | |
str: The truncated text. | |
""" | |
# Find the last position of any punctuation mark from the provided set | |
last_punct_position = max(text.rfind(p) for p in punctuation_marks) | |
# Check if any punctuation mark is found | |
if last_punct_position == -1: | |
# No punctuation found, return the original text | |
return text.strip() | |
# Return the truncated text up to and including the last punctuation mark | |
return text[:last_punct_position + 1].strip() | |
def split_texter_response(text: str) -> str: | |
""" | |
Splits the texter's response into multiple messages, | |
introducing '\ntexter:' prefixes after punctuation. | |
The number of messages is randomly chosen based on specified probabilities: | |
- 1 message: 30% chance | |
- 2 messages: 25% chance | |
- 3 messages: 45% chance | |
The first message does not include the '\ntexter:' prefix. | |
""" | |
# Use regex to split text into sentences, keeping the punctuation | |
sentences = re.findall(r'[^.!?]+[.!?]*', text) | |
# Remove empty strings from sentences | |
sentences = [s.strip() for s in sentences if s.strip()] | |
# Decide number of messages based on specified probabilities | |
num_messages = random.choices([1, 2, 3], weights=[0.3, 0.25, 0.45], k=1)[0] | |
# If not enough sentences to make the splits, adjust num_messages | |
if len(sentences) < num_messages: | |
num_messages = len(sentences) | |
# If num_messages is 1, return the original text | |
if num_messages == 1: | |
return text.strip() | |
# Calculate split points | |
# We need to divide the sentences into num_messages parts | |
avg = len(sentences) / num_messages | |
split_indices = [int(round(avg * i)) for i in range(1, num_messages)] | |
# Build the new text | |
new_text = '' | |
start = 0 | |
for i, end in enumerate(split_indices + [len(sentences)]): | |
segment_sentences = sentences[start:end] | |
segment_text = ' '.join(segment_sentences).strip() | |
if i == 0: | |
# First segment, do not add '\ntexter:' | |
new_text += segment_text | |
else: | |
# Subsequent segments, add '\ntexter:' | |
new_text += f"\ntexter: {segment_text}" | |
start = end | |
return new_text.strip() | |
def process_model_response(completion: dict, model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> str: | |
""" | |
Process the raw model response, including parsing, applying guardrails, | |
truncation, and splitting the response into multiple messages if necessary. | |
Args: | |
completion (dict): Raw response from the LLM. | |
model_input (dict): The model input containing the conversation history. | |
endpoint_url (str): The URL of the endpoint. | |
endpoint_bearer_token (str): The authentication token for endpoint. | |
Returns: | |
str: Final processed response ready for the APP. | |
""" | |
# Step 1: Parse the raw response to extract the assistant's message | |
assistant_message = parse_model_response(completion) | |
# Step 2: Apply guardrails (handle possible AI responses) | |
guardrail_message = apply_guardrails(model_input, assistant_message, endpoint_url, endpoint_bearer_token) | |
# Step 3: Apply response truncation | |
truncated_message = truncate_response(guardrail_message) | |
# Step 4: Split the response into multiple messages if needed | |
final_response = split_texter_response(truncated_message) | |
return final_response | |