Spaces:
Running
Running
import type { | |
ChatCompletionMessageParam, | |
ChatOptions, | |
InitProgressCallback, | |
MLCEngineConfig, | |
} from "@mlc-ai/web-llm"; | |
import type { ChatMessage } from "gpt-tokenizer/GptEncoding"; | |
import { addLogEntry } from "./logEntries"; | |
import { | |
getSettings, | |
updateModelLoadingProgress, | |
updateModelSizeInMegabytes, | |
updateResponse, | |
updateTextGenerationState, | |
} from "./pubSub"; | |
import { | |
canStartResponding, | |
defaultContextSize, | |
getDefaultChatCompletionCreateParamsStreaming, | |
getDefaultChatMessages, | |
getFormattedSearchResults, | |
handleStreamingResponse, | |
} from "./textGenerationUtilities"; | |
export async function generateTextWithWebLlm() { | |
const engine = await initializeWebLlmEngine(); | |
if (getSettings().enableAiResponse) { | |
await canStartResponding(); | |
updateTextGenerationState("preparingToGenerate"); | |
const completion = await engine.chat.completions.create({ | |
...getDefaultChatCompletionCreateParamsStreaming(), | |
messages: getDefaultChatMessages( | |
getFormattedSearchResults(true), | |
) as ChatCompletionMessageParam[], | |
}); | |
await handleStreamingResponse(completion, updateResponse, { | |
shouldUpdateGeneratingState: true, | |
}); | |
} | |
addLogEntry( | |
`WebLLM finished generating the response. Stats: ${await engine.runtimeStatsText()}`, | |
); | |
engine.unload(); | |
} | |
export async function generateChatWithWebLlm( | |
messages: ChatMessage[], | |
onUpdate: (partialResponse: string) => void, | |
) { | |
const engine = await initializeWebLlmEngine(); | |
const completion = await engine.chat.completions.create({ | |
...getDefaultChatCompletionCreateParamsStreaming(), | |
messages: messages as ChatCompletionMessageParam[], | |
}); | |
const response = await handleStreamingResponse(completion, onUpdate); | |
addLogEntry( | |
`WebLLM finished generating the chat response. Stats: ${await engine.runtimeStatsText()}`, | |
); | |
engine.unload(); | |
return response; | |
} | |
async function initializeWebLlmEngine() { | |
const { | |
CreateWebWorkerMLCEngine, | |
CreateMLCEngine, | |
hasModelInCache, | |
prebuiltAppConfig, | |
} = await import("@mlc-ai/web-llm"); | |
const selectedModelId = getSettings().webLlmModelId; | |
updateModelSizeInMegabytes( | |
prebuiltAppConfig.model_list.find((m) => m.model_id === selectedModelId) | |
?.vram_required_MB || 0, | |
); | |
addLogEntry(`Selected WebLLM model: ${selectedModelId}`); | |
const isModelCached = await hasModelInCache(selectedModelId); | |
let initProgressCallback: InitProgressCallback | undefined; | |
if (!isModelCached) { | |
initProgressCallback = (report) => { | |
updateModelLoadingProgress(Math.round(report.progress * 100)); | |
}; | |
} | |
const engineConfig: MLCEngineConfig = { | |
initProgressCallback, | |
logLevel: "SILENT", | |
}; | |
const chatOptions: ChatOptions = { | |
context_window_size: defaultContextSize, | |
sliding_window_size: -1, | |
attention_sink_size: -1, | |
}; | |
return Worker | |
? await CreateWebWorkerMLCEngine( | |
new Worker(new URL("./webLlmWorker.ts", import.meta.url), { | |
type: "module", | |
}), | |
selectedModelId, | |
engineConfig, | |
chatOptions, | |
) | |
: await CreateMLCEngine(selectedModelId, engineConfig, chatOptions); | |
} | |