nsarrazin HF staff Mishig commited on
Commit
77399ca
·
unverified ·
1 Parent(s): 6e0b0ea

Continue generation feature (#707)

Browse files

* Initial work on continue feature

* Move continue button

* Fix websearch with continue

* Make it work with every model

* Update src/routes/conversation/[id]/+server.ts

Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>

* fixes

* async all the things

* add reduce comment

* remove log

* Only show loading indicator if not continuing

---------

Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>

.env.template CHANGED
@@ -57,7 +57,8 @@ MODELS=`[
57
  "repetition_penalty": 1.2,
58
  "top_k": 50,
59
  "truncate": 3072,
60
- "max_new_tokens": 1024
 
61
  }
62
  },
63
  {
@@ -116,7 +117,8 @@ MODELS=`[
116
  "repetition_penalty": 1.2,
117
  "top_k": 50,
118
  "truncate": 4096,
119
- "max_new_tokens": 4096
 
120
  }
121
  },
122
  {
 
57
  "repetition_penalty": 1.2,
58
  "top_k": 50,
59
  "truncate": 3072,
60
+ "max_new_tokens": 1024,
61
+ "stop" : ["</s>", " </s><s>[INST] "]
62
  }
63
  },
64
  {
 
117
  "repetition_penalty": 1.2,
118
  "top_k": 50,
119
  "truncate": 4096,
120
+ "max_new_tokens": 4096,
121
+ "stop": [" </s><s>[INST] "]
122
  }
123
  },
124
  {
src/lib/buildPrompt.ts CHANGED
@@ -13,6 +13,7 @@ interface buildPromptOptions {
13
  webSearch?: WebSearch;
14
  preprompt?: string;
15
  files?: File[];
 
16
  }
17
 
18
  export async function buildPrompt({
@@ -22,37 +23,38 @@ export async function buildPrompt({
22
  preprompt,
23
  id,
24
  }: buildPromptOptions): Promise<string> {
 
 
25
  if (webSearch && webSearch.context) {
26
- const lastMsg = messages.slice(-1)[0];
27
- const messagesWithoutLastUsrMsg = messages.slice(0, -1);
28
- const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
29
 
 
 
30
  const previousQuestions =
31
  previousUserMessages.length > 0
32
  ? `Previous questions: \n${previousUserMessages
33
  .map(({ content }) => `- ${content}`)
34
  .join("\n")}`
35
  : "";
 
36
  const currentDate = format(new Date(), "MMMM d, yyyy");
37
- messages = [
38
- ...messagesWithoutLastUsrMsg,
39
- {
40
- from: "user",
41
- content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
42
  =====================
43
  ${webSearch.context}
44
  =====================
45
  ${previousQuestions}
46
- Answer the question: ${lastMsg.content}
47
- `,
48
- },
49
- ];
50
  }
51
-
52
  // section to handle potential files input
53
  if (model.multimodal) {
54
- messages = await Promise.all(
55
- messages.map(async (el) => {
56
  let content = el.content;
57
 
58
  if (el.from === "user") {
@@ -83,7 +85,7 @@ export async function buildPrompt({
83
 
84
  return (
85
  model
86
- .chatPromptRender({ messages, preprompt })
87
  // Not super precise, but it's truncated in the model's backend anyway
88
  .split(" ")
89
  .slice(-(model.parameters?.truncate ?? 0))
 
13
  webSearch?: WebSearch;
14
  preprompt?: string;
15
  files?: File[];
16
+ continue?: boolean;
17
  }
18
 
19
  export async function buildPrompt({
 
23
  preprompt,
24
  id,
25
  }: buildPromptOptions): Promise<string> {
26
+ let modifiedMessages = [...messages];
27
+
28
  if (webSearch && webSearch.context) {
29
+ // find index of the last user message
30
+ const lastUsrMsgIndex = modifiedMessages.map((el) => el.from).lastIndexOf("user");
 
31
 
32
+ // combine all the other previous questions into one string
33
+ const previousUserMessages = modifiedMessages.filter((el) => el.from === "user").slice(0, -1);
34
  const previousQuestions =
35
  previousUserMessages.length > 0
36
  ? `Previous questions: \n${previousUserMessages
37
  .map(({ content }) => `- ${content}`)
38
  .join("\n")}`
39
  : "";
40
+
41
  const currentDate = format(new Date(), "MMMM d, yyyy");
42
+
43
+ // update the last user message directly (that way if the last message is an assistant partial answer, we keep the beginning of that answer)
44
+ modifiedMessages[lastUsrMsgIndex] = {
45
+ from: "user",
46
+ content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
47
  =====================
48
  ${webSearch.context}
49
  =====================
50
  ${previousQuestions}
51
+ Answer the question: ${messages[lastUsrMsgIndex].content} `,
52
+ };
 
 
53
  }
 
54
  // section to handle potential files input
55
  if (model.multimodal) {
56
+ modifiedMessages = await Promise.all(
57
+ modifiedMessages.map(async (el) => {
58
  let content = el.content;
59
 
60
  if (el.from === "user") {
 
85
 
86
  return (
87
  model
88
+ .chatPromptRender({ messages: modifiedMessages, preprompt })
89
  // Not super precise, but it's truncated in the model's backend anyway
90
  .split(" ")
91
  .slice(-(model.parameters?.truncate ?? 0))
src/lib/components/ContinueBtn.svelte ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import CarbonContinue from "~icons/carbon/continue";
3
+
4
+ export let classNames = "";
5
+ </script>
6
+
7
+ <button
8
+ type="button"
9
+ on:click
10
+ class="btn flex h-8 rounded-lg border bg-white px-3 py-1 text-gray-500 shadow-sm transition-all hover:bg-gray-100 dark:border-gray-600 dark:bg-gray-700 dark:text-gray-300 dark:hover:bg-gray-600 {classNames}"
11
+ >
12
+ <CarbonContinue class="mr-2 text-xs " /> Continue
13
+ </button>
src/lib/components/chat/ChatMessage.svelte CHANGED
@@ -13,6 +13,7 @@
13
  import CarbonDownload from "~icons/carbon/download";
14
  import CarbonThumbsUp from "~icons/carbon/thumbs-up";
15
  import CarbonThumbsDown from "~icons/carbon/thumbs-down";
 
16
  import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
17
  import type { Model } from "$lib/types/Model";
18
 
 
13
  import CarbonDownload from "~icons/carbon/download";
14
  import CarbonThumbsUp from "~icons/carbon/thumbs-up";
15
  import CarbonThumbsDown from "~icons/carbon/thumbs-down";
16
+
17
  import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
18
  import type { Model } from "$lib/types/Model";
19
 
src/lib/components/chat/ChatMessages.svelte CHANGED
@@ -54,11 +54,12 @@
54
  webSearchMessages={i === messages.length - 1 ? webSearchMessages : []}
55
  on:retry
56
  on:vote
 
57
  />
58
  {:else}
59
  <ChatIntroduction {models} {currentModel} on:message />
60
  {/each}
61
- {#if pending}
62
  <ChatMessage
63
  message={{ from: "assistant", content: "", id: randomUUID() }}
64
  model={currentModel}
 
54
  webSearchMessages={i === messages.length - 1 ? webSearchMessages : []}
55
  on:retry
56
  on:vote
57
+ on:continue
58
  />
59
  {:else}
60
  <ChatIntroduction {models} {currentModel} on:message />
61
  {/each}
62
+ {#if pending && messages[messages.length - 1]?.from === "user"}
63
  <ChatMessage
64
  message={{ from: "assistant", content: "", id: randomUUID() }}
65
  model={currentModel}
src/lib/components/chat/ChatWindow.svelte CHANGED
@@ -24,6 +24,7 @@
24
  import UploadBtn from "../UploadBtn.svelte";
25
  import file2base64 from "$lib/utils/file2base64";
26
  import { useSettingsStore } from "$lib/stores/settings";
 
27
 
28
  export let messages: Message[] = [];
29
  export let loading = false;
@@ -48,6 +49,7 @@
48
  share: void;
49
  stop: void;
50
  retry: { id: Message["id"]; content: string };
 
51
  }>();
52
 
53
  const handleSubmit = () => {
@@ -124,6 +126,7 @@
124
  }
125
  }}
126
  on:vote
 
127
  on:retry={(ev) => {
128
  if (!loading) dispatch("retry", ev.detail);
129
  }}
@@ -173,8 +176,20 @@
173
  content: messages[messages.length - 1].content,
174
  })}
175
  />
176
- {:else if currentModel.multimodal}
177
- <UploadBtn bind:files classNames="ml-auto" />
 
 
 
 
 
 
 
 
 
 
 
 
178
  {/if}
179
  </div>
180
  <form
 
24
  import UploadBtn from "../UploadBtn.svelte";
25
  import file2base64 from "$lib/utils/file2base64";
26
  import { useSettingsStore } from "$lib/stores/settings";
27
+ import ContinueBtn from "../ContinueBtn.svelte";
28
 
29
  export let messages: Message[] = [];
30
  export let loading = false;
 
49
  share: void;
50
  stop: void;
51
  retry: { id: Message["id"]; content: string };
52
+ continue: { id: Message["id"] };
53
  }>();
54
 
55
  const handleSubmit = () => {
 
126
  }
127
  }}
128
  on:vote
129
+ on:continue
130
  on:retry={(ev) => {
131
  if (!loading) dispatch("retry", ev.detail);
132
  }}
 
176
  content: messages[messages.length - 1].content,
177
  })}
178
  />
179
+ {:else}
180
+ <div class="ml-auto gap-2">
181
+ {#if currentModel.multimodal}
182
+ <UploadBtn bind:files classNames="ml-auto" />
183
+ {/if}
184
+ {#if messages && messages[messages.length - 1]?.interrupted && !isReadOnly}
185
+ <ContinueBtn
186
+ on:click={() =>
187
+ dispatch("continue", {
188
+ id: messages[messages.length - 1].id,
189
+ })}
190
+ />
191
+ {/if}
192
+ </div>
193
  {/if}
194
  </div>
195
  <form
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -14,6 +14,7 @@ interface EndpointParameters {
14
  preprompt?: Conversation["preprompt"];
15
  _id?: Conversation["_id"];
16
  };
 
17
  }
18
 
19
  interface CommonEndpoint {
 
14
  preprompt?: Conversation["preprompt"];
15
  _id?: Conversation["_id"];
16
  };
17
+ continue?: boolean;
18
  }
19
 
20
  interface CommonEndpoint {
src/lib/server/endpoints/tgi/endpointTgi.ts CHANGED
@@ -15,8 +15,9 @@ export const endpointTgiParametersSchema = z.object({
15
 
16
  export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
17
  const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
18
- return async ({ conversation }) => {
19
- const prompt = await buildPrompt({
 
20
  messages: conversation.messages,
21
  webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
22
  preprompt: conversation.preprompt,
@@ -24,6 +25,16 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
24
  id: conversation._id,
25
  });
26
 
 
 
 
 
 
 
 
 
 
 
27
  return textGenerationStream(
28
  {
29
  parameters: { ...model.parameters, return_full_text: false },
 
15
 
16
  export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
17
  const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
18
+
19
+ return async ({ conversation, continue: messageContinue }) => {
20
+ let prompt = await buildPrompt({
21
  messages: conversation.messages,
22
  webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
23
  preprompt: conversation.preprompt,
 
25
  id: conversation._id,
26
  });
27
 
28
+ if (messageContinue) {
29
+ // start with the full prompt, and for each stop token, try to remove it from the end of the prompt
30
+ prompt = model.parameters.stop.reduce((acc: string, curr: string) => {
31
+ if (acc.endsWith(curr)) {
32
+ return acc.slice(0, acc.length - curr.length);
33
+ }
34
+ return acc;
35
+ }, prompt.trimEnd());
36
+ }
37
+
38
  return textGenerationStream(
39
  {
40
  parameters: { ...model.parameters, return_full_text: false },
src/lib/types/Message.ts CHANGED
@@ -11,4 +11,5 @@ export type Message = Partial<Timestamps> & {
11
  webSearch?: WebSearch;
12
  score?: -1 | 0 | 1;
13
  files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
 
14
  };
 
11
  webSearch?: WebSearch;
12
  score?: -1 | 0 | 1;
13
  files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
14
+ interrupted?: boolean;
15
  };
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -64,9 +64,17 @@
64
  }
65
  }
66
  // this function is used to send new message to the backends
67
- async function writeMessage(message: string, messageId = randomUUID()) {
68
- if (!message.trim()) return;
69
-
 
 
 
 
 
 
 
 
70
  try {
71
  $isAborted = false;
72
  loading = true;
@@ -74,13 +82,21 @@
74
 
75
  // first we check if the messageId already exists, indicating a retry
76
 
77
- let retryMessageIndex = messages.findIndex((msg) => msg.id === messageId);
78
- const isRetry = retryMessageIndex !== -1;
79
- // if it's not a retry we just use the whole array
80
- if (!isRetry) {
81
- retryMessageIndex = messages.length;
 
 
 
 
 
 
82
  }
83
 
 
 
84
  const module = await import("browser-image-resizer");
85
 
86
  // currently, only IDEFICS is supported by TGI
@@ -99,15 +115,31 @@
99
  );
100
 
101
  // slice up to the point of the retry
102
- messages = [
103
- ...messages.slice(0, retryMessageIndex),
104
- {
105
- from: "user",
106
- content: message,
107
- id: messageId,
108
- files: isRetry ? messages[retryMessageIndex].files : resizedImages,
109
- },
110
- ];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  files = [];
113
 
@@ -115,9 +147,10 @@
115
  method: "POST",
116
  headers: { "Content-Type": "application/json" },
117
  body: JSON.stringify({
118
- inputs: message,
119
  id: messageId,
120
  is_retry: isRetry,
 
121
  web_search: $webSearchParameters.useSearch,
122
  files: isRetry ? undefined : resizedImages,
123
  }),
@@ -282,37 +315,54 @@
282
  // only used in case of creating new conversations (from the parent POST endpoint)
283
  if ($pendingMessage) {
284
  files = $pendingMessage.files;
285
- await writeMessage($pendingMessage.content);
286
  $pendingMessage = undefined;
287
  }
288
  });
289
 
290
  async function onMessage(event: CustomEvent<string>) {
291
  if (!data.shared) {
292
- writeMessage(event.detail);
293
  } else {
294
- convFromShared()
295
  .then(async (convId) => {
296
  await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
297
  })
298
- .then(() => writeMessage(event.detail))
299
  .finally(() => (loading = false));
300
  }
301
  }
302
 
303
  async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) {
304
  if (!data.shared) {
305
- writeMessage(event.detail.content, event.detail.id);
 
 
 
 
306
  } else {
307
- convFromShared()
308
  .then(async (convId) => {
309
  await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
310
  })
311
- .then(() => writeMessage(event.detail.content, event.detail.id))
 
 
 
 
 
 
 
312
  .finally(() => (loading = false));
313
  }
314
  }
315
 
 
 
 
 
 
 
316
  $: $page.params.id, (($isAborted = true), (loading = false));
317
  $: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title;
318
  </script>
@@ -337,6 +387,7 @@
337
  bind:files
338
  on:message={onMessage}
339
  on:retry={onRetry}
 
340
  on:vote={(event) => voteMessage(event.detail.score, event.detail.id)}
341
  on:share={() => shareConversation($page.params.id, data.title)}
342
  on:stop={() => (($isAborted = true), (loading = false))}
 
64
  }
65
  }
66
  // this function is used to send new message to the backends
67
+ async function writeMessage({
68
+ prompt,
69
+ messageId = randomUUID(),
70
+ isRetry = false,
71
+ isContinue = false,
72
+ }: {
73
+ prompt?: string;
74
+ messageId?: ReturnType<typeof randomUUID>;
75
+ isRetry?: boolean;
76
+ isContinue?: boolean;
77
+ }): Promise<void> {
78
  try {
79
  $isAborted = false;
80
  loading = true;
 
82
 
83
  // first we check if the messageId already exists, indicating a retry
84
 
85
+ let msgIndex = messages.findIndex((msg) => msg.id === messageId);
86
+
87
+ if (msgIndex === -1) {
88
+ msgIndex = messages.length - 1;
89
+ }
90
+ if (isRetry && messages[msgIndex].from === "assistant") {
91
+ throw new Error("Trying to retry a message that is not from user");
92
+ }
93
+
94
+ if (isContinue && messages[msgIndex].from === "user") {
95
+ throw new Error("Trying to continue a message that is not from assistant");
96
  }
97
 
98
+ // const isNewMessage = !isRetry && !isContinue;
99
+
100
  const module = await import("browser-image-resizer");
101
 
102
  // currently, only IDEFICS is supported by TGI
 
115
  );
116
 
117
  // slice up to the point of the retry
118
+ if (isRetry) {
119
+ messages = [
120
+ ...messages.slice(0, msgIndex),
121
+ {
122
+ from: "user",
123
+ content: messages[msgIndex].content,
124
+ id: messageId,
125
+ files: messages[msgIndex].files,
126
+ },
127
+ ];
128
+ } else if (!isContinue) {
129
+ // or add a new message if its not a continue request
130
+ if (!prompt) {
131
+ throw new Error("Prompt is undefined");
132
+ }
133
+ messages = [
134
+ ...messages,
135
+ {
136
+ from: "user",
137
+ content: prompt ?? "",
138
+ id: messageId,
139
+ files: resizedImages,
140
+ },
141
+ ];
142
+ }
143
 
144
  files = [];
145
 
 
147
  method: "POST",
148
  headers: { "Content-Type": "application/json" },
149
  body: JSON.stringify({
150
+ inputs: prompt,
151
  id: messageId,
152
  is_retry: isRetry,
153
+ is_continue: isContinue,
154
  web_search: $webSearchParameters.useSearch,
155
  files: isRetry ? undefined : resizedImages,
156
  }),
 
315
  // only used in case of creating new conversations (from the parent POST endpoint)
316
  if ($pendingMessage) {
317
  files = $pendingMessage.files;
318
+ await writeMessage({ prompt: $pendingMessage.content });
319
  $pendingMessage = undefined;
320
  }
321
  });
322
 
323
  async function onMessage(event: CustomEvent<string>) {
324
  if (!data.shared) {
325
+ await writeMessage({ prompt: event.detail });
326
  } else {
327
+ await convFromShared()
328
  .then(async (convId) => {
329
  await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
330
  })
331
+ .then(async () => await writeMessage({ prompt: event.detail }))
332
  .finally(() => (loading = false));
333
  }
334
  }
335
 
336
  async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) {
337
  if (!data.shared) {
338
+ await writeMessage({
339
+ prompt: event.detail.content,
340
+ messageId: event.detail.id,
341
+ isRetry: true,
342
+ });
343
  } else {
344
+ await convFromShared()
345
  .then(async (convId) => {
346
  await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
347
  })
348
+ .then(
349
+ async () =>
350
+ await writeMessage({
351
+ prompt: event.detail.content,
352
+ messageId: event.detail.id,
353
+ isRetry: true,
354
+ })
355
+ )
356
  .finally(() => (loading = false));
357
  }
358
  }
359
 
360
+ async function onContinue(event: CustomEvent<{ id: Message["id"] }>) {
361
+ if (!data.shared) {
362
+ writeMessage({ messageId: event.detail.id, isContinue: true });
363
+ }
364
+ }
365
+
366
  $: $page.params.id, (($isAborted = true), (loading = false));
367
  $: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title;
368
  </script>
 
387
  bind:files
388
  on:message={onMessage}
389
  on:retry={onRetry}
390
+ on:continue={onContinue}
391
  on:vote={(event) => voteMessage(event.detail.score, event.detail.id)}
392
  on:share={() => shareConversation($page.params.id, data.title)}
393
  on:stop={() => (($isAborted = true), (loading = false))}
src/routes/conversation/[id]/+server.ts CHANGED
@@ -91,14 +91,16 @@ export async function POST({ request, locals, params, getClientAddress }) {
91
  const {
92
  inputs: newPrompt,
93
  id: messageId,
94
- is_retry,
 
95
  web_search: webSearch,
96
  files: b64files,
97
  } = z
98
  .object({
99
- inputs: z.string().trim().min(1),
100
  id: z.optional(z.string().uuid()),
101
  is_retry: z.optional(z.boolean()),
 
102
  web_search: z.optional(z.boolean()),
103
  files: z.optional(z.array(z.string())),
104
  })
@@ -136,38 +138,50 @@ export async function POST({ request, locals, params, getClientAddress }) {
136
  hashes = await Promise.all(files.map(async (file) => await uploadFile(file, conv)));
137
  }
138
 
 
 
 
 
 
139
  // get the list of messages
140
  // while checking for retries
141
  let messages = (() => {
142
- if (is_retry && messageId) {
 
143
  // if the message is a retry, replace the message and remove the messages after it
144
  let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
 
145
  if (retryMessageIdx === -1) {
146
  retryMessageIdx = conv.messages.length;
147
  }
 
148
  return [
149
  ...conv.messages.slice(0, retryMessageIdx),
150
  {
151
- content: newPrompt,
152
  from: "user",
153
  id: messageId as Message["id"],
154
  updatedAt: new Date(),
155
  files: conv.messages[retryMessageIdx]?.files,
156
  },
157
  ];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  } // else append the message at the bottom
159
-
160
- return [
161
- ...conv.messages,
162
- {
163
- content: newPrompt,
164
- from: "user",
165
- id: (messageId as Message["id"]) || crypto.randomUUID(),
166
- createdAt: new Date(),
167
- updatedAt: new Date(),
168
- files: hashes,
169
- },
170
- ];
171
  })() satisfies Message[];
172
 
173
  await collections.conversations.updateOne(
@@ -183,10 +197,14 @@ export async function POST({ request, locals, params, getClientAddress }) {
183
  }
184
  );
185
 
 
 
186
  // we now build the stream
187
  const stream = new ReadableStream({
188
  async start(controller) {
189
- const updates: MessageUpdate[] = [];
 
 
190
 
191
  function update(newUpdate: MessageUpdate) {
192
  if (newUpdate.type !== "stream") {
@@ -209,7 +227,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
209
  const summarizeIfNeeded = (async () => {
210
  if (conv.title === "New Chat" && messages.length === 1) {
211
  try {
212
- conv.title = (await summarize(newPrompt)) ?? conv.title;
213
  update({ type: "status", status: "title", message: conv.title });
214
  } catch (e) {
215
  console.error(e);
@@ -232,17 +250,22 @@ export async function POST({ request, locals, params, getClientAddress }) {
232
 
233
  let webSearchResults: WebSearch | undefined;
234
 
235
- if (webSearch) {
236
- webSearchResults = await runWebSearch(conv, newPrompt, update);
 
 
 
237
  }
238
 
239
- messages[messages.length - 1].webSearch = webSearchResults;
240
-
241
  conv.messages = messages;
242
 
 
 
 
 
243
  try {
244
  const endpoint = await model.getEndpoint();
245
- for await (const output of await endpoint({ conversation: conv })) {
246
  // if not generated_text is here it means the generation is not done
247
  if (!output.generated_text) {
248
  // else we get the next token
@@ -292,7 +315,8 @@ export async function POST({ request, locals, params, getClientAddress }) {
292
  ...messages.slice(0, -1),
293
  {
294
  ...messages[messages.length - 1],
295
- content: output.generated_text,
 
296
  updates,
297
  updatedAt: new Date(),
298
  },
@@ -302,6 +326,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
302
  } catch (e) {
303
  update({ type: "status", status: "error", message: (e as Error).message });
304
  }
 
305
  await collections.conversations.updateOne(
306
  {
307
  _id: convId,
@@ -315,6 +340,9 @@ export async function POST({ request, locals, params, getClientAddress }) {
315
  }
316
  );
317
 
 
 
 
318
  update({
319
  type: "finalAnswer",
320
  text: messages[messages.length - 1].content,
@@ -324,18 +352,20 @@ export async function POST({ request, locals, params, getClientAddress }) {
324
  return;
325
  },
326
  async cancel() {
327
- await collections.conversations.updateOne(
328
- {
329
- _id: convId,
330
- },
331
- {
332
- $set: {
333
- messages,
334
- title: conv.title,
335
- updatedAt: new Date(),
336
  },
337
- }
338
- );
 
 
 
 
 
 
 
339
  },
340
  });
341
 
 
91
  const {
92
  inputs: newPrompt,
93
  id: messageId,
94
+ is_retry: isRetry,
95
+ is_continue: isContinue,
96
  web_search: webSearch,
97
  files: b64files,
98
  } = z
99
  .object({
100
+ inputs: z.optional(z.string().trim().min(1)),
101
  id: z.optional(z.string().uuid()),
102
  is_retry: z.optional(z.boolean()),
103
+ is_continue: z.optional(z.boolean()),
104
  web_search: z.optional(z.boolean()),
105
  files: z.optional(z.array(z.string())),
106
  })
 
138
  hashes = await Promise.all(files.map(async (file) => await uploadFile(file, conv)));
139
  }
140
 
141
+ // can only call isContinue on the last message id
142
+ if (isContinue && conv.messages[conv.messages.length - 1].id !== messageId) {
143
+ throw error(400, "Can only continue the last message");
144
+ }
145
+
146
  // get the list of messages
147
  // while checking for retries
148
  let messages = (() => {
149
+ // for retries we slice and rewrite the last user message
150
+ if (isRetry && messageId) {
151
  // if the message is a retry, replace the message and remove the messages after it
152
  let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId);
153
+
154
  if (retryMessageIdx === -1) {
155
  retryMessageIdx = conv.messages.length;
156
  }
157
+
158
  return [
159
  ...conv.messages.slice(0, retryMessageIdx),
160
  {
161
+ content: conv.messages[retryMessageIdx]?.content,
162
  from: "user",
163
  id: messageId as Message["id"],
164
  updatedAt: new Date(),
165
  files: conv.messages[retryMessageIdx]?.files,
166
  },
167
  ];
168
+ } else if (isContinue && messageId) {
169
+ // for continue we do nothing and expand the last assistant message
170
+ return conv.messages;
171
+ } else {
172
+ // in normal conversation we add an extra user message
173
+ return [
174
+ ...conv.messages,
175
+ {
176
+ content: newPrompt ?? "",
177
+ from: "user",
178
+ id: (messageId as Message["id"]) || crypto.randomUUID(),
179
+ createdAt: new Date(),
180
+ updatedAt: new Date(),
181
+ files: hashes,
182
+ },
183
+ ];
184
  } // else append the message at the bottom
 
 
 
 
 
 
 
 
 
 
 
 
185
  })() satisfies Message[];
186
 
187
  await collections.conversations.updateOne(
 
197
  }
198
  );
199
 
200
+ let doneStreaming = false;
201
+
202
  // we now build the stream
203
  const stream = new ReadableStream({
204
  async start(controller) {
205
+ const updates: MessageUpdate[] = isContinue
206
+ ? conv.messages[conv.messages.length - 1].updates ?? []
207
+ : [];
208
 
209
  function update(newUpdate: MessageUpdate) {
210
  if (newUpdate.type !== "stream") {
 
227
  const summarizeIfNeeded = (async () => {
228
  if (conv.title === "New Chat" && messages.length === 1) {
229
  try {
230
+ conv.title = (await summarize(messages[0].content)) ?? conv.title;
231
  update({ type: "status", status: "title", message: conv.title });
232
  } catch (e) {
233
  console.error(e);
 
250
 
251
  let webSearchResults: WebSearch | undefined;
252
 
253
+ if (webSearch && !isContinue) {
254
+ webSearchResults = await runWebSearch(conv, messages[messages.length - 1].content, update);
255
+ messages[messages.length - 1].webSearch = webSearchResults;
256
+ } else if (isContinue) {
257
+ webSearchResults = messages[messages.length - 1].webSearch;
258
  }
259
 
 
 
260
  conv.messages = messages;
261
 
262
+ const previousContent = isContinue
263
+ ? conv.messages.find((message) => message.id === messageId)?.content ?? ""
264
+ : "";
265
+
266
  try {
267
  const endpoint = await model.getEndpoint();
268
+ for await (const output of await endpoint({ conversation: conv, continue: isContinue })) {
269
  // if not generated_text is here it means the generation is not done
270
  if (!output.generated_text) {
271
  // else we get the next token
 
315
  ...messages.slice(0, -1),
316
  {
317
  ...messages[messages.length - 1],
318
+ content: previousContent + output.generated_text,
319
+ interrupted: !output.token.special, // if its a special token it finished on its own, else it was interrupted
320
  updates,
321
  updatedAt: new Date(),
322
  },
 
326
  } catch (e) {
327
  update({ type: "status", status: "error", message: (e as Error).message });
328
  }
329
+
330
  await collections.conversations.updateOne(
331
  {
332
  _id: convId,
 
340
  }
341
  );
342
 
343
+ // used to detect if cancel() is called bc of interrupt or just because the connection closes
344
+ doneStreaming = true;
345
+
346
  update({
347
  type: "finalAnswer",
348
  text: messages[messages.length - 1].content,
 
352
  return;
353
  },
354
  async cancel() {
355
+ if (!doneStreaming) {
356
+ await collections.conversations.updateOne(
357
+ {
358
+ _id: convId,
 
 
 
 
 
359
  },
360
+ {
361
+ $set: {
362
+ messages,
363
+ title: conv.title,
364
+ updatedAt: new Date(),
365
+ },
366
+ }
367
+ );
368
+ }
369
  },
370
  });
371