Spaces:
Running
Running
from typing import List | |
import tiktoken | |
from langchain_core.messages import BaseMessage, ToolMessage, HumanMessage, AIMessage, SystemMessage, trim_messages | |
def str_token_counter(text: str) -> int: | |
enc = tiktoken.get_encoding("o200k_base") | |
return len(enc.encode(text)) | |
def tiktoken_counter(messages: List[BaseMessage]) -> int: | |
"""Approximately reproduce https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |
For simplicity only supports str Message.contents. | |
""" | |
num_tokens = 3 # every reply is primed with <|start|>assistant<|message|> | |
tokens_per_message = 3 | |
tokens_per_name = 1 | |
for msg in messages: | |
if isinstance(msg, HumanMessage): | |
role = "user" | |
elif isinstance(msg, AIMessage): | |
role = "assistant" | |
elif isinstance(msg, ToolMessage): | |
role = "tool" | |
elif isinstance(msg, SystemMessage): | |
role = "system" | |
else: | |
raise ValueError(f"Unsupported messages type {msg.__class__}") | |
num_tokens += ( | |
tokens_per_message | |
+ str_token_counter(role) | |
+ str_token_counter(msg.content) | |
) | |
if msg.name: | |
num_tokens += tokens_per_name + str_token_counter(msg.name) | |
return num_tokens | |
def convert_to_openai_messages(messages: List[BaseMessage]) -> List[dict]: | |
"""Convert LangChain messages to OpenAI format.""" | |
openai_messages = [] | |
for msg in messages: | |
message_dict = {"content": msg.content} | |
if isinstance(msg, HumanMessage): | |
message_dict["role"] = "user" | |
elif isinstance(msg, AIMessage): | |
message_dict["role"] = "assistant" | |
elif isinstance(msg, SystemMessage): | |
message_dict["role"] = "system" | |
elif isinstance(msg, ToolMessage): | |
message_dict["role"] = "tool" | |
else: | |
raise ValueError(f"Unsupported message type: {msg.__class__}") | |
if msg.name: | |
message_dict["name"] = msg.name | |
openai_messages.append(message_dict) | |
return openai_messages | |
def trim_messages_openai(messages: List[BaseMessage]) -> List[dict]: | |
"""Trim LangChain messages and convert to OpenAI format.""" | |
trimmed_messages = trim_messages( | |
messages, | |
token_counter=tiktoken_counter, | |
strategy="last", | |
max_tokens=45, | |
start_on="human", | |
end_on=("human", "tool"), | |
include_system=True, | |
) | |
openai_format_messages = convert_to_openai_messages(trimmed_messages) | |
return openai_format_messages | |
# Test | |
# messages = [SystemMessage(content="You are a helpful assistant."), HumanMessage(query)] | |
# openai_format_messages = trim_messages_openai(messages) |