|
const crypto = require('crypto'); |
|
const TextStream = require('./TextStream'); |
|
const { RecursiveCharacterTextSplitter } = require('langchain/text_splitter'); |
|
const { ChatOpenAI } = require('langchain/chat_models/openai'); |
|
const { loadSummarizationChain } = require('langchain/chains'); |
|
const { refinePrompt } = require('./prompts/refinePrompt'); |
|
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models'); |
|
|
|
class BaseClient { |
|
constructor(apiKey, options = {}) { |
|
this.apiKey = apiKey; |
|
this.sender = options.sender || 'AI'; |
|
this.contextStrategy = null; |
|
this.currentDateString = new Date().toLocaleDateString('en-us', { |
|
year: 'numeric', |
|
month: 'long', |
|
day: 'numeric', |
|
}); |
|
} |
|
|
|
setOptions() { |
|
throw new Error('Method \'setOptions\' must be implemented.'); |
|
} |
|
|
|
getCompletion() { |
|
throw new Error('Method \'getCompletion\' must be implemented.'); |
|
} |
|
|
|
async sendCompletion() { |
|
throw new Error('Method \'sendCompletion\' must be implemented.'); |
|
} |
|
|
|
getSaveOptions() { |
|
throw new Error('Subclasses must implement getSaveOptions'); |
|
} |
|
|
|
async buildMessages() { |
|
throw new Error('Subclasses must implement buildMessages'); |
|
} |
|
|
|
getBuildMessagesOptions() { |
|
throw new Error('Subclasses must implement getBuildMessagesOptions'); |
|
} |
|
|
|
async generateTextStream(text, onProgress, options = {}) { |
|
const stream = new TextStream(text, options); |
|
await stream.processTextStream(onProgress); |
|
} |
|
|
|
async setMessageOptions(opts = {}) { |
|
if (opts && typeof opts === 'object') { |
|
this.setOptions(opts); |
|
} |
|
const user = opts.user || null; |
|
const conversationId = opts.conversationId || crypto.randomUUID(); |
|
const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000'; |
|
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); |
|
const responseMessageId = crypto.randomUUID(); |
|
const saveOptions = this.getSaveOptions(); |
|
this.abortController = opts.abortController || new AbortController(); |
|
this.currentMessages = (await this.loadHistory(conversationId, parentMessageId)) ?? []; |
|
|
|
return { |
|
...opts, |
|
user, |
|
conversationId, |
|
parentMessageId, |
|
userMessageId, |
|
responseMessageId, |
|
saveOptions, |
|
}; |
|
} |
|
|
|
createUserMessage({ messageId, parentMessageId, conversationId, text }) { |
|
const userMessage = { |
|
messageId, |
|
parentMessageId, |
|
conversationId, |
|
sender: 'User', |
|
text, |
|
isCreatedByUser: true, |
|
}; |
|
return userMessage; |
|
} |
|
|
|
async handleStartMethods(message, opts) { |
|
const { user, conversationId, parentMessageId, userMessageId, responseMessageId, saveOptions } = |
|
await this.setMessageOptions(opts); |
|
|
|
const userMessage = this.createUserMessage({ |
|
messageId: userMessageId, |
|
parentMessageId, |
|
conversationId, |
|
text: message, |
|
}); |
|
|
|
if (typeof opts?.getIds === 'function') { |
|
opts.getIds({ |
|
userMessage, |
|
conversationId, |
|
responseMessageId, |
|
}); |
|
} |
|
|
|
if (typeof opts?.onStart === 'function') { |
|
opts.onStart(userMessage); |
|
} |
|
|
|
return { |
|
...opts, |
|
user, |
|
conversationId, |
|
responseMessageId, |
|
saveOptions, |
|
userMessage, |
|
}; |
|
} |
|
|
|
addInstructions(messages, instructions) { |
|
const payload = []; |
|
if (!instructions) { |
|
return messages; |
|
} |
|
if (messages.length > 1) { |
|
payload.push(...messages.slice(0, -1)); |
|
} |
|
|
|
payload.push(instructions); |
|
|
|
if (messages.length > 0) { |
|
payload.push(messages[messages.length - 1]); |
|
} |
|
|
|
return payload; |
|
} |
|
|
|
async handleTokenCountMap(tokenCountMap) { |
|
if (this.currentMessages.length === 0) { |
|
return; |
|
} |
|
|
|
for (let i = 0; i < this.currentMessages.length; i++) { |
|
|
|
if (i === this.currentMessages.length - 1) { |
|
break; |
|
} |
|
|
|
const message = this.currentMessages[i]; |
|
const { messageId } = message; |
|
const update = {}; |
|
|
|
if (messageId === tokenCountMap.refined?.messageId) { |
|
if (this.options.debug) { |
|
console.debug(`Adding refined props to ${messageId}.`); |
|
} |
|
|
|
update.refinedMessageText = tokenCountMap.refined.content; |
|
update.refinedTokenCount = tokenCountMap.refined.tokenCount; |
|
} |
|
|
|
if (message.tokenCount && !update.refinedTokenCount) { |
|
if (this.options.debug) { |
|
console.debug(`Skipping ${messageId}: already had a token count.`); |
|
} |
|
continue; |
|
} |
|
|
|
const tokenCount = tokenCountMap[messageId]; |
|
if (tokenCount) { |
|
message.tokenCount = tokenCount; |
|
update.tokenCount = tokenCount; |
|
await this.updateMessageInDatabase({ messageId, ...update }); |
|
} |
|
} |
|
} |
|
|
|
concatenateMessages(messages) { |
|
return messages.reduce((acc, message) => { |
|
const nameOrRole = message.name ?? message.role; |
|
return acc + `${nameOrRole}:\n${message.content}\n\n`; |
|
}, ''); |
|
} |
|
|
|
async refineMessages(messagesToRefine, remainingContextTokens) { |
|
const model = new ChatOpenAI({ temperature: 0 }); |
|
const chain = loadSummarizationChain(model, { |
|
type: 'refine', |
|
verbose: this.options.debug, |
|
refinePrompt, |
|
}); |
|
const splitter = new RecursiveCharacterTextSplitter({ |
|
chunkSize: 1500, |
|
chunkOverlap: 100, |
|
}); |
|
const userMessages = this.concatenateMessages( |
|
messagesToRefine.filter((m) => m.role === 'user'), |
|
); |
|
const assistantMessages = this.concatenateMessages( |
|
messagesToRefine.filter((m) => m.role !== 'user'), |
|
); |
|
const userDocs = await splitter.createDocuments([userMessages], [], { |
|
chunkHeader: 'DOCUMENT NAME: User Message\n\n---\n\n', |
|
appendChunkOverlapHeader: true, |
|
}); |
|
const assistantDocs = await splitter.createDocuments([assistantMessages], [], { |
|
chunkHeader: 'DOCUMENT NAME: Assistant Message\n\n---\n\n', |
|
appendChunkOverlapHeader: true, |
|
}); |
|
|
|
const input_documents = userDocs.concat(assistantDocs); |
|
if (this.options.debug) { |
|
console.debug('Refining messages...'); |
|
} |
|
try { |
|
const res = await chain.call({ |
|
input_documents, |
|
signal: this.abortController.signal, |
|
}); |
|
|
|
const refinedMessage = { |
|
role: 'assistant', |
|
content: res.output_text, |
|
tokenCount: this.getTokenCount(res.output_text), |
|
}; |
|
|
|
if (this.options.debug) { |
|
console.debug('Refined messages', refinedMessage); |
|
console.debug( |
|
`remainingContextTokens: ${remainingContextTokens}, after refining: ${ |
|
remainingContextTokens - refinedMessage.tokenCount |
|
}`, |
|
); |
|
} |
|
|
|
return refinedMessage; |
|
} catch (e) { |
|
console.error('Error refining messages'); |
|
console.error(e); |
|
return null; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async getMessagesWithinTokenLimit(messages) { |
|
let currentTokenCount = 0; |
|
let context = []; |
|
let messagesToRefine = []; |
|
let refineIndex = -1; |
|
let remainingContextTokens = this.maxContextTokens; |
|
|
|
for (let i = messages.length - 1; i >= 0; i--) { |
|
const message = messages[i]; |
|
const newTokenCount = currentTokenCount + message.tokenCount; |
|
const exceededLimit = newTokenCount > this.maxContextTokens; |
|
let shouldRefine = exceededLimit && this.shouldRefineContext; |
|
let refineNextMessage = i !== 0 && i !== 1 && context.length > 0; |
|
|
|
if (shouldRefine) { |
|
messagesToRefine.push(message); |
|
|
|
if (refineIndex === -1) { |
|
refineIndex = i; |
|
} |
|
|
|
if (refineNextMessage) { |
|
refineIndex = i + 1; |
|
const removedMessage = context.pop(); |
|
messagesToRefine.push(removedMessage); |
|
currentTokenCount -= removedMessage.tokenCount; |
|
remainingContextTokens = this.maxContextTokens - currentTokenCount; |
|
refineNextMessage = false; |
|
} |
|
|
|
continue; |
|
} else if (exceededLimit) { |
|
break; |
|
} |
|
|
|
context.push(message); |
|
currentTokenCount = newTokenCount; |
|
remainingContextTokens = this.maxContextTokens - currentTokenCount; |
|
await new Promise((resolve) => setImmediate(resolve)); |
|
} |
|
|
|
return { |
|
context: context.reverse(), |
|
remainingContextTokens, |
|
messagesToRefine: messagesToRefine.reverse(), |
|
refineIndex, |
|
}; |
|
} |
|
|
|
async handleContextStrategy({ instructions, orderedMessages, formattedMessages }) { |
|
let payload = this.addInstructions(formattedMessages, instructions); |
|
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); |
|
let { context, remainingContextTokens, messagesToRefine, refineIndex } = |
|
await this.getMessagesWithinTokenLimit(payload); |
|
|
|
payload = context; |
|
let refinedMessage; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let diff = orderedWithInstructions.length - payload.length; |
|
|
|
if (this.options.debug) { |
|
console.debug('<---------------------------------DIFF--------------------------------->'); |
|
console.debug( |
|
`Difference between payload (${payload.length}) and orderedWithInstructions (${orderedWithInstructions.length}): ${diff}`, |
|
); |
|
console.debug( |
|
'remainingContextTokens, this.maxContextTokens (1/2)', |
|
remainingContextTokens, |
|
this.maxContextTokens, |
|
); |
|
} |
|
|
|
|
|
if (diff > 0) { |
|
orderedWithInstructions = orderedWithInstructions.slice(diff); |
|
} |
|
|
|
if (messagesToRefine.length > 0) { |
|
refinedMessage = await this.refineMessages(messagesToRefine, remainingContextTokens); |
|
payload.unshift(refinedMessage); |
|
remainingContextTokens -= refinedMessage.tokenCount; |
|
} |
|
|
|
if (this.options.debug) { |
|
console.debug( |
|
'remainingContextTokens, this.maxContextTokens (2/2)', |
|
remainingContextTokens, |
|
this.maxContextTokens, |
|
); |
|
} |
|
|
|
let tokenCountMap = orderedWithInstructions.reduce((map, message, index) => { |
|
if (!message.messageId) { |
|
return map; |
|
} |
|
|
|
if (index === refineIndex) { |
|
map.refined = { ...refinedMessage, messageId: message.messageId }; |
|
} |
|
|
|
map[message.messageId] = payload[index].tokenCount; |
|
return map; |
|
}, {}); |
|
|
|
const promptTokens = this.maxContextTokens - remainingContextTokens; |
|
|
|
if (this.options.debug) { |
|
console.debug('<-------------------------PAYLOAD/TOKEN COUNT MAP------------------------->'); |
|
console.debug('Payload:', payload); |
|
console.debug('Token Count Map:', tokenCountMap); |
|
console.debug('Prompt Tokens', promptTokens, remainingContextTokens, this.maxContextTokens); |
|
} |
|
|
|
return { payload, tokenCountMap, promptTokens, messages: orderedWithInstructions }; |
|
} |
|
|
|
async sendMessage(message, opts = {}) { |
|
const { user, conversationId, responseMessageId, saveOptions, userMessage } = |
|
await this.handleStartMethods(message, opts); |
|
|
|
this.user = user; |
|
|
|
|
|
this.currentMessages.push(userMessage); |
|
|
|
let { |
|
prompt: payload, |
|
tokenCountMap, |
|
promptTokens, |
|
} = await this.buildMessages( |
|
this.currentMessages, |
|
|
|
|
|
userMessage.messageId, |
|
this.getBuildMessagesOptions(opts), |
|
); |
|
|
|
if (this.options.debug) { |
|
console.debug('payload'); |
|
console.debug(payload); |
|
} |
|
|
|
if (tokenCountMap) { |
|
console.dir(tokenCountMap, { depth: null }); |
|
if (tokenCountMap[userMessage.messageId]) { |
|
userMessage.tokenCount = tokenCountMap[userMessage.messageId]; |
|
console.log('userMessage.tokenCount', userMessage.tokenCount); |
|
console.log('userMessage', userMessage); |
|
} |
|
|
|
payload = payload.map((message) => { |
|
const messageWithoutTokenCount = message; |
|
delete messageWithoutTokenCount.tokenCount; |
|
return messageWithoutTokenCount; |
|
}); |
|
this.handleTokenCountMap(tokenCountMap); |
|
} |
|
|
|
await this.saveMessageToDatabase(userMessage, saveOptions, user); |
|
const responseMessage = { |
|
messageId: responseMessageId, |
|
conversationId, |
|
parentMessageId: userMessage.messageId, |
|
isCreatedByUser: false, |
|
model: this.modelOptions.model, |
|
sender: this.sender, |
|
text: await this.sendCompletion(payload, opts), |
|
promptTokens, |
|
}; |
|
|
|
if (tokenCountMap && this.getTokenCountForResponse) { |
|
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); |
|
responseMessage.completionTokens = responseMessage.tokenCount; |
|
} |
|
await this.saveMessageToDatabase(responseMessage, saveOptions, user); |
|
delete responseMessage.tokenCount; |
|
return responseMessage; |
|
} |
|
|
|
async getConversation(conversationId, user = null) { |
|
return await getConvo(user, conversationId); |
|
} |
|
|
|
async loadHistory(conversationId, parentMessageId = null) { |
|
if (this.options.debug) { |
|
console.debug('Loading history for conversation', conversationId, parentMessageId); |
|
} |
|
|
|
const messages = (await getMessages({ conversationId })) || []; |
|
|
|
if (messages.length === 0) { |
|
return []; |
|
} |
|
|
|
let mapMethod = null; |
|
if (this.getMessageMapMethod) { |
|
mapMethod = this.getMessageMapMethod(); |
|
} |
|
|
|
return this.constructor.getMessagesForConversation(messages, parentMessageId, mapMethod); |
|
} |
|
|
|
async saveMessageToDatabase(message, endpointOptions, user = null) { |
|
await saveMessage({ ...message, unfinished: false, cancelled: false }); |
|
await saveConvo(user, { |
|
conversationId: message.conversationId, |
|
endpoint: this.options.endpoint, |
|
...endpointOptions, |
|
}); |
|
} |
|
|
|
async updateMessageInDatabase(message) { |
|
await updateMessage(message); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static getMessagesForConversation(messages, parentMessageId, mapMethod = null) { |
|
if (!messages || messages.length === 0) { |
|
return []; |
|
} |
|
|
|
const orderedMessages = []; |
|
let currentMessageId = parentMessageId; |
|
while (currentMessageId) { |
|
const message = messages.find((msg) => { |
|
const messageId = msg.messageId ?? msg.id; |
|
return messageId === currentMessageId; |
|
}); |
|
if (!message) { |
|
break; |
|
} |
|
orderedMessages.unshift(message); |
|
currentMessageId = message.parentMessageId; |
|
} |
|
|
|
if (mapMethod) { |
|
return orderedMessages.map(mapMethod); |
|
} |
|
|
|
return orderedMessages; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
getTokenCountForMessage(message) { |
|
let tokensPerMessage; |
|
let nameAdjustment; |
|
if (this.modelOptions.model.startsWith('gpt-4')) { |
|
tokensPerMessage = 3; |
|
nameAdjustment = 1; |
|
} else { |
|
tokensPerMessage = 4; |
|
nameAdjustment = -1; |
|
} |
|
|
|
if (this.options.debug) { |
|
console.debug('getTokenCountForMessage', message); |
|
} |
|
|
|
|
|
const propertyTokenCounts = Object.entries(message).map(([key, value]) => { |
|
if (key === 'tokenCount' || typeof value !== 'string') { |
|
return 0; |
|
} |
|
|
|
const numTokens = this.getTokenCount(value); |
|
|
|
|
|
const adjustment = key === 'name' ? nameAdjustment : 0; |
|
return numTokens + adjustment; |
|
}); |
|
|
|
if (this.options.debug) { |
|
console.debug('propertyTokenCounts', propertyTokenCounts); |
|
} |
|
|
|
|
|
return propertyTokenCounts.reduce((a, b) => a + b, tokensPerMessage); |
|
} |
|
} |
|
|
|
module.exports = BaseClient; |
|
|