|
const BaseClient = require('./BaseClient'); |
|
const { google } = require('googleapis'); |
|
const { Agent, ProxyAgent } = require('undici'); |
|
const { |
|
encoding_for_model: encodingForModel, |
|
get_encoding: getEncoding, |
|
} = require('@dqbd/tiktoken'); |
|
|
|
const tokenizersCache = {}; |
|
|
|
class GoogleClient extends BaseClient { |
|
constructor(credentials, options = {}) { |
|
super('apiKey', options); |
|
this.client_email = credentials.client_email; |
|
this.project_id = credentials.project_id; |
|
this.private_key = credentials.private_key; |
|
this.sender = 'PaLM2'; |
|
this.setOptions(options); |
|
} |
|
|
|
|
|
constructUrl() { |
|
return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`; |
|
} |
|
|
|
async getClient() { |
|
const scopes = ['https://www.googleapis.com/auth/cloud-platform']; |
|
const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes); |
|
|
|
jwtClient.authorize((err) => { |
|
if (err) { |
|
console.log(err); |
|
throw err; |
|
} |
|
}); |
|
|
|
return jwtClient; |
|
} |
|
|
|
|
|
setOptions(options) { |
|
if (this.options && !this.options.replaceOptions) { |
|
|
|
this.options.modelOptions = { |
|
...this.options.modelOptions, |
|
...options.modelOptions, |
|
}; |
|
delete options.modelOptions; |
|
|
|
this.options = { |
|
...this.options, |
|
...options, |
|
}; |
|
} else { |
|
this.options = options; |
|
} |
|
|
|
this.options.examples = this.options.examples.filter( |
|
(obj) => obj.input.content !== '' && obj.output.content !== '', |
|
); |
|
|
|
const modelOptions = this.options.modelOptions || {}; |
|
this.modelOptions = { |
|
...modelOptions, |
|
|
|
model: modelOptions.model || 'chat-bison', |
|
temperature: typeof modelOptions.temperature === 'undefined' ? 0.2 : modelOptions.temperature, |
|
topP: typeof modelOptions.topP === 'undefined' ? 0.95 : modelOptions.topP, |
|
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, |
|
|
|
}; |
|
|
|
this.isChatModel = this.modelOptions.model.startsWith('chat-'); |
|
const { isChatModel } = this; |
|
this.isTextModel = this.modelOptions.model.startsWith('text-'); |
|
const { isTextModel } = this; |
|
|
|
this.maxContextTokens = this.options.maxContextTokens || (isTextModel ? 8000 : 4096); |
|
|
|
|
|
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1024; |
|
this.maxPromptTokens = |
|
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; |
|
|
|
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { |
|
throw new Error( |
|
`maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ |
|
this.maxPromptTokens + this.maxResponseTokens |
|
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`, |
|
); |
|
} |
|
|
|
this.userLabel = this.options.userLabel || 'User'; |
|
this.modelLabel = this.options.modelLabel || 'Assistant'; |
|
|
|
if (isChatModel) { |
|
|
|
|
|
|
|
this.startToken = '||>'; |
|
this.endToken = ''; |
|
this.gptEncoder = this.constructor.getTokenizer('cl100k_base'); |
|
} else if (isTextModel) { |
|
this.startToken = '<|im_start|>'; |
|
this.endToken = '<|im_end|>'; |
|
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, { |
|
'<|im_start|>': 100264, |
|
'<|im_end|>': 100265, |
|
}); |
|
} else { |
|
|
|
|
|
|
|
this.startToken = '||>'; |
|
this.endToken = ''; |
|
try { |
|
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true); |
|
} catch { |
|
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true); |
|
} |
|
} |
|
|
|
if (!this.modelOptions.stop) { |
|
const stopTokens = [this.startToken]; |
|
if (this.endToken && this.endToken !== this.startToken) { |
|
stopTokens.push(this.endToken); |
|
} |
|
stopTokens.push(`\n${this.userLabel}:`); |
|
stopTokens.push('<|diff_marker|>'); |
|
|
|
this.modelOptions.stop = stopTokens; |
|
} |
|
|
|
if (this.options.reverseProxyUrl) { |
|
this.completionsUrl = this.options.reverseProxyUrl; |
|
} else { |
|
this.completionsUrl = this.constructUrl(); |
|
} |
|
|
|
return this; |
|
} |
|
|
|
getMessageMapMethod() { |
|
return ((message) => ({ |
|
author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel), |
|
content: message?.content ?? message.text, |
|
})).bind(this); |
|
} |
|
|
|
buildMessages(messages = []) { |
|
const formattedMessages = messages.map(this.getMessageMapMethod()); |
|
let payload = { |
|
instances: [ |
|
{ |
|
messages: formattedMessages, |
|
}, |
|
], |
|
parameters: this.options.modelOptions, |
|
}; |
|
|
|
if (this.options.promptPrefix) { |
|
payload.instances[0].context = this.options.promptPrefix; |
|
} |
|
|
|
if (this.options.examples.length > 0) { |
|
payload.instances[0].examples = this.options.examples; |
|
} |
|
|
|
|
|
if (this.isTextModel) { |
|
payload.instances = [ |
|
{ |
|
prompt: messages[messages.length - 1].content, |
|
}, |
|
]; |
|
} |
|
|
|
if (this.options.debug) { |
|
console.debug('GoogleClient buildMessages'); |
|
console.dir(payload, { depth: null }); |
|
} |
|
|
|
return { prompt: payload }; |
|
} |
|
|
|
async getCompletion(payload, abortController = null) { |
|
if (!abortController) { |
|
abortController = new AbortController(); |
|
} |
|
const { debug } = this.options; |
|
const url = this.completionsUrl; |
|
if (debug) { |
|
console.debug(); |
|
console.debug(url); |
|
console.debug(this.modelOptions); |
|
console.debug(); |
|
} |
|
const opts = { |
|
method: 'POST', |
|
agent: new Agent({ |
|
bodyTimeout: 0, |
|
headersTimeout: 0, |
|
}), |
|
signal: abortController.signal, |
|
}; |
|
|
|
if (this.options.proxy) { |
|
opts.agent = new ProxyAgent(this.options.proxy); |
|
} |
|
|
|
const client = await this.getClient(); |
|
const res = await client.request({ url, method: 'POST', data: payload }); |
|
console.dir(res.data, { depth: null }); |
|
return res.data; |
|
} |
|
|
|
getSaveOptions() { |
|
return { |
|
promptPrefix: this.options.promptPrefix, |
|
modelLabel: this.options.modelLabel, |
|
...this.modelOptions, |
|
}; |
|
} |
|
|
|
getBuildMessagesOptions() { |
|
|
|
} |
|
|
|
async sendCompletion(payload, opts = {}) { |
|
console.log('GoogleClient: sendcompletion', payload, opts); |
|
let reply = ''; |
|
let blocked = false; |
|
try { |
|
const result = await this.getCompletion(payload, opts.abortController); |
|
blocked = result?.predictions?.[0]?.safetyAttributes?.blocked; |
|
reply = |
|
result?.predictions?.[0]?.candidates?.[0]?.content || |
|
result?.predictions?.[0]?.content || |
|
''; |
|
if (blocked === true) { |
|
reply = `Google blocked a proper response to your message:\n${JSON.stringify( |
|
result.predictions[0].safetyAttributes, |
|
)}${reply.length > 0 ? `\nAI Response:\n${reply}` : ''}`; |
|
} |
|
if (this.options.debug) { |
|
console.debug('result'); |
|
console.debug(result); |
|
} |
|
} catch (err) { |
|
console.error(err); |
|
} |
|
|
|
if (!blocked) { |
|
await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 }); |
|
} |
|
|
|
return reply.trim(); |
|
} |
|
|
|
|
|
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { |
|
if (tokenizersCache[encoding]) { |
|
return tokenizersCache[encoding]; |
|
} |
|
let tokenizer; |
|
if (isModelName) { |
|
tokenizer = encodingForModel(encoding, extendSpecialTokens); |
|
} else { |
|
tokenizer = getEncoding(encoding, extendSpecialTokens); |
|
} |
|
tokenizersCache[encoding] = tokenizer; |
|
return tokenizer; |
|
} |
|
|
|
getTokenCount(text) { |
|
return this.gptEncoder.encode(text, 'all').length; |
|
} |
|
} |
|
|
|
module.exports = GoogleClient; |
|
|