File size: 9,668 Bytes
15411db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class ChatState():
  """                                                                                                                                                                                                                                                                                                     
  Manages the conversation history for a turn-based chatbot                                                                                                                                                                                                                                               
  Follows the turn-based conversation guidelines for the Gemma family of models                                                                                                                                                                                                                           
  documented at https://ai.google.dev/gemma/docs/formatting                                                                                                                                                                                                                                               
  """


  __START_TURN_USER__ = "Instruction:\n"
  __START_TURN_MODEL__ = "\n\nResponse:\n"
  __END_TURN__ = ""#"\n"                                                                                                                                                                                                                                                                                  


  def __init__(self, model, system=""):
    """                                                                                                                                                                                                                                                                                                   
    Initializes the chat state.                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                                          
    Args:                                                                                                                                                                                                                                                                                                 
        model: The language model to use for generating responses.                                                                                                                                                                                                                                        
        system: (Optional) System instructions or bot description.                                                                                                                                                                                                                                        
    """
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      """                                                                                                                                                                                                                                                                                                 
      Adds a user message to the history with start/end turn markers.                                                                                                                                                                                                                                     
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      """                                                                                                                                                                                                                                                                                                 
      Adds a model response to the history with the start turn marker.                                                                                                                                                                                                                                    
      Model will generate end turn marker.                                                                                                                                                                                                                                                                
      """
      self.history.append(self.__START_TURN_MODEL__ + message+ "\n")

  def get_history(self):
      """                                                                                                                                                                                                                                                                                                 
      Returns the entire chat history as a single string.                                                                                                                                                                                                                                                 
      """
      return "".join([*self.history])

  def get_full_prompt(self):
    """                                                                                                                                                                                                                                                                                                   
    Builds the prompt for the language model, including history and system description.                                                                                                                                                                                                                   
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """                                                                                                                                                                                                                                                                                                   
    Handles sending a user message and getting a model response.                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                                          
    Args:                                                                                                                                                                                                                                                                                                 
        message: The user's message.                                                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                                                                                          
    Returns:                                                                                                                                                                                                                                                                                              
        The model's response.                                                                                                                                                                                                                                                                             
    """
    self.add_to_history_as_user(message)
    prompt   = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=4096)
    result = response.replace(prompt, "")  # Extract only the new response                                                                                                                                                                                                                                
    self.add_to_history_as_model(result)
    return result