MiniSearch / client /modules /textGenerationWithWebLlm.ts
github-actions[bot]
Sync from https://github.com/felladrin/MiniSearch
884908f
import type {
ChatCompletionMessageParam,
ChatOptions,
InitProgressCallback,
MLCEngineConfig,
} from "@mlc-ai/web-llm";
import type { ChatMessage } from "gpt-tokenizer/GptEncoding";
import { addLogEntry } from "./logEntries";
import {
getSettings,
updateModelLoadingProgress,
updateModelSizeInMegabytes,
updateResponse,
updateTextGenerationState,
} from "./pubSub";
import {
canStartResponding,
defaultContextSize,
getDefaultChatCompletionCreateParamsStreaming,
getDefaultChatMessages,
getFormattedSearchResults,
handleStreamingResponse,
} from "./textGenerationUtilities";
export async function generateTextWithWebLlm() {
const engine = await initializeWebLlmEngine();
if (getSettings().enableAiResponse) {
await canStartResponding();
updateTextGenerationState("preparingToGenerate");
const completion = await engine.chat.completions.create({
...getDefaultChatCompletionCreateParamsStreaming(),
messages: getDefaultChatMessages(
getFormattedSearchResults(true),
) as ChatCompletionMessageParam[],
});
await handleStreamingResponse(completion, updateResponse, {
shouldUpdateGeneratingState: true,
});
}
addLogEntry(
`WebLLM finished generating the response. Stats: ${await engine.runtimeStatsText()}`,
);
engine.unload();
}
export async function generateChatWithWebLlm(
messages: ChatMessage[],
onUpdate: (partialResponse: string) => void,
) {
const engine = await initializeWebLlmEngine();
const completion = await engine.chat.completions.create({
...getDefaultChatCompletionCreateParamsStreaming(),
messages: messages as ChatCompletionMessageParam[],
});
const response = await handleStreamingResponse(completion, onUpdate);
addLogEntry(
`WebLLM finished generating the chat response. Stats: ${await engine.runtimeStatsText()}`,
);
engine.unload();
return response;
}
async function initializeWebLlmEngine() {
const {
CreateWebWorkerMLCEngine,
CreateMLCEngine,
hasModelInCache,
prebuiltAppConfig,
} = await import("@mlc-ai/web-llm");
const selectedModelId = getSettings().webLlmModelId;
updateModelSizeInMegabytes(
prebuiltAppConfig.model_list.find((m) => m.model_id === selectedModelId)
?.vram_required_MB || 0,
);
addLogEntry(`Selected WebLLM model: ${selectedModelId}`);
const isModelCached = await hasModelInCache(selectedModelId);
let initProgressCallback: InitProgressCallback | undefined;
if (!isModelCached) {
initProgressCallback = (report) => {
updateModelLoadingProgress(Math.round(report.progress * 100));
};
}
const engineConfig: MLCEngineConfig = {
initProgressCallback,
logLevel: "SILENT",
};
const chatOptions: ChatOptions = {
context_window_size: defaultContextSize,
sliding_window_size: -1,
attention_sink_size: -1,
};
return Worker
? await CreateWebWorkerMLCEngine(
new Worker(new URL("./webLlmWorker.ts", import.meta.url), {
type: "module",
}),
selectedModelId,
engineConfig,
chatOptions,
)
: await CreateMLCEngine(selectedModelId, engineConfig, chatOptions);
}