|
import { VertexAI, HarmCategory, HarmBlockThreshold } from "@google-cloud/vertexai"; |
|
import { buildPrompt } from "$lib/buildPrompt"; |
|
import type { TextGenerationStreamOutput } from "@huggingface/inference"; |
|
import type { Endpoint } from "../endpoints"; |
|
import { z } from "zod"; |
|
|
|
export const endpointVertexParametersSchema = z.object({ |
|
weight: z.number().int().positive().default(1), |
|
model: z.any(), |
|
type: z.literal("vertex"), |
|
location: z.string().default("europe-west1"), |
|
project: z.string(), |
|
apiEndpoint: z.string().optional(), |
|
}); |
|
|
|
export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint { |
|
const { project, location, model, apiEndpoint } = endpointVertexParametersSchema.parse(input); |
|
|
|
const vertex_ai = new VertexAI({ |
|
project, |
|
location, |
|
apiEndpoint, |
|
}); |
|
|
|
const generativeModel = vertex_ai.getGenerativeModel({ |
|
model: model.id ?? model.name, |
|
safety_settings: [ |
|
{ |
|
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, |
|
threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
}, |
|
], |
|
generation_config: {}, |
|
}); |
|
|
|
return async ({ messages, preprompt, continueMessage }) => { |
|
const prompt = await buildPrompt({ |
|
messages, |
|
continueMessage, |
|
preprompt, |
|
model, |
|
}); |
|
|
|
const chat = generativeModel.startChat(); |
|
const result = await chat.sendMessageStream(prompt); |
|
let tokenId = 0; |
|
|
|
return (async function* () { |
|
let generatedText = ""; |
|
|
|
for await (const data of result.stream) { |
|
if (Array.isArray(data?.candidates) && data.candidates.length > 0) { |
|
const firstPart = data.candidates[0].content.parts[0]; |
|
if ("text" in firstPart) { |
|
const content = firstPart.text; |
|
generatedText += content; |
|
const output: TextGenerationStreamOutput = { |
|
token: { |
|
id: tokenId++, |
|
text: content ?? "", |
|
logprob: 0, |
|
special: false, |
|
}, |
|
generated_text: generatedText, |
|
details: null, |
|
}; |
|
yield output; |
|
} |
|
|
|
if (!data.candidates.slice(-1)[0].finishReason) break; |
|
} else { |
|
break; |
|
} |
|
} |
|
})(); |
|
}; |
|
} |
|
export default endpointVertex; |
|
|