coyotte508 HF staff commited on
Commit
ad02fa3
1 Parent(s): e5eb656

✨ Save messages in backend (#31)

Browse files
.eslintrc.cjs CHANGED
@@ -12,6 +12,9 @@ module.exports = {
12
  sourceType: 'module',
13
  ecmaVersion: 2020
14
  },
 
 
 
15
  env: {
16
  browser: true,
17
  es2017: true,
 
12
  sourceType: 'module',
13
  ecmaVersion: 2020
14
  },
15
+ rules: {
16
+ 'no-shadow': ['error']
17
+ },
18
  env: {
19
  browser: true,
20
  es2017: true,
src/lib/buildPrompt.ts ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import {
2
+ PUBLIC_ASSISTANT_MESSAGE_TOKEN,
3
+ PUBLIC_SEP_TOKEN,
4
+ PUBLIC_USER_MESSAGE_TOKEN
5
+ } from '$env/static/public';
6
+ import type { Message } from './types/Message';
7
+
8
+ /**
9
+ * Convert [{user: "assistant", content: "hi"}, {user: "user", content: "hello"}] to:
10
+ *
11
+ * <|assistant|>hi<|endoftext|><|prompter|>hello<|endoftext|><|assistant|>
12
+ */
13
+ export function buildPrompt(messages: Message[]): string {
14
+ return (
15
+ messages
16
+ .map(
17
+ (m) =>
18
+ (m.from === 'user'
19
+ ? PUBLIC_USER_MESSAGE_TOKEN + m.content
20
+ : PUBLIC_ASSISTANT_MESSAGE_TOKEN + m.content) +
21
+ (m.content.endsWith(PUBLIC_SEP_TOKEN) ? '' : PUBLIC_SEP_TOKEN)
22
+ )
23
+ .join('') + PUBLIC_ASSISTANT_MESSAGE_TOKEN
24
+ );
25
+ }
src/lib/utils/streamToAsyncIterable.ts ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Statements/for-await...of#iterating_over_async_generators
2
+ export async function* streamToAsyncIterable(
3
+ stream: ReadableStream<Uint8Array>
4
+ ): AsyncIterableIterator<Uint8Array> {
5
+ const reader = stream.getReader();
6
+ try {
7
+ while (true) {
8
+ const { done, value } = await reader.read();
9
+ if (done) return;
10
+ yield value;
11
+ }
12
+ } finally {
13
+ reader.releaseLock();
14
+ }
15
+ }
src/lib/utils/sum.ts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ export function sum(nums: number[]): number {
2
+ return nums.reduce((a, b) => a + b, 0);
3
+ }
src/routes/+page.svelte CHANGED
@@ -1,5 +1,5 @@
1
  <script lang="ts">
2
- import { goto, invalidate, invalidateAll } from '$app/navigation';
3
  import ChatWindow from '$lib/components/chat/ChatWindow.svelte';
4
  import { pendingMessage } from '$lib/stores/pendingMessage';
5
 
 
1
  <script lang="ts">
2
+ import { goto } from '$app/navigation';
3
  import ChatWindow from '$lib/components/chat/ChatWindow.svelte';
4
  import { pendingMessage } from '$lib/stores/pendingMessage';
5
 
src/routes/api/conversation/+server.ts DELETED
@@ -1,19 +0,0 @@
1
- import { HF_TOKEN } from '$env/static/private';
2
- import { PUBLIC_MODEL_ENDPOINT } from '$env/static/public';
3
-
4
- export async function POST({ request, fetch }) {
5
- const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
6
- headers: {
7
- 'Content-Type': request.headers.get('Content-Type') ?? 'application/json',
8
- Authorization: `Basic ${HF_TOKEN}`
9
- },
10
- method: 'POST',
11
- body: await request.text()
12
- });
13
-
14
- return new Response(resp.body, {
15
- headers: Object.fromEntries(resp.headers.entries()),
16
- status: resp.status,
17
- statusText: resp.statusText
18
- });
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/routes/conversation/[id]/+page.svelte CHANGED
@@ -4,23 +4,14 @@
4
  import { onMount } from 'svelte';
5
  import type { PageData } from './$types';
6
  import { page } from '$app/stores';
7
- import {
8
- PUBLIC_ASSISTANT_MESSAGE_TOKEN,
9
- PUBLIC_SEP_TOKEN,
10
- PUBLIC_USER_MESSAGE_TOKEN
11
- } from '$env/static/public';
12
  import { HfInference } from '@huggingface/inference';
13
 
14
  export let data: PageData;
15
 
16
  $: messages = data.messages;
17
 
18
- const userToken = PUBLIC_USER_MESSAGE_TOKEN;
19
- const assistantToken = PUBLIC_ASSISTANT_MESSAGE_TOKEN;
20
- const sepToken = PUBLIC_SEP_TOKEN;
21
-
22
  const hf = new HfInference();
23
- const model = hf.endpoint(`${$page.url.origin}/api/conversation`);
24
 
25
  let loading = false;
26
 
@@ -76,16 +67,7 @@
76
 
77
  messages = [...messages, { from: 'user', content: message }];
78
 
79
- const inputs =
80
- messages
81
- .map(
82
- (m) =>
83
- (m.from === 'user' ? userToken + m.content : assistantToken + m.content) +
84
- (m.content.endsWith(sepToken) ? '' : sepToken)
85
- )
86
- .join('') + assistantToken;
87
-
88
- await getTextGenerationStream(inputs);
89
  } finally {
90
  loading = false;
91
  }
 
4
  import { onMount } from 'svelte';
5
  import type { PageData } from './$types';
6
  import { page } from '$app/stores';
 
 
 
 
 
7
  import { HfInference } from '@huggingface/inference';
8
 
9
  export let data: PageData;
10
 
11
  $: messages = data.messages;
12
 
 
 
 
 
13
  const hf = new HfInference();
14
+ const model = hf.endpoint($page.url.href);
15
 
16
  let loading = false;
17
 
 
67
 
68
  messages = [...messages, { from: 'user', content: message }];
69
 
70
+ await getTextGenerationStream(message);
 
 
 
 
 
 
 
 
 
71
  } finally {
72
  loading = false;
73
  }
src/routes/conversation/[id]/+server.ts ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HF_TOKEN } from '$env/static/private';
2
+ import { PUBLIC_MODEL_ENDPOINT } from '$env/static/public';
3
+ import { buildPrompt } from '$lib/buildPrompt.js';
4
+ import { collections } from '$lib/server/database.js';
5
+ import type { Message } from '$lib/types/Message.js';
6
+ import { streamToAsyncIterable } from '$lib/utils/streamToAsyncIterable';
7
+ import { sum } from '$lib/utils/sum';
8
+ import { error } from '@sveltejs/kit';
9
+ import { ObjectId } from 'mongodb';
10
+
11
+ export async function POST({ request, fetch, locals, params }) {
12
+ // todo: add validation on params.id
13
+ const convId = new ObjectId(params.id);
14
+
15
+ const conv = await collections.conversations.findOne({
16
+ _id: convId,
17
+ sessionId: locals.sessionId
18
+ });
19
+
20
+ if (!conv) {
21
+ throw error(404, 'Conversation not found');
22
+ }
23
+
24
+ // Todo: validate prompt with zod? or aktype
25
+ const json = await request.json();
26
+
27
+ const messages = [...conv.messages, { from: 'user', content: json.inputs }] satisfies Message[];
28
+
29
+ json.inputs = buildPrompt(messages);
30
+
31
+ const resp = await fetch(PUBLIC_MODEL_ENDPOINT, {
32
+ headers: {
33
+ 'Content-Type': request.headers.get('Content-Type') ?? 'application/json',
34
+ Authorization: `Basic ${HF_TOKEN}`
35
+ },
36
+ method: 'POST',
37
+ body: JSON.stringify(json)
38
+ });
39
+
40
+ const [stream1, stream2] = resp.body!.tee();
41
+
42
+ async function saveMessage() {
43
+ const generated_text = await parseGeneratedText(stream2);
44
+
45
+ messages.push({ from: 'assistant', content: generated_text });
46
+
47
+ console.log('updating conversation', convId, messages);
48
+
49
+ await collections.conversations.updateOne(
50
+ {
51
+ _id: convId
52
+ },
53
+ {
54
+ $set: {
55
+ messages,
56
+ updatedAt: new Date()
57
+ }
58
+ }
59
+ );
60
+ }
61
+
62
+ saveMessage().catch(console.error);
63
+
64
+ // Todo: maybe we should wait for the message to be saved before ending the response - in case of errors
65
+ return new Response(stream1, {
66
+ headers: Object.fromEntries(resp.headers.entries()),
67
+ status: resp.status,
68
+ statusText: resp.statusText
69
+ });
70
+ }
71
+
72
+ async function parseGeneratedText(stream: ReadableStream): Promise<string> {
73
+ const inputs: Uint8Array[] = [];
74
+ for await (const input of streamToAsyncIterable(stream)) {
75
+ inputs.push(input);
76
+ }
77
+
78
+ // Merge inputs into a single Uint8Array
79
+ const completeInput = new Uint8Array(sum(inputs.map((input) => input.length)));
80
+ let offset = 0;
81
+ for (const input of inputs) {
82
+ completeInput.set(input, offset);
83
+ offset += input.length;
84
+ }
85
+
86
+ // Get last line starting with "data:" and parse it as JSON to get the generated text
87
+ const message = new TextDecoder().decode(completeInput);
88
+
89
+ let lastIndex = message.lastIndexOf('\ndata:');
90
+ if (lastIndex === -1) {
91
+ lastIndex = message.indexOf('data');
92
+ }
93
+
94
+ if (lastIndex === -1) {
95
+ console.error('Could not parse in last message');
96
+ }
97
+
98
+ let lastMessage = message.slice(lastIndex).trim().slice('data:'.length);
99
+ if (lastMessage.includes('\n')) {
100
+ lastMessage = lastMessage.slice(0, lastMessage.indexOf('\n'));
101
+ }
102
+
103
+ const res = JSON.parse(lastMessage).generated_text;
104
+
105
+ if (typeof res !== 'string') {
106
+ throw new Error('Could not parse generated text');
107
+ }
108
+
109
+ return res;
110
+ }