Spaces:
Running
Running
class ChatState(): | |
""" | |
Manages the conversation history for a turn-based chatbot | |
Follows the turn-based conversation guidelines for the Gemma family of models | |
documented at https://ai.google.dev/gemma/docs/formatting | |
""" | |
__START_TURN_USER__ = "Instruction:\n" | |
__START_TURN_MODEL__ = "\n\nResponse:\n" | |
__END_TURN__ = ""#"\n" | |
def __init__(self, model, system=""): | |
""" | |
Initializes the chat state. | |
Args: | |
model: The language model to use for generating responses. | |
system: (Optional) System instructions or bot description. | |
""" | |
self.model = model | |
self.system = system | |
self.history = [] | |
def add_to_history_as_user(self, message): | |
""" | |
Adds a user message to the history with start/end turn markers. | |
""" | |
self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__) | |
def add_to_history_as_model(self, message): | |
""" | |
Adds a model response to the history with the start turn marker. | |
Model will generate end turn marker. | |
""" | |
self.history.append(self.__START_TURN_MODEL__ + message+ "\n") | |
def get_history(self): | |
""" | |
Returns the entire chat history as a single string. | |
""" | |
return "".join([*self.history]) | |
def get_full_prompt(self): | |
""" | |
Builds the prompt for the language model, including history and system description. | |
""" | |
prompt = self.get_history() + self.__START_TURN_MODEL__ | |
if len(self.system)>0: | |
prompt = self.system + "\n" + prompt | |
return prompt | |
def send_message(self, message): | |
""" | |
Handles sending a user message and getting a model response. | |
Args: | |
message: The user's message. | |
Returns: | |
The model's response. | |
""" | |
self.add_to_history_as_user(message) | |
prompt = self.get_full_prompt() | |
response = self.model.generate(prompt, max_length=4096) | |
result = response.replace(prompt, "") # Extract only the new response | |
self.add_to_history_as_model(result) | |
return result | |