tldw / App_Function_Libraries /Gradio_UI /Character_interaction_tab.py
oceansweep's picture
Upload 169 files
c5b0bb7 verified
raw
history blame
24.8 kB
# Character_Interaction_tab.py
# Description: This file contains the functions that are used for Character Interactions in the Gradio UI.
#
# Imports
import base64
import io
import uuid
from datetime import datetime as datetime
import logging
import json
import os
from typing import List, Dict, Tuple, Union
#
# External Imports
import gradio as gr
from PIL import Image
#
# Local Imports
from App_Function_Libraries.Chat.Chat_Functions import chat, load_characters, save_chat_history_to_db_wrapper
from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper
from App_Function_Libraries.Gradio_UI.Writing_tab import generate_writing_feedback
from App_Function_Libraries.Utils.Utils import default_api_endpoint, format_api_name, global_api_endpoints
#
########################################################################################################################
#
# Single-Character chat Functions:
# FIXME - add these functions to the Personas library
def chat_with_character(user_message, history, char_data, api_name_input, api_key):
if char_data is None:
return history, "Please import a character card first."
bot_message = generate_writing_feedback(user_message, char_data['name'], "Overall", api_name_input,
api_key)
history.append((user_message, bot_message))
return history, ""
def import_character_card(file):
if file is None:
logging.warning("No file provided for character card import")
return None
try:
if file.name.lower().endswith(('.png', '.webp')):
logging.info(f"Attempting to import character card from image: {file.name}")
json_data = extract_json_from_image(file)
if json_data:
logging.info("JSON data extracted from image, attempting to parse")
card_data = import_character_card_json(json_data)
if card_data:
# Save the image data
with Image.open(file) as img:
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
card_data['image'] = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
return card_data
else:
logging.warning("No JSON data found in the image")
else:
logging.info(f"Attempting to import character card from JSON file: {file.name}")
content = file.read().decode('utf-8')
return import_character_card_json(content)
except Exception as e:
logging.error(f"Error importing character card: {e}")
return None
def import_character_card_json(json_content):
try:
# Remove any leading/trailing whitespace
json_content = json_content.strip()
# Log the first 100 characters of the content
logging.debug(f"JSON content (first 100 chars): {json_content[:100]}...")
card_data = json.loads(json_content)
logging.debug(f"Parsed JSON data keys: {list(card_data.keys())}")
if 'spec' in card_data and card_data['spec'] == 'chara_card_v2':
logging.info("Detected V2 character card")
return card_data['data']
else:
logging.info("Assuming V1 character card")
return card_data
except json.JSONDecodeError as e:
logging.error(f"JSON decode error: {e}")
logging.error(f"Problematic JSON content: {json_content[:500]}...")
except Exception as e:
logging.error(f"Unexpected error parsing JSON: {e}")
return None
def extract_json_from_image(image_file):
logging.debug(f"Attempting to extract JSON from image: {image_file.name}")
try:
with Image.open(image_file) as img:
logging.debug("Image opened successfully")
metadata = img.info
if 'chara' in metadata:
logging.debug("Found 'chara' in image metadata")
chara_content = metadata['chara']
logging.debug(f"Content of 'chara' metadata (first 100 chars): {chara_content[:100]}...")
try:
decoded_content = base64.b64decode(chara_content).decode('utf-8')
logging.debug(f"Decoded content (first 100 chars): {decoded_content[:100]}...")
return decoded_content
except Exception as e:
logging.error(f"Error decoding base64 content: {e}")
logging.debug("'chara' not found in metadata, checking for base64 encoded data")
raw_data = img.tobytes()
possible_json = raw_data.split(b'{', 1)[-1].rsplit(b'}', 1)[0]
if possible_json:
try:
decoded = base64.b64decode(possible_json).decode('utf-8')
if decoded.startswith('{') and decoded.endswith('}'):
logging.debug("Found and decoded base64 JSON data")
return '{' + decoded + '}'
except Exception as e:
logging.error(f"Error decoding base64 data: {e}")
logging.warning("No JSON data found in the image")
except Exception as e:
logging.error(f"Error extracting JSON from image: {e}")
return None
def load_chat_history(file):
try:
content = file.read().decode('utf-8')
chat_data = json.loads(content)
return chat_data['history'], chat_data['character']
except Exception as e:
logging.error(f"Error loading chat history: {e}")
return None, None
#
# End of X
######################################################################################################################
#
# Multi-Character Chat Interface
# FIXME - refactor and move these functions to the Character_Chat library so that it uses the same functions
def character_interaction_setup():
characters = load_characters()
return characters, [], None, None
def extract_character_response(response: Union[str, Tuple]) -> str:
if isinstance(response, tuple):
# If it's a tuple, try to extract the first string element
for item in response:
if isinstance(item, str):
return item.strip()
# If no string found, return a default message
return "I'm not sure how to respond."
elif isinstance(response, str):
# If it's already a string, just return it
return response.strip()
else:
# For any other type, return a default message
return "I'm having trouble forming a response."
# def process_character_response(response: str) -> str:
# # Remove any leading explanatory text before the first '---'
# parts = response.split('---')
# if len(parts) > 1:
# return '---' + '---'.join(parts[1:])
# return response.strip()
def process_character_response(response: Union[str, Tuple]) -> str:
if isinstance(response, tuple):
response = ' '.join(str(item) for item in response if isinstance(item, str))
if isinstance(response, str):
# Remove any leading explanatory text before the first '---'
parts = response.split('---')
if len(parts) > 1:
return '---' + '---'.join(parts[1:])
return response.strip()
else:
return "I'm having trouble forming a response."
def character_turn(characters: Dict, conversation: List[Tuple[str, str]],
current_character: str, other_characters: List[str],
api_endpoint: str, api_key: str, temperature: float,
scenario: str = "") -> Tuple[List[Tuple[str, str]], str]:
if not current_character or current_character not in characters:
return conversation, current_character
if not conversation and scenario:
conversation.append(("Scenario", scenario))
current_char = characters[current_character]
other_chars = [characters[char] for char in other_characters if char in characters and char != current_character]
prompt = f"{current_char['name']}'s personality: {current_char['personality']}\n"
for char in other_chars:
prompt += f"{char['name']}'s personality: {char['personality']}\n"
prompt += "Conversation so far:\n" + "\n".join([f"{sender}: {message}" for sender, message in conversation])
prompt += f"\n\nHow would {current_char['name']} respond?"
try:
response = chat_wrapper(prompt, conversation, {}, [], api_endpoint, api_key, "", None, False, temperature, "")
processed_response = process_character_response(response)
conversation.append((current_char['name'], processed_response))
except Exception as e:
error_message = f"Error generating response: {str(e)}"
conversation.append((current_char['name'], error_message))
return conversation, current_character
def character_interaction(character1: str, character2: str, api_endpoint: str, api_key: str,
num_turns: int, scenario: str, temperature: float,
user_interjection: str = "") -> List[str]:
characters = load_characters()
char1 = characters[character1]
char2 = characters[character2]
conversation = []
current_speaker = char1
other_speaker = char2
# Add scenario to the conversation start
if scenario:
conversation.append(f"Scenario: {scenario}")
for turn in range(num_turns):
# Construct the prompt for the current speaker
prompt = f"{current_speaker['name']}'s personality: {current_speaker['personality']}\n"
prompt += f"{other_speaker['name']}'s personality: {other_speaker['personality']}\n"
prompt += f"Conversation so far:\n" + "\n".join(
[msg if isinstance(msg, str) else f"{msg[0]}: {msg[1]}" for msg in conversation])
# Add user interjection if provided
if user_interjection and turn == num_turns // 2:
prompt += f"\n\nUser interjection: {user_interjection}\n"
conversation.append(f"User: {user_interjection}")
prompt += f"\n\nHow would {current_speaker['name']} respond?"
# FIXME - figure out why the double print is happening
# Get response from the LLM
response = chat_wrapper(prompt, conversation, {}, [], api_endpoint, api_key, "", None, False, temperature, "")
# Add the response to the conversation
conversation.append((current_speaker['name'], response))
# Switch speakers
current_speaker, other_speaker = other_speaker, current_speaker
# Convert the conversation to a list of strings for output
return [f"{msg[0]}: {msg[1]}" if isinstance(msg, tuple) else msg for msg in conversation]
def create_multiple_character_chat_tab():
try:
default_value = None
if default_api_endpoint:
if default_api_endpoint in global_api_endpoints:
default_value = format_api_name(default_api_endpoint)
else:
logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints")
except Exception as e:
logging.error(f"Error setting default API endpoint: {str(e)}")
default_value = None
with gr.TabItem("Multi-Character Chat", visible=True):
characters, conversation, current_character, other_character = character_interaction_setup()
with gr.Blocks() as character_interaction:
gr.Markdown("# Multi-Character Chat")
with gr.Row():
num_characters = gr.Dropdown(label="Number of Characters", choices=["2", "3", "4"], value="2")
character_selectors = [gr.Dropdown(label=f"Character {i + 1}", choices=list(characters.keys())) for i in
range(4)]
# Refactored API selection dropdown
api_endpoint = gr.Dropdown(
choices=["None"] + [format_api_name(api) for api in global_api_endpoints],
value=default_value,
label="API for Interaction (Optional)"
)
api_key = gr.Textbox(label="API Key (if required)", type="password")
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.7)
scenario = gr.Textbox(label="Scenario (optional)", lines=3)
chat_display = gr.Chatbot(label="Character Interaction")
current_index = gr.State(0)
next_turn_btn = gr.Button("Next Turn")
narrator_input = gr.Textbox(label="Narrator Input", placeholder="Add a narration or description...")
add_narration_btn = gr.Button("Add Narration")
error_box = gr.Textbox(label="Error Messages", visible=False)
reset_btn = gr.Button("Reset Conversation")
chat_media_name = gr.Textbox(label="Custom Chat Name(optional)", visible=True)
save_chat_history_to_db = gr.Button("Save Chat History to DataBase")
def update_character_selectors(num):
return [gr.update(visible=True) if i < int(num) else gr.update(visible=False) for i in range(4)]
num_characters.change(
update_character_selectors,
inputs=[num_characters],
outputs=character_selectors
)
def reset_conversation():
return [], 0, gr.update(value=""), gr.update(value="")
def take_turn(conversation, current_index, char1, char2, char3, char4, api_endpoint, api_key, temperature,
scenario):
char_selectors = [char for char in [char1, char2, char3, char4] if char] # Remove None values
num_chars = len(char_selectors)
if num_chars == 0:
return conversation, current_index # No characters selected, return without changes
if not conversation:
conversation = []
if scenario:
conversation.append(("Scenario", scenario))
current_character = char_selectors[current_index % num_chars]
next_index = (current_index + 1) % num_chars
prompt = f"Character speaking: {current_character}\nOther characters: {', '.join(char for char in char_selectors if char != current_character)}\n"
prompt += "Generate the next part of the conversation, including character dialogues and actions. Characters should speak in first person."
response, new_conversation, _ = chat_wrapper(prompt, conversation, {}, [], api_endpoint, api_key, "",
None, False, temperature, "")
# Format the response
formatted_lines = []
for line in response.split('\n'):
if ':' in line:
speaker, text = line.split(':', 1)
formatted_lines.append(f"**{speaker.strip()}**: {text.strip()}")
else:
formatted_lines.append(line)
formatted_response = '\n'.join(formatted_lines)
# Update the last message in the conversation with the formatted response
if new_conversation:
new_conversation[-1] = (new_conversation[-1][0], formatted_response)
else:
new_conversation.append((current_character, formatted_response))
return new_conversation, next_index
def add_narration(narration, conversation):
if narration:
conversation.append(("Narrator", narration))
return conversation, ""
def take_turn_with_error_handling(conversation, current_index, char1, char2, char3, char4, api_endpoint,
api_key, temperature, scenario):
try:
new_conversation, next_index = take_turn(conversation, current_index, char1, char2, char3, char4,
api_endpoint, api_key, temperature, scenario)
return new_conversation, next_index, gr.update(visible=False, value="")
except Exception as e:
error_message = f"An error occurred: {str(e)}"
return conversation, current_index, gr.update(visible=True, value=error_message)
# Define States for conversation_id and media_content, which are required for saving chat history
media_content = gr.State({})
conversation_id = gr.State(str(uuid.uuid4()))
next_turn_btn.click(
take_turn_with_error_handling,
inputs=[chat_display, current_index] + character_selectors + [api_endpoint, api_key, temperature,
scenario],
outputs=[chat_display, current_index, error_box]
)
add_narration_btn.click(
add_narration,
inputs=[narrator_input, chat_display],
outputs=[chat_display, narrator_input]
)
reset_btn.click(
reset_conversation,
outputs=[chat_display, current_index, scenario, narrator_input]
)
# FIXME - Implement saving chat history to database; look at Chat_UI.py for reference
save_chat_history_to_db.click(
save_chat_history_to_db_wrapper,
inputs=[chat_display, conversation_id, media_content, chat_media_name],
outputs=[conversation_id, gr.Textbox(label="Save Status")]
)
return character_interaction
#
# End of Multi-Character chat tab
########################################################################################################################
#
# Narrator-Controlled Conversation Tab
# From `Fuzzlewumper` on Reddit.
def create_narrator_controlled_conversation_tab():
try:
default_value = None
if default_api_endpoint:
if default_api_endpoint in global_api_endpoints:
default_value = format_api_name(default_api_endpoint)
else:
logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints")
except Exception as e:
logging.error(f"Error setting default API endpoint: {str(e)}")
default_value = None
with gr.TabItem("Narrator-Controlled Conversation", visible=True):
gr.Markdown("# Narrator-Controlled Conversation")
with gr.Row():
with gr.Column(scale=1):
# Refactored API selection dropdown
api_endpoint = gr.Dropdown(
choices=["None"] + [format_api_name(api) for api in global_api_endpoints],
value=default_value,
label="API for Chat Interaction (Optional)"
)
api_key = gr.Textbox(label="API Key (if required)", type="password")
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.7)
with gr.Column(scale=2):
narrator_input = gr.Textbox(
label="Narrator Input",
placeholder="Set the scene or provide context...",
lines=3
)
character_inputs = []
for i in range(4): # Allow up to 4 characters
with gr.Row():
name = gr.Textbox(label=f"Character {i + 1} Name")
description = gr.Textbox(label=f"Character {i + 1} Description", lines=3)
character_inputs.append((name, description))
conversation_display = gr.Chatbot(label="Conversation", height=400)
user_input = gr.Textbox(label="Your Input (optional)", placeholder="Add your own dialogue or action...")
with gr.Row():
generate_btn = gr.Button("Generate Next Interaction")
reset_btn = gr.Button("Reset Conversation")
chat_media_name = gr.Textbox(label="Custom Chat Name(optional)", visible=True)
save_chat_history_to_db = gr.Button("Save Chat History to DataBase")
error_box = gr.Textbox(label="Error Messages", visible=False)
# Define States for conversation_id and media_content, which are required for saving chat history
conversation_id = gr.State(str(uuid.uuid4()))
media_content = gr.State({})
def generate_interaction(conversation, narrator_text, user_text, api_endpoint, api_key, temperature,
*character_data):
try:
characters = [{"name": name.strip(), "description": desc.strip()}
for name, desc in zip(character_data[::2], character_data[1::2])
if name.strip() and desc.strip()]
if not characters:
raise ValueError("At least one character must be defined.")
prompt = f"Narrator: {narrator_text}\n\n"
for char in characters:
prompt += f"Character '{char['name']}': {char['description']}\n"
prompt += "\nGenerate the next part of the conversation, including character dialogues and actions. "
prompt += "Characters should speak in first person. "
if user_text:
prompt += f"\nIncorporate this user input: {user_text}"
prompt += "\nResponse:"
response, conversation, _ = chat_wrapper(prompt, conversation, {}, [], api_endpoint, api_key, "", None,
False, temperature, "")
# Format the response
formatted_lines = []
for line in response.split('\n'):
if ':' in line:
speaker, text = line.split(':', 1)
formatted_lines.append(f"**{speaker.strip()}**: {text.strip()}")
else:
formatted_lines.append(line)
formatted_response = '\n'.join(formatted_lines)
# Update the last message in the conversation with the formatted response
if conversation:
conversation[-1] = (conversation[-1][0], formatted_response)
else:
conversation.append((None, formatted_response))
return conversation, gr.update(value=""), gr.update(value=""), gr.update(visible=False, value="")
except Exception as e:
error_message = f"An error occurred: {str(e)}"
return conversation, gr.update(), gr.update(), gr.update(visible=True, value=error_message)
def reset_conversation():
return [], gr.update(value=""), gr.update(value=""), gr.update(visible=False, value="")
generate_btn.click(
generate_interaction,
inputs=[conversation_display, narrator_input, user_input, api_endpoint, api_key, temperature] +
[input for char_input in character_inputs for input in char_input],
outputs=[conversation_display, narrator_input, user_input, error_box]
)
reset_btn.click(
reset_conversation,
outputs=[conversation_display, narrator_input, user_input, error_box]
)
# FIXME - Implement saving chat history to database; look at Chat_UI.py for reference
save_chat_history_to_db.click(
save_chat_history_to_db_wrapper,
inputs=[conversation_display, conversation_id, media_content, chat_media_name],
outputs=[conversation_id, gr.Textbox(label="Save Status")]
)
return api_endpoint, api_key, temperature, narrator_input, conversation_display, user_input, generate_btn, reset_btn, error_box
#
# End of Narrator-Controlled Conversation tab
########################################################################################################################