OpenLLM / tools /memory.py
xusenlin's picture
Upload 13 files
20048e3
from typing import Dict, Any, List
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema import (
BaseMessage,
HumanMessage,
AIMessage,
)
from .registry import BaseParent
class ChatGLMConversationBufferWindowMemory(ConversationBufferWindowMemory):
human_prefix: str = "问"
ai_prefix: str = "答"
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
if self.return_messages:
buffer: Any = self.buffer[-self.k * 2:]
else:
buffer = self.get_buffer_string(
self.buffer[-self.k * 2:],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: buffer}
@staticmethod
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "问", ai_prefix: str = "答"
) -> str:
"""Get buffer string of messages."""
string_messages, i = [], 0
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
string_messages.append(f"[Round {i}]\n{role}{m.content}")
i += 1
elif isinstance(m, AIMessage):
role = ai_prefix
string_messages.append(f"{role}{m.content}")
else:
raise ValueError(f"Got unsupported message type: {m}")
return "\n".join(string_messages) + f"\n[Round {i}]"
class ChineseAlpacaConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory):
human_prefix: str = "### Instruction"
ai_prefix: str = "### Response"
@staticmethod
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "### Instruction", ai_prefix: str = "### Response"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
string_messages.append(f"{role}:\n\n{m.content}")
elif isinstance(m, AIMessage):
role = ai_prefix
string_messages.append(f"{role}:\n\n{m.content}")
else:
raise ValueError(f"Got unsupported message type: {m}")
return "\n\n".join(string_messages)
class FireFlyConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory):
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
@staticmethod
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "Assistant"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
string_messages.append(f"<s>{m.content}</s>")
elif isinstance(m, AIMessage):
string_messages.append(f"</s>{m.content}</s>")
else:
raise ValueError(f"Got unsupported message type: {m}")
return "".join(string_messages)
class PhoenixConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory):
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
@staticmethod
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "Assistant"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: <s>{m.content}</s>")
return "".join(string_messages)
class MossConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory):
human_prefix: str = "<|Human|>"
ai_prefix: str = "<|MOSS|>"
@staticmethod
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "<|Human|>", ai_prefix: str = "<|MOSS|>"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
string_messages.append(f"{human_prefix}: {m.content}<eoh>")
elif isinstance(m, AIMessage):
string_messages.append(f"{ai_prefix}: {m.content}<eom>")
else:
raise ValueError(f"Got unsupported message type: {m}")
return "\n".join(string_messages)
class GuanacoConversationBufferWindowMemory(ChatGLMConversationBufferWindowMemory):
human_prefix: str = "### Human"
ai_prefix: str = "### Assistant"
@staticmethod
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "### Human", ai_prefix: str = "### Assistant"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
string_messages.append(f"{human_prefix}: {m.content}")
elif isinstance(m, AIMessage):
string_messages.append(f"{ai_prefix}: {m.content}")
else:
raise ValueError(f"Got unsupported message type: {m}")
return "\n".join(string_messages)
class CustomConversationBufferWindowMemory(BaseParent):
registry = {}
CustomConversationBufferWindowMemory.add_to_registry("gpt-3.5-turbo", ConversationBufferWindowMemory)
CustomConversationBufferWindowMemory.add_to_registry("chatglm", ChatGLMConversationBufferWindowMemory)
CustomConversationBufferWindowMemory.add_to_registry("chinese-llama-alpaca", ChineseAlpacaConversationBufferWindowMemory)
CustomConversationBufferWindowMemory.add_to_registry("firefly", FireFlyConversationBufferWindowMemory)
CustomConversationBufferWindowMemory.add_to_registry("phoenix", PhoenixConversationBufferWindowMemory)
CustomConversationBufferWindowMemory.add_to_registry("moss", MossConversationBufferWindowMemory)
CustomConversationBufferWindowMemory.add_to_registry("guanaco", GuanacoConversationBufferWindowMemory)