Stop generation button (closes #86) (#88)
Browse filesCo-authored-by: coyotte508 <coyotte508@gmail.com>
- src/lib/components/StopGeneratingBtn.svelte +17 -0
- src/lib/components/chat/ChatWindow.svelte +9 -3
- src/lib/server/abortedGenerations.ts +29 -0
- src/lib/server/database.ts +5 -1
- src/lib/types/AbortedGeneration.ts +9 -0
- src/lib/utils/concatUint8Arrays.ts +12 -0
- src/routes/conversation/[id]/+page.svelte +21 -1
- src/routes/conversation/[id]/+server.ts +37 -9
- src/routes/conversation/[id]/stop-generating/+server.ts +27 -0
src/lib/components/StopGeneratingBtn.svelte
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
import CarbonPause from "~icons/carbon/pause-filled";
|
3 |
+
|
4 |
+
export let visible: boolean = false;
|
5 |
+
export let className = "";
|
6 |
+
</script>
|
7 |
+
|
8 |
+
<button
|
9 |
+
type="button"
|
10 |
+
on:click
|
11 |
+
class="absolute btn flex rounded-lg border py-1 px-3 shadow-sm bg-white dark:bg-gray-700 hover:bg-gray-100 dark:hover:bg-gray-600 dark:border-gray-600 transition-all
|
12 |
+
{className}
|
13 |
+
{visible ? 'opacity-100 visible' : 'opacity-0 invisible'}
|
14 |
+
"
|
15 |
+
>
|
16 |
+
<CarbonPause class="mr-1 -ml-1 w-[1.1875rem] h-[1.25rem] text-gray-400" /> Stop generating
|
17 |
+
</button>
|
src/lib/components/chat/ChatWindow.svelte
CHANGED
@@ -3,10 +3,11 @@
|
|
3 |
import { createEventDispatcher } from "svelte";
|
4 |
|
5 |
import CarbonSendAltFilled from "~icons/carbon/send-alt-filled";
|
|
|
6 |
|
7 |
import ChatMessages from "./ChatMessages.svelte";
|
8 |
import ChatInput from "./ChatInput.svelte";
|
9 |
-
import
|
10 |
import { PUBLIC_MODEL_ID, PUBLIC_MODEL_NAME } from "$env/static/public";
|
11 |
|
12 |
export let messages: Message[] = [];
|
@@ -16,7 +17,7 @@
|
|
16 |
|
17 |
let message: string;
|
18 |
|
19 |
-
const dispatch = createEventDispatcher<{ message: string; share: void }>();
|
20 |
|
21 |
const handleSubmit = () => {
|
22 |
if (loading) return;
|
@@ -28,8 +29,13 @@
|
|
28 |
<div class="relative min-h-0 min-w-0">
|
29 |
<ChatMessages {loading} {pending} {messages} on:message />
|
30 |
<div
|
31 |
-
class="flex flex-col pointer-events-none [&>*]:pointer-events-auto max-md:border-t dark:border-gray-800 items-center max-md:dark:bg-gray-900 max-md:bg-white bg-gradient-to-t from-white via-white/80 to-white/0 dark:from-gray-900 dark:via-gray-80 dark:to-gray-900/0 justify-center absolute inset-x-0 max-w-3xl xl:max-w-4xl mx-auto px-3.5 sm:px-5 bottom-0 py-4 md:py-8 w-full"
|
32 |
>
|
|
|
|
|
|
|
|
|
|
|
33 |
<form
|
34 |
on:submit|preventDefault={handleSubmit}
|
35 |
class="w-full relative flex items-center rounded-xl flex-1 max-w-4xl border bg-gray-100 focus-within:border-gray-300 dark:bg-gray-700 dark:border-gray-600 dark:focus-within:border-gray-500 "
|
|
|
3 |
import { createEventDispatcher } from "svelte";
|
4 |
|
5 |
import CarbonSendAltFilled from "~icons/carbon/send-alt-filled";
|
6 |
+
import CarbonExport from "~icons/carbon/export";
|
7 |
|
8 |
import ChatMessages from "./ChatMessages.svelte";
|
9 |
import ChatInput from "./ChatInput.svelte";
|
10 |
+
import StopGeneratingBtn from "../StopGeneratingBtn.svelte";
|
11 |
import { PUBLIC_MODEL_ID, PUBLIC_MODEL_NAME } from "$env/static/public";
|
12 |
|
13 |
export let messages: Message[] = [];
|
|
|
17 |
|
18 |
let message: string;
|
19 |
|
20 |
+
const dispatch = createEventDispatcher<{ message: string; share: void; stop: void }>();
|
21 |
|
22 |
const handleSubmit = () => {
|
23 |
if (loading) return;
|
|
|
29 |
<div class="relative min-h-0 min-w-0">
|
30 |
<ChatMessages {loading} {pending} {messages} on:message />
|
31 |
<div
|
32 |
+
class="flex flex-col pointer-events-none [&>*]:pointer-events-auto max-md:border-t dark:border-gray-800 items-center max-md:dark:bg-gray-900 max-md:bg-white bg-gradient-to-t from-white via-white/80 to-white/0 dark:from-gray-900 dark:via-gray-80 dark:to-gray-900/0 justify-center absolute inset-x-0 max-w-3xl xl:max-w-4xl mx-auto px-3.5 sm:px-5 bottom-0 py-4 md:py-8 w-full z-0"
|
33 |
>
|
34 |
+
<StopGeneratingBtn
|
35 |
+
visible={loading}
|
36 |
+
className="right-5 mr-[1px] md:mr-0 md:right-7 top-6 md:top-10 z-10"
|
37 |
+
on:click={() => dispatch("stop")}
|
38 |
+
/>
|
39 |
<form
|
40 |
on:submit|preventDefault={handleSubmit}
|
41 |
class="w-full relative flex items-center rounded-xl flex-1 max-w-4xl border bg-gray-100 focus-within:border-gray-300 dark:bg-gray-700 dark:border-gray-600 dark:focus-within:border-gray-500 "
|
src/lib/server/abortedGenerations.ts
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Shouldn't be needed if we dove into sveltekit internals, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850
|
2 |
+
|
3 |
+
import { setTimeout } from "node:timers/promises";
|
4 |
+
import { collections } from "./database";
|
5 |
+
|
6 |
+
let closed = false;
|
7 |
+
process.on("SIGINT", () => {
|
8 |
+
closed = true;
|
9 |
+
});
|
10 |
+
|
11 |
+
export let abortedGenerations: Map<string, Date> = new Map();
|
12 |
+
|
13 |
+
async function maintainAbortedGenerations() {
|
14 |
+
while (!closed) {
|
15 |
+
await setTimeout(1000);
|
16 |
+
|
17 |
+
try {
|
18 |
+
const aborts = await collections.abortedGenerations.find({}).sort({ createdAt: 1 }).toArray();
|
19 |
+
|
20 |
+
abortedGenerations = new Map(
|
21 |
+
aborts.map(({ conversationId, createdAt }) => [conversationId.toString(), createdAt])
|
22 |
+
);
|
23 |
+
} catch (err) {
|
24 |
+
console.error(err);
|
25 |
+
}
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
maintainAbortedGenerations();
|
src/lib/server/database.ts
CHANGED
@@ -2,6 +2,7 @@ import { MONGODB_URL, MONGODB_DB_NAME } from "$env/static/private";
|
|
2 |
import { MongoClient } from "mongodb";
|
3 |
import type { Conversation } from "$lib/types/Conversation";
|
4 |
import type { SharedConversation } from "$lib/types/SharedConversation";
|
|
|
5 |
|
6 |
const client = new MongoClient(MONGODB_URL, {
|
7 |
// directConnection: true
|
@@ -13,11 +14,14 @@ const db = client.db(MONGODB_DB_NAME);
|
|
13 |
|
14 |
const conversations = db.collection<Conversation>("conversations");
|
15 |
const sharedConversations = db.collection<SharedConversation>("sharedConversations");
|
|
|
16 |
|
17 |
export { client, db };
|
18 |
-
export const collections = { conversations, sharedConversations };
|
19 |
|
20 |
client.on("open", () => {
|
21 |
conversations.createIndex({ sessionId: 1, updatedAt: -1 });
|
|
|
|
|
22 |
sharedConversations.createIndex({ hash: 1 }, { unique: true });
|
23 |
});
|
|
|
2 |
import { MongoClient } from "mongodb";
|
3 |
import type { Conversation } from "$lib/types/Conversation";
|
4 |
import type { SharedConversation } from "$lib/types/SharedConversation";
|
5 |
+
import type { AbortedGeneration } from "$lib/types/AbortedGeneration";
|
6 |
|
7 |
const client = new MongoClient(MONGODB_URL, {
|
8 |
// directConnection: true
|
|
|
14 |
|
15 |
const conversations = db.collection<Conversation>("conversations");
|
16 |
const sharedConversations = db.collection<SharedConversation>("sharedConversations");
|
17 |
+
const abortedGenerations = db.collection<AbortedGeneration>("abortedGenerations");
|
18 |
|
19 |
export { client, db };
|
20 |
+
export const collections = { conversations, sharedConversations, abortedGenerations };
|
21 |
|
22 |
client.on("open", () => {
|
23 |
conversations.createIndex({ sessionId: 1, updatedAt: -1 });
|
24 |
+
abortedGenerations.createIndex({ updatedAt: 1 }, { expireAfterSeconds: 30 });
|
25 |
+
abortedGenerations.createIndex({ conversationId: 1 }, { unique: true });
|
26 |
sharedConversations.createIndex({ hash: 1 }, { unique: true });
|
27 |
});
|
src/lib/types/AbortedGeneration.ts
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Ideally shouldn't be needed, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850
|
2 |
+
|
3 |
+
import type { Conversation } from "./Conversation";
|
4 |
+
|
5 |
+
export interface AbortedGeneration {
|
6 |
+
createdAt: Date;
|
7 |
+
updatedAt: Date;
|
8 |
+
conversationId: Conversation["_id"];
|
9 |
+
}
|
src/lib/utils/concatUint8Arrays.ts
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { sum } from "./sum";
|
2 |
+
|
3 |
+
export function concatUint8Arrays(arrays: Uint8Array[]): Uint8Array {
|
4 |
+
const totalLength = sum(arrays.map((a) => a.length));
|
5 |
+
const result = new Uint8Array(totalLength);
|
6 |
+
let offset = 0;
|
7 |
+
for (const array of arrays) {
|
8 |
+
result.set(array, offset);
|
9 |
+
offset += array.length;
|
10 |
+
}
|
11 |
+
return result;
|
12 |
+
}
|
src/routes/conversation/[id]/+page.svelte
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
|
17 |
let messages = data.messages;
|
18 |
let lastLoadedMessages = data.messages;
|
|
|
19 |
|
20 |
// Since we modify the messages array locally, we don't want to reset it if an old version is passed
|
21 |
$: if (data.messages !== lastLoadedMessages) {
|
@@ -55,7 +56,24 @@
|
|
55 |
for await (const data of response) {
|
56 |
pending = false;
|
57 |
|
58 |
-
if (!data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
// final message
|
61 |
if (data.generated_text) {
|
@@ -91,6 +109,7 @@
|
|
91 |
if (!message.trim()) return;
|
92 |
|
93 |
try {
|
|
|
94 |
loading = true;
|
95 |
pending = true;
|
96 |
|
@@ -130,4 +149,5 @@
|
|
130 |
{messages}
|
131 |
on:message={(message) => writeMessage(message.detail)}
|
132 |
on:share={() => shareConversation($page.params.id, data.title)}
|
|
|
133 |
/>
|
|
|
16 |
|
17 |
let messages = data.messages;
|
18 |
let lastLoadedMessages = data.messages;
|
19 |
+
let isAborted = false;
|
20 |
|
21 |
// Since we modify the messages array locally, we don't want to reset it if an old version is passed
|
22 |
$: if (data.messages !== lastLoadedMessages) {
|
|
|
56 |
for await (const data of response) {
|
57 |
pending = false;
|
58 |
|
59 |
+
if (!data) {
|
60 |
+
break;
|
61 |
+
}
|
62 |
+
|
63 |
+
if (conversationId !== $page.params.id) {
|
64 |
+
fetch(`${base}/conversation/${conversationId}/stop-generating`, {
|
65 |
+
method: "POST",
|
66 |
+
}).catch(console.error);
|
67 |
+
break;
|
68 |
+
}
|
69 |
+
|
70 |
+
if (isAborted) {
|
71 |
+
isAborted = false;
|
72 |
+
fetch(`${base}/conversation/${conversationId}/stop-generating`, {
|
73 |
+
method: "POST",
|
74 |
+
}).catch(console.error);
|
75 |
+
break;
|
76 |
+
}
|
77 |
|
78 |
// final message
|
79 |
if (data.generated_text) {
|
|
|
109 |
if (!message.trim()) return;
|
110 |
|
111 |
try {
|
112 |
+
isAborted = false;
|
113 |
loading = true;
|
114 |
pending = true;
|
115 |
|
|
|
149 |
{messages}
|
150 |
on:message={(message) => writeMessage(message.detail)}
|
151 |
on:share={() => shareConversation($page.params.id, data.title)}
|
152 |
+
on:stop={() => (isAborted = true)}
|
153 |
/>
|
src/routes/conversation/[id]/+server.ts
CHANGED
@@ -1,18 +1,21 @@
|
|
1 |
import { PUBLIC_SEP_TOKEN } from "$env/static/public";
|
2 |
import { buildPrompt } from "$lib/buildPrompt.js";
|
|
|
3 |
import { collections } from "$lib/server/database.js";
|
4 |
import { modelEndpoint } from "$lib/server/modelEndpoint.js";
|
5 |
import type { Message } from "$lib/types/Message.js";
|
|
|
6 |
import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
|
7 |
-
import { sum } from "$lib/utils/sum";
|
8 |
import { trimPrefix } from "$lib/utils/trimPrefix.js";
|
9 |
import { trimSuffix } from "$lib/utils/trimSuffix.js";
|
|
|
10 |
import { error } from "@sveltejs/kit";
|
11 |
import { ObjectId } from "mongodb";
|
12 |
|
13 |
export async function POST({ request, fetch, locals, params }) {
|
14 |
// todo: add validation on params.id
|
15 |
const convId = new ObjectId(params.id);
|
|
|
16 |
|
17 |
const conv = await collections.conversations.findOne({
|
18 |
_id: convId,
|
@@ -31,6 +34,8 @@ export async function POST({ request, fetch, locals, params }) {
|
|
31 |
|
32 |
const randomEndpoint = modelEndpoint();
|
33 |
|
|
|
|
|
34 |
const resp = await fetch(randomEndpoint.endpoint, {
|
35 |
headers: {
|
36 |
"Content-Type": request.headers.get("Content-Type") ?? "application/json",
|
@@ -41,12 +46,13 @@ export async function POST({ request, fetch, locals, params }) {
|
|
41 |
...json,
|
42 |
inputs: prompt,
|
43 |
}),
|
|
|
44 |
});
|
45 |
|
46 |
const [stream1, stream2] = resp.body!.tee();
|
47 |
|
48 |
async function saveMessage() {
|
49 |
-
let generated_text = await parseGeneratedText(stream2);
|
50 |
|
51 |
// We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text
|
52 |
if (generated_text.startsWith(prompt)) {
|
@@ -97,19 +103,41 @@ export async function DELETE({ locals, params }) {
|
|
97 |
return new Response();
|
98 |
}
|
99 |
|
100 |
-
async function parseGeneratedText(
|
|
|
|
|
|
|
|
|
|
|
101 |
const inputs: Uint8Array[] = [];
|
102 |
for await (const input of streamToAsyncIterable(stream)) {
|
103 |
inputs.push(input);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
}
|
105 |
|
106 |
// Merge inputs into a single Uint8Array
|
107 |
-
const completeInput =
|
108 |
-
let offset = 0;
|
109 |
-
for (const input of inputs) {
|
110 |
-
completeInput.set(input, offset);
|
111 |
-
offset += input.length;
|
112 |
-
}
|
113 |
|
114 |
// Get last line starting with "data:" and parse it as JSON to get the generated text
|
115 |
const message = new TextDecoder().decode(completeInput);
|
|
|
1 |
import { PUBLIC_SEP_TOKEN } from "$env/static/public";
|
2 |
import { buildPrompt } from "$lib/buildPrompt.js";
|
3 |
+
import { abortedGenerations } from "$lib/server/abortedGenerations.js";
|
4 |
import { collections } from "$lib/server/database.js";
|
5 |
import { modelEndpoint } from "$lib/server/modelEndpoint.js";
|
6 |
import type { Message } from "$lib/types/Message.js";
|
7 |
+
import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays.js";
|
8 |
import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
|
|
|
9 |
import { trimPrefix } from "$lib/utils/trimPrefix.js";
|
10 |
import { trimSuffix } from "$lib/utils/trimSuffix.js";
|
11 |
+
import type { TextGenerationStreamOutput } from "@huggingface/inference";
|
12 |
import { error } from "@sveltejs/kit";
|
13 |
import { ObjectId } from "mongodb";
|
14 |
|
15 |
export async function POST({ request, fetch, locals, params }) {
|
16 |
// todo: add validation on params.id
|
17 |
const convId = new ObjectId(params.id);
|
18 |
+
const date = new Date();
|
19 |
|
20 |
const conv = await collections.conversations.findOne({
|
21 |
_id: convId,
|
|
|
34 |
|
35 |
const randomEndpoint = modelEndpoint();
|
36 |
|
37 |
+
const abortController = new AbortController();
|
38 |
+
|
39 |
const resp = await fetch(randomEndpoint.endpoint, {
|
40 |
headers: {
|
41 |
"Content-Type": request.headers.get("Content-Type") ?? "application/json",
|
|
|
46 |
...json,
|
47 |
inputs: prompt,
|
48 |
}),
|
49 |
+
signal: abortController.signal,
|
50 |
});
|
51 |
|
52 |
const [stream1, stream2] = resp.body!.tee();
|
53 |
|
54 |
async function saveMessage() {
|
55 |
+
let generated_text = await parseGeneratedText(stream2, convId, date, abortController);
|
56 |
|
57 |
// We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text
|
58 |
if (generated_text.startsWith(prompt)) {
|
|
|
103 |
return new Response();
|
104 |
}
|
105 |
|
106 |
+
async function parseGeneratedText(
|
107 |
+
stream: ReadableStream,
|
108 |
+
conversationId: ObjectId,
|
109 |
+
promptedAt: Date,
|
110 |
+
abortController: AbortController
|
111 |
+
): Promise<string> {
|
112 |
const inputs: Uint8Array[] = [];
|
113 |
for await (const input of streamToAsyncIterable(stream)) {
|
114 |
inputs.push(input);
|
115 |
+
|
116 |
+
const date = abortedGenerations.get(conversationId.toString());
|
117 |
+
|
118 |
+
if (date && date > promptedAt) {
|
119 |
+
abortController.abort("Cancelled by user");
|
120 |
+
const completeInput = concatUint8Arrays(inputs);
|
121 |
+
|
122 |
+
const lines = new TextDecoder()
|
123 |
+
.decode(completeInput)
|
124 |
+
.split("\n")
|
125 |
+
.filter((line) => line.startsWith("data:"));
|
126 |
+
|
127 |
+
const tokens = lines.map((line) => {
|
128 |
+
try {
|
129 |
+
const json: TextGenerationStreamOutput = JSON.parse(line.slice("data:".length));
|
130 |
+
return json.token.text;
|
131 |
+
} catch {
|
132 |
+
return "";
|
133 |
+
}
|
134 |
+
});
|
135 |
+
return tokens.join("");
|
136 |
+
}
|
137 |
}
|
138 |
|
139 |
// Merge inputs into a single Uint8Array
|
140 |
+
const completeInput = concatUint8Arrays(inputs);
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
// Get last line starting with "data:" and parse it as JSON to get the generated text
|
143 |
const message = new TextDecoder().decode(completeInput);
|
src/routes/conversation/[id]/stop-generating/+server.ts
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { collections } from "$lib/server/database";
|
2 |
+
import { error } from "@sveltejs/kit";
|
3 |
+
import { ObjectId } from "mongodb";
|
4 |
+
|
5 |
+
/**
|
6 |
+
* Ideally, we'd be able to detect the client-side abort, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850
|
7 |
+
*/
|
8 |
+
export async function POST({ params, locals }) {
|
9 |
+
const conversationId = new ObjectId(params.id);
|
10 |
+
|
11 |
+
const conversation = await collections.conversations.findOne({
|
12 |
+
_id: conversationId,
|
13 |
+
sessionId: locals.sessionId,
|
14 |
+
});
|
15 |
+
|
16 |
+
if (!conversation) {
|
17 |
+
throw error(404, "Conversation not found");
|
18 |
+
}
|
19 |
+
|
20 |
+
await collections.abortedGenerations.updateOne(
|
21 |
+
{ conversationId },
|
22 |
+
{ $set: { updatedAt: new Date() }, $setOnInsert: { createdAt: new Date() } },
|
23 |
+
{ upsert: true }
|
24 |
+
);
|
25 |
+
|
26 |
+
return new Response();
|
27 |
+
}
|