Matou-Garou / convex /agent /conversation.ts
Jofthomas's picture
Jofthomas HF staff
prompt change
34480d2
import { v } from 'convex/values';
import { Id } from '../_generated/dataModel';
import { ActionCtx, internalQuery } from '../_generated/server';
import { LLMMessage, chatCompletion } from '../util/llm';
import * as memory from './memory';
import { api, internal } from '../_generated/api';
import * as embeddingsCache from './embeddingsCache';
import { GameId, conversationId, playerId } from '../aiTown/ids';
import { NUM_MEMORIES_TO_SEARCH } from '../constants';
const selfInternal = internal.agent.conversation;
export async function startConversationMessage(
ctx: ActionCtx,
worldId: Id<'worlds'>,
conversationId: GameId<'conversations'>,
playerId: GameId<'players'>,
otherPlayerId: GameId<'players'>,
) {
const { player, otherPlayer, agent, otherAgent, lastConversation } = await ctx.runQuery(
selfInternal.queryPromptData,
{
worldId,
playerId,
otherPlayerId,
conversationId,
},
);
const embedding = await embeddingsCache.fetch(
ctx,
`${player.name} is talking to ${otherPlayer.name}`,
);
const memories = await memory.searchMemories(
ctx,
player.id as GameId<'players'>,
embedding,
Number(process.env.NUM_MEMORIES_TO_SEARCH) || NUM_MEMORIES_TO_SEARCH,
);
const memoryWithOtherPlayer = memories.find(
(m) => m.data.type === 'conversation' && m.data.playerIds.includes(otherPlayerId),
);
const prompt = [
`You are ${player.name}, and you just started a conversation with ${otherPlayer.name}.You should act and speak as a human, and answer with natural and short answers.`,
];
prompt.push(...agentPrompts(otherPlayer, agent, otherAgent ?? null));
prompt.push(...previousConversationPrompt(otherPlayer, lastConversation));
prompt.push(...relatedMemoriesPrompt(memories));
if (memoryWithOtherPlayer) {
prompt.push(
`Be sure to include some detail or question about a previous conversation in your greeting.`,
);
}
prompt.push(`${player.name}:`);
const { content } = await chatCompletion({
messages: [
{
role: 'user',
content: prompt.join('\n'),
},
],
max_tokens: 50,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}
export async function continueConversationMessage(
ctx: ActionCtx,
worldId: Id<'worlds'>,
conversationId: GameId<'conversations'>,
playerId: GameId<'players'>,
otherPlayerId: GameId<'players'>,
) {
const { player, otherPlayer, conversation, agent, otherAgent } = await ctx.runQuery(
selfInternal.queryPromptData,
{
worldId,
playerId,
otherPlayerId,
conversationId,
},
);
const now = Date.now();
const started = new Date(conversation.created);
const embedding = await embeddingsCache.fetch(
ctx,
`What do you think about ${otherPlayer.name}?`,
);
const memories = await memory.searchMemories(ctx, player.id as GameId<'players'>, embedding, 3);
const prompt = [
`You are ${player.name}, and you're currently in a conversation with ${otherPlayer.name}.`,
`The conversation started at ${started.toLocaleString()}. It's now ${now.toLocaleString()}.`,
];
prompt.push(...agentPrompts(otherPlayer, agent, otherAgent ?? null));
prompt.push(...relatedMemoriesPrompt(memories));
prompt.push(
`Below is the current chat history between you and ${otherPlayer.name}.`,
`DO NOT greet them again. Do NOT use the word "Hey" too often. Your response should be brief and within 50 characters.`,
);
const llmMessages: LLMMessage[] = [
{
role: 'user',
content: prompt.join('\n'),
},
...(await previousMessages(
ctx,
worldId,
player,
otherPlayer,
conversation.id as GameId<'conversations'>,
)),
];
llmMessages.push({ role: 'user', content: `${player.name}:` });
const { content } = await chatCompletion({
messages: llmMessages,
max_tokens: 50,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}
export async function leaveConversationMessage(
ctx: ActionCtx,
worldId: Id<'worlds'>,
conversationId: GameId<'conversations'>,
playerId: GameId<'players'>,
otherPlayerId: GameId<'players'>,
) {
const { player, otherPlayer, conversation, agent, otherAgent } = await ctx.runQuery(
selfInternal.queryPromptData,
{
worldId,
playerId,
otherPlayerId,
conversationId,
},
);
const prompt = [
`You are ${player.name}, and you're currently in a conversation with ${otherPlayer.name}.`,
`You've decided to leave the question and would like to politely tell them you're leaving the conversation.`,
];
prompt.push(...agentPrompts(otherPlayer, agent, otherAgent ?? null));
prompt.push(
`Below is the current chat history between you and ${otherPlayer.name}.`,
`How would you like to tell them that you're leaving? Your response should be brief and within 50 characters.`,
);
const llmMessages: LLMMessage[] = [
{
role: 'user',
content: prompt.join('\n'),
},
...(await previousMessages(
ctx,
worldId,
player,
otherPlayer,
conversation.id as GameId<'conversations'>,
)),
];
llmMessages.push({ role: 'user', content: `${player.name}:` });
const { content } = await chatCompletion({
messages: llmMessages,
max_tokens: 50,
stream: true,
stop: stopWords(otherPlayer.name, player.name),
});
return content;
}
function agentPrompts(
otherPlayer: { name: string },
agent: { identity: string; plan: string } | null,
otherAgent: { identity: string; plan: string } | null,
): string[] {
const prompt = [];
if (agent) {
prompt.push(`You are in a game of werewolf, you are playing as a ${playerId.type} . Villagers shall trt to discover who is werewolf and vote for them, while werewolf shall eliminate villagers to WIN. You are trying to win. About you: ${agent.identity}`);
prompt.push(`Your goals for the conversation: ${agent.plan}. Do not mention you are in a game`);
}
if (otherAgent) {
prompt.push(`About ${otherPlayer.name}: ${otherAgent.identity}`);
}
return prompt;
}
function previousConversationPrompt(
otherPlayer: { name: string },
conversation: { created: number } | null,
): string[] {
const prompt = [];
if (conversation) {
const prev = new Date(conversation.created);
const now = new Date();
prompt.push(
`Last time you chatted with ${
otherPlayer.name
} it was ${prev.toLocaleString()}. It's now ${now.toLocaleString()}.`,
);
}
return prompt;
}
function relatedMemoriesPrompt(memories: memory.Memory[]): string[] {
const prompt = [];
if (memories.length > 0) {
prompt.push(`Here are some related memories in decreasing relevance order:`);
for (const memory of memories) {
prompt.push(' - ' + memory.description);
}
}
return prompt;
}
async function previousMessages(
ctx: ActionCtx,
worldId: Id<'worlds'>,
player: { id: string; name: string },
otherPlayer: { id: string; name: string },
conversationId: GameId<'conversations'>,
) {
const llmMessages: LLMMessage[] = [];
const prevMessages = await ctx.runQuery(api.messages.listMessages, { worldId, conversationId });
for (const message of prevMessages) {
const author = message.author === player.id ? player : otherPlayer;
const recipient = message.author === player.id ? otherPlayer : player;
llmMessages.push({
role: 'user',
content: `${author.name} to ${recipient.name}: ${message.text}`,
});
}
return llmMessages;
}
export const queryPromptData = internalQuery({
args: {
worldId: v.id('worlds'),
playerId,
otherPlayerId: playerId,
conversationId,
},
handler: async (ctx, args) => {
const world = await ctx.db.get(args.worldId);
if (!world) {
throw new Error(`World ${args.worldId} not found`);
}
const player = world.players.find((p) => p.id === args.playerId);
if (!player) {
throw new Error(`Player ${args.playerId} not found`);
}
const playerDescription = await ctx.db
.query('playerDescriptions')
.withIndex('worldId', (q) => q.eq('worldId', args.worldId).eq('playerId', args.playerId))
.first();
if (!playerDescription) {
throw new Error(`Player description for ${args.playerId} not found`);
}
const otherPlayer = world.players.find((p) => p.id === args.otherPlayerId);
if (!otherPlayer) {
throw new Error(`Player ${args.otherPlayerId} not found`);
}
const otherPlayerDescription = await ctx.db
.query('playerDescriptions')
.withIndex('worldId', (q) => q.eq('worldId', args.worldId).eq('playerId', args.otherPlayerId))
.first();
if (!otherPlayerDescription) {
throw new Error(`Player description for ${args.otherPlayerId} not found`);
}
const conversation = world.conversations.find((c) => c.id === args.conversationId);
if (!conversation) {
throw new Error(`Conversation ${args.conversationId} not found`);
}
const agent = world.agents.find((a) => a.playerId === args.playerId);
if (!agent) {
throw new Error(`Player ${args.playerId} not found`);
}
const agentDescription = await ctx.db
.query('agentDescriptions')
.withIndex('worldId', (q) => q.eq('worldId', args.worldId).eq('agentId', agent.id))
.first();
if (!agentDescription) {
throw new Error(`Agent description for ${agent.id} not found`);
}
const otherAgent = world.agents.find((a) => a.playerId === args.otherPlayerId);
let otherAgentDescription;
if (otherAgent) {
otherAgentDescription = await ctx.db
.query('agentDescriptions')
.withIndex('worldId', (q) => q.eq('worldId', args.worldId).eq('agentId', otherAgent.id))
.first();
if (!otherAgentDescription) {
throw new Error(`Agent description for ${otherAgent.id} not found`);
}
}
const lastTogether = await ctx.db
.query('participatedTogether')
.withIndex('edge', (q) =>
q
.eq('worldId', args.worldId)
.eq('player1', args.playerId)
.eq('player2', args.otherPlayerId),
)
// Order by conversation end time descending.
.order('desc')
.first();
let lastConversation = null;
if (lastTogether) {
lastConversation = await ctx.db
.query('archivedConversations')
.withIndex('worldId', (q) =>
q.eq('worldId', args.worldId).eq('id', lastTogether.conversationId),
)
.first();
if (!lastConversation) {
throw new Error(`Conversation ${lastTogether.conversationId} not found`);
}
}
return {
player: { name: playerDescription.name, ...player },
otherPlayer: { name: otherPlayerDescription.name, ...otherPlayer },
conversation,
agent: { identity: agentDescription.identity, plan: agentDescription.plan, ...agent },
otherAgent: otherAgent && {
identity: otherAgentDescription!.identity,
plan: otherAgentDescription!.plan,
...otherAgent,
},
lastConversation,
};
},
});
function stopWords(otherPlayer: string, player: string) {
// These are the words we ask the LLM to stop on. OpenAI only supports 4.
const variants = [`${otherPlayer} to ${player}`];
return variants.flatMap((stop) => [stop + ':', stop.toLowerCase() + ':']);
}