File size: 3,192 Bytes
884908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import type {
  ChatCompletionMessageParam,
  ChatOptions,
  InitProgressCallback,
  MLCEngineConfig,
} from "@mlc-ai/web-llm";
import type { ChatMessage } from "gpt-tokenizer/GptEncoding";
import { addLogEntry } from "./logEntries";
import {
  getSettings,
  updateModelLoadingProgress,
  updateModelSizeInMegabytes,
  updateResponse,
  updateTextGenerationState,
} from "./pubSub";
import {
  canStartResponding,
  defaultContextSize,
  getDefaultChatCompletionCreateParamsStreaming,
  getDefaultChatMessages,
  getFormattedSearchResults,
  handleStreamingResponse,
} from "./textGenerationUtilities";

export async function generateTextWithWebLlm() {
  const engine = await initializeWebLlmEngine();

  if (getSettings().enableAiResponse) {
    await canStartResponding();
    updateTextGenerationState("preparingToGenerate");

    const completion = await engine.chat.completions.create({
      ...getDefaultChatCompletionCreateParamsStreaming(),
      messages: getDefaultChatMessages(
        getFormattedSearchResults(true),
      ) as ChatCompletionMessageParam[],
    });

    await handleStreamingResponse(completion, updateResponse, {
      shouldUpdateGeneratingState: true,
    });
  }

  addLogEntry(
    `WebLLM finished generating the response. Stats: ${await engine.runtimeStatsText()}`,
  );

  engine.unload();
}

export async function generateChatWithWebLlm(
  messages: ChatMessage[],
  onUpdate: (partialResponse: string) => void,
) {
  const engine = await initializeWebLlmEngine();

  const completion = await engine.chat.completions.create({
    ...getDefaultChatCompletionCreateParamsStreaming(),
    messages: messages as ChatCompletionMessageParam[],
  });

  const response = await handleStreamingResponse(completion, onUpdate);

  addLogEntry(
    `WebLLM finished generating the chat response. Stats: ${await engine.runtimeStatsText()}`,
  );

  engine.unload();
  return response;
}

async function initializeWebLlmEngine() {
  const {
    CreateWebWorkerMLCEngine,
    CreateMLCEngine,
    hasModelInCache,
    prebuiltAppConfig,
  } = await import("@mlc-ai/web-llm");

  const selectedModelId = getSettings().webLlmModelId;

  updateModelSizeInMegabytes(
    prebuiltAppConfig.model_list.find((m) => m.model_id === selectedModelId)
      ?.vram_required_MB || 0,
  );

  addLogEntry(`Selected WebLLM model: ${selectedModelId}`);

  const isModelCached = await hasModelInCache(selectedModelId);
  let initProgressCallback: InitProgressCallback | undefined;

  if (!isModelCached) {
    initProgressCallback = (report) => {
      updateModelLoadingProgress(Math.round(report.progress * 100));
    };
  }

  const engineConfig: MLCEngineConfig = {
    initProgressCallback,
    logLevel: "SILENT",
  };

  const chatOptions: ChatOptions = {
    context_window_size: defaultContextSize,
    sliding_window_size: -1,
    attention_sink_size: -1,
  };

  return Worker
    ? await CreateWebWorkerMLCEngine(
        new Worker(new URL("./webLlmWorker.ts", import.meta.url), {
          type: "module",
        }),
        selectedModelId,
        engineConfig,
        chatOptions,
      )
    : await CreateMLCEngine(selectedModelId, engineConfig, chatOptions);
}