dual_window / state_manager.py
HuanzhiMao's picture
Fixing Dual Window Bug and Adding Single Model Demo (#1)
0157229 verified
raw
history blame
5.2 kB
from app_utils import *
from state import State
class StateManager:
"""
Manages the state of both single and dual chatbot interfaces.
Attributes:
single_bot (State): State object for the single chatbot interface
dual_bot1 (State): State object for the first bot in dual interface
dual_bot2 (State): State object for the second bot in dual interface
single_current_category (str): Current category selected in single interface
dual_current_category (str): Current category selected in dual interface
"""
def __init__(self, database):
self.single_bot = State(MODELS[0], DEFAULT_TEMPERATURE_1, database, SINGLE_MODEL_BOT_EXAMPLE_SETTING, "single_model_bot")
self.dual_bot1 = State(MODELS[0], DEFAULT_TEMPERATURE_1, database, DUAL_MODEL_BOT_1_EXAMPLE_SETTING, "dual_model_bot_1")
self.dual_bot2 = State(MODELS[1], DEFAULT_TEMPERATURE_2, database, DUAL_MODEL_BOT_2_EXAMPLE_SETTING, "dual_model_bot_2")
self.single_current_category = CATEGORIES[0]
self.dual_current_category = CATEGORIES[0]
self.initialize()
def initialize(self):
self.single_bot.initialize(self.single_current_category)
self.dual_bot1.initialize(self.dual_current_category)
self.dual_bot2.initialize(self.dual_current_category)
def add_message_helper(self, message, bot):
for x in message["files"]:
bot.history.append({"role": "user", "content": {"path": x}})
if message["text"] is not None:
bot.history.append({"role": "user", "content": message["text"]})
bot.test_entry["question"] = [{"role": "user", "content": message["text"]}]
def add_message(self, message, target=None):
if target is None:
self.add_message_helper(message, self.single_bot)
return self.single_bot.history, gr.MultimodalTextbox(value=None, interactive=False)
if target in ["Model 1", "Both"]:
print("Adding message to bot1: {message}")
self.add_message_helper(message, self.dual_bot1)
if target in ["Model 2", "Both"]:
print("Adding message to bot2: {message}")
self.add_message_helper(message, self.dual_bot2)
return self.dual_bot1.history, self.dual_bot2.history, gr.MultimodalTextbox(value=None, interactive=False)
def get_reponse_single(self):
bot_generation = self.single_bot.response()
while True:
stop = True
try:
generation_history = next(bot_generation)
stop = False
yield generation_history
except StopIteration:
pass
if stop:
break
def get_reponse_dual(self, target):
if target != "Both":
bot_generation = self.dual_bot1.response() if target == "Model 1" else self.dual_bot2.response()
while True:
stop = True
try:
generation_history = next(bot_generation)
stop = False
if target == "Model 1":
yield generation_history, self.dual_bot2.history
else:
yield self.dual_bot1.history, generation_history
except StopIteration:
pass
if stop:
break
else:
bot1_generation = self.dual_bot1.response()
bot2_generation = self.dual_bot2.response()
while True:
stop = True
try:
generation_history_1 = next(bot1_generation)
stop = False
except StopIteration:
pass
try:
generation_history_2 = next(bot2_generation)
stop = False
except StopIteration:
pass
yield generation_history_1, generation_history_2
if stop:
break
def single_update_category_and_load_config(self, category):
self.single_current_category = category
self.single_bot.update_category_and_load_config(category)
return category
def dual_update_category_and_load_config(self, category):
self.dual_current_category = category
self.dual_bot1.update_category_and_load_config(category)
self.dual_bot2.update_category_and_load_config(category)
return category
def single_load_example_and_update(self, example):
model, temp, category, message = self.single_bot.load_example_and_update(example)
return model, temp, category, message
def dual_load_example_and_update(self, example):
model_1, temp_1, category_1, message_1 = self.dual_bot1.load_example_and_update(example)
model_2, temp_2, category_2, message_2 = self.dual_bot2.load_example_and_update(example)
assert category_1 == category_2
assert message_1 == message_2
return model_1, temp_1, model_2, temp_2, category_1, message_1