radames commited on
Commit
43eb4d6
·
1 Parent(s): fca36d1

fix firstToken

Browse files
Files changed (7) hide show
  1. code.js +4 -0
  2. index.html +25 -16
  3. lib/m.d.ts +3 -2
  4. lib/m.js +3 -2
  5. lib/m_bg.wasm +2 -2
  6. lib/m_bg.wasm.d.ts +1 -1
  7. llama2c.worker.js +10 -2
code.js CHANGED
@@ -87,6 +87,7 @@ async function generateSequence({
87
  maxSeqLen,
88
  temp,
89
  repeatPenalty,
 
90
  contentEl,
91
  controller,
92
  }) {
@@ -104,6 +105,7 @@ async function generateSequence({
104
  prompt,
105
  temp,
106
  repeatPenalty,
 
107
  seed: seed,
108
  maxSeqLen,
109
  command: "start",
@@ -201,6 +203,7 @@ async function run(containers, controller) {
201
  const maxSeqLen = document.querySelector("#max-seq");
202
  const temp = document.querySelector("#temperature");
203
  const repeatPenalty = document.querySelector("#repeat-penalty");
 
204
  const modelID = document.querySelector("#model");
205
 
206
  const weightsURL = `${MODELS_BASE_URL}/${MODELS[getValue(modelID)].url}`;
@@ -223,6 +226,7 @@ async function run(containers, controller) {
223
  modelID: getValue(modelID),
224
  maxSeqLen: getValue(maxSeqLen),
225
  temp: getValue(temp),
 
226
  repeatPenalty: getValue(repeatPenalty),
227
  contentEl: container,
228
  controller,
 
87
  maxSeqLen,
88
  temp,
89
  repeatPenalty,
90
+ top_p,
91
  contentEl,
92
  controller,
93
  }) {
 
105
  prompt,
106
  temp,
107
  repeatPenalty,
108
+ top_p,
109
  seed: seed,
110
  maxSeqLen,
111
  command: "start",
 
203
  const maxSeqLen = document.querySelector("#max-seq");
204
  const temp = document.querySelector("#temperature");
205
  const repeatPenalty = document.querySelector("#repeat-penalty");
206
+ const topP = document.querySelector("#top-p");
207
  const modelID = document.querySelector("#model");
208
 
209
  const weightsURL = `${MODELS_BASE_URL}/${MODELS[getValue(modelID)].url}`;
 
226
  modelID: getValue(modelID),
227
  maxSeqLen: getValue(maxSeqLen),
228
  temp: getValue(temp),
229
+ top_p: getValue(topP),
230
  repeatPenalty: getValue(repeatPenalty),
231
  contentEl: container,
232
  controller,
index.html CHANGED
@@ -3,20 +3,23 @@
3
  <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
4
  <title>TinyStories - Candle Llama.c Rust/WASM</title>
5
  </head>
 
6
  <body></body>
7
  </html>
8
 
9
- <!doctype html>
10
  <html>
11
  <head>
12
  <meta charset="UTF-8" />
13
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
14
  <style>
15
  @import url("https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@200;400;500;700&family=Source+Sans+3:wght@200;400;500;600;700;800;900&display=swap");
 
16
  html,
17
  body {
18
  font-family: "Source Sans 3", sans-serif;
19
  }
 
20
  .mono {
21
  font-family: "IBM Plex Mono", monospace;
22
  }
@@ -44,6 +47,7 @@
44
  </style>
45
  <script type="module" src="./code.js"></script>
46
  </head>
 
47
  <body class="container mx-auto max-w-2xl p-4 bg-[#020058]">
48
  <img src="./imgs/cat.png" class="fixed top-0 left-0 w-20 -z-10" />
49
  <header class="py-2 mb-6">
@@ -66,20 +70,17 @@
66
  </header>
67
  <form
68
  id="form"
69
- class="flex text-normal px-1 py-2 border-2 border-white rounded-md items-center"
70
- >
71
  <input type="submit" hidden="" />
72
  <input
73
  type="text"
74
  id="prompt"
75
  class="w-full px-3 py-2 mx-1 resize-none outline-none bg-[#020058] text-white"
76
  placeholder="Add your prompt here..."
77
- value="Once upon a time"
78
- />
79
  <button
80
  id="run"
81
- class="bg-white hover:bg-gray-400 text-black font-normal py-2 w-20 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
82
- >
83
  Create
84
  </button>
85
  </form>
@@ -89,8 +90,7 @@
89
  <label for="model" class="text-sm">Models Options: </label>
90
  <select
91
  id="model"
92
- class="border-2 border-white rounded-md font-light text-white bg-[#020058] px-1"
93
- >
94
  <option value="stories15M" selected>stories 15M (60.8 MB)</option>
95
  <option value="stories42M">stories 42M (167 MB)</option>
96
  <option value="stories110M">stories 110M (438 MB) WARNING</option>
@@ -104,8 +104,7 @@
104
  max="256"
105
  step="1"
106
  value="150"
107
- oninput="this.nextElementSibling.value = Number(this.value)"
108
- />
109
  <output class="n-block"> 150</output>
110
  <label class="text-sm font-medium" for="temperature">Temperature</label>
111
  <input
@@ -115,23 +114,33 @@
115
  max="2"
116
  step="0.01"
117
  value="0.2"
118
- oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
119
- />
120
  <output class="n-block">0.2</output>
121
 
122
  <label class="text-sm font-medium" for="repeat_penalty"
123
  >Repeat Penalty</label
124
  >
 
125
  <input
126
  type="range"
127
  id="repeat-penalty"
128
- min="-2"
129
  max="2"
130
  step="0.01"
131
  value="1.10"
132
- oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
133
- />
134
  <output class="n-block">1.10</output>
 
 
 
 
 
 
 
 
 
 
 
135
  </div>
136
  </details>
137
  <div class="text-base grid gap-3 py-5" id="container"></div>
 
3
  <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
4
  <title>TinyStories - Candle Llama.c Rust/WASM</title>
5
  </head>
6
+
7
  <body></body>
8
  </html>
9
 
10
+ <!DOCTYPE html>
11
  <html>
12
  <head>
13
  <meta charset="UTF-8" />
14
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
15
  <style>
16
  @import url("https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@200;400;500;700&family=Source+Sans+3:wght@200;400;500;600;700;800;900&display=swap");
17
+
18
  html,
19
  body {
20
  font-family: "Source Sans 3", sans-serif;
21
  }
22
+
23
  .mono {
24
  font-family: "IBM Plex Mono", monospace;
25
  }
 
47
  </style>
48
  <script type="module" src="./code.js"></script>
49
  </head>
50
+
51
  <body class="container mx-auto max-w-2xl p-4 bg-[#020058]">
52
  <img src="./imgs/cat.png" class="fixed top-0 left-0 w-20 -z-10" />
53
  <header class="py-2 mb-6">
 
70
  </header>
71
  <form
72
  id="form"
73
+ class="flex text-normal px-1 py-2 border-2 border-white rounded-md items-center">
 
74
  <input type="submit" hidden="" />
75
  <input
76
  type="text"
77
  id="prompt"
78
  class="w-full px-3 py-2 mx-1 resize-none outline-none bg-[#020058] text-white"
79
  placeholder="Add your prompt here..."
80
+ value="Once upon a time" />
 
81
  <button
82
  id="run"
83
+ class="bg-white hover:bg-gray-400 text-black font-normal py-2 w-20 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
 
84
  Create
85
  </button>
86
  </form>
 
90
  <label for="model" class="text-sm">Models Options: </label>
91
  <select
92
  id="model"
93
+ class="border-2 border-white rounded-md font-light text-white bg-[#020058] px-1">
 
94
  <option value="stories15M" selected>stories 15M (60.8 MB)</option>
95
  <option value="stories42M">stories 42M (167 MB)</option>
96
  <option value="stories110M">stories 110M (438 MB) WARNING</option>
 
104
  max="256"
105
  step="1"
106
  value="150"
107
+ oninput="this.nextElementSibling.value = Number(this.value)" />
 
108
  <output class="n-block"> 150</output>
109
  <label class="text-sm font-medium" for="temperature">Temperature</label>
110
  <input
 
114
  max="2"
115
  step="0.01"
116
  value="0.2"
117
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
 
118
  <output class="n-block">0.2</output>
119
 
120
  <label class="text-sm font-medium" for="repeat_penalty"
121
  >Repeat Penalty</label
122
  >
123
+
124
  <input
125
  type="range"
126
  id="repeat-penalty"
127
+ min="1"
128
  max="2"
129
  step="0.01"
130
  value="1.10"
131
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
 
132
  <output class="n-block">1.10</output>
133
+
134
+ <label class="text-sm font-medium" for="top-p">Top P</label>
135
+ <input
136
+ type="range"
137
+ id="top-p"
138
+ min="0"
139
+ max="1"
140
+ step="0.01"
141
+ value="1.00"
142
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
143
+ <output class="n-block">1.00</output>
144
  </div>
145
  </details>
146
  <div class="text-base grid gap-3 py-5" id="container"></div>
lib/m.d.ts CHANGED
@@ -16,11 +16,12 @@ export class Model {
16
  /**
17
  * @param {string} prompt
18
  * @param {number} temp
 
19
  * @param {number} repeat_penalty
20
  * @param {bigint} seed
21
  * @returns {string}
22
  */
23
- init_with_prompt(prompt: string, temp: number, repeat_penalty: number, seed: bigint): string;
24
  /**
25
  * @returns {string}
26
  */
@@ -34,7 +35,7 @@ export interface InitOutput {
34
  readonly __wbg_model_free: (a: number) => void;
35
  readonly model_new: (a: number, b: number, c: number, d: number, e: number) => void;
36
  readonly model_get_seq_len: (a: number) => number;
37
- readonly model_init_with_prompt: (a: number, b: number, c: number, d: number, e: number, f: number, g: number) => void;
38
  readonly model_next_token: (a: number, b: number) => void;
39
  readonly main: (a: number, b: number) => number;
40
  readonly __wbindgen_add_to_stack_pointer: (a: number) => number;
 
16
  /**
17
  * @param {string} prompt
18
  * @param {number} temp
19
+ * @param {number} top_p
20
  * @param {number} repeat_penalty
21
  * @param {bigint} seed
22
  * @returns {string}
23
  */
24
+ init_with_prompt(prompt: string, temp: number, top_p: number, repeat_penalty: number, seed: bigint): string;
25
  /**
26
  * @returns {string}
27
  */
 
35
  readonly __wbg_model_free: (a: number) => void;
36
  readonly model_new: (a: number, b: number, c: number, d: number, e: number) => void;
37
  readonly model_get_seq_len: (a: number) => number;
38
+ readonly model_init_with_prompt: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number) => void;
39
  readonly model_next_token: (a: number, b: number) => void;
40
  readonly main: (a: number, b: number) => number;
41
  readonly __wbindgen_add_to_stack_pointer: (a: number) => number;
lib/m.js CHANGED
@@ -181,18 +181,19 @@ export class Model {
181
  /**
182
  * @param {string} prompt
183
  * @param {number} temp
 
184
  * @param {number} repeat_penalty
185
  * @param {bigint} seed
186
  * @returns {string}
187
  */
188
- init_with_prompt(prompt, temp, repeat_penalty, seed) {
189
  let deferred3_0;
190
  let deferred3_1;
191
  try {
192
  const retptr = wasm.__wbindgen_add_to_stack_pointer(-16);
193
  const ptr0 = passStringToWasm0(prompt, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
194
  const len0 = WASM_VECTOR_LEN;
195
- wasm.model_init_with_prompt(retptr, this.__wbg_ptr, ptr0, len0, temp, repeat_penalty, seed);
196
  var r0 = getInt32Memory0()[retptr / 4 + 0];
197
  var r1 = getInt32Memory0()[retptr / 4 + 1];
198
  var r2 = getInt32Memory0()[retptr / 4 + 2];
 
181
  /**
182
  * @param {string} prompt
183
  * @param {number} temp
184
+ * @param {number} top_p
185
  * @param {number} repeat_penalty
186
  * @param {bigint} seed
187
  * @returns {string}
188
  */
189
+ init_with_prompt(prompt, temp, top_p, repeat_penalty, seed) {
190
  let deferred3_0;
191
  let deferred3_1;
192
  try {
193
  const retptr = wasm.__wbindgen_add_to_stack_pointer(-16);
194
  const ptr0 = passStringToWasm0(prompt, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
195
  const len0 = WASM_VECTOR_LEN;
196
+ wasm.model_init_with_prompt(retptr, this.__wbg_ptr, ptr0, len0, temp, top_p, repeat_penalty, seed);
197
  var r0 = getInt32Memory0()[retptr / 4 + 0];
198
  var r1 = getInt32Memory0()[retptr / 4 + 1];
199
  var r2 = getInt32Memory0()[retptr / 4 + 2];
lib/m_bg.wasm CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:761e18f3da99de2c2eac0f5bc13dee39fc412c472694d277bfe1cd4b1e5809d7
3
- size 3725264
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7e872b0d61cb8dc0ea356c867b6082ac50c68639165c63728c7f3ef1a0f5979
3
+ size 3794192
lib/m_bg.wasm.d.ts CHANGED
@@ -4,7 +4,7 @@ export const memory: WebAssembly.Memory;
4
  export function __wbg_model_free(a: number): void;
5
  export function model_new(a: number, b: number, c: number, d: number, e: number): void;
6
  export function model_get_seq_len(a: number): number;
7
- export function model_init_with_prompt(a: number, b: number, c: number, d: number, e: number, f: number, g: number): void;
8
  export function model_next_token(a: number, b: number): void;
9
  export function main(a: number, b: number): number;
10
  export function __wbindgen_add_to_stack_pointer(a: number): number;
 
4
  export function __wbg_model_free(a: number): void;
5
  export function model_new(a: number, b: number, c: number, d: number, e: number): void;
6
  export function model_get_seq_len(a: number): number;
7
+ export function model_init_with_prompt(a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number): void;
8
  export function model_next_token(a: number, b: number): void;
9
  export function main(a: number, b: number): number;
10
  export function __wbindgen_add_to_stack_pointer(a: number): number;
llama2c.worker.js CHANGED
@@ -53,20 +53,28 @@ async function generate(data) {
53
  tokenizerURL,
54
  prompt,
55
  temp,
 
56
  repeatPenalty,
57
  seed,
58
  maxSeqLen,
59
  } = data;
60
  try {
 
61
  self.postMessage({ status: "loading", message: "Starting llama2.c" });
62
  const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL);
63
 
64
  self.postMessage({ status: "loading", message: "Initializing model" });
65
- model.init_with_prompt(prompt, temp, repeatPenalty, seed);
 
 
 
 
 
 
66
 
67
  const seq_len = model.get_seq_len();
68
 
69
- let sentence = "";
70
  let maxTokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
71
  let startTime = performance.now();
72
  let tokensCount = 0;
 
53
  tokenizerURL,
54
  prompt,
55
  temp,
56
+ top_p,
57
  repeatPenalty,
58
  seed,
59
  maxSeqLen,
60
  } = data;
61
  try {
62
+ console.log(data);
63
  self.postMessage({ status: "loading", message: "Starting llama2.c" });
64
  const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL);
65
 
66
  self.postMessage({ status: "loading", message: "Initializing model" });
67
+ const firstToken = model.init_with_prompt(
68
+ prompt,
69
+ temp,
70
+ top_p,
71
+ repeatPenalty,
72
+ seed
73
+ );
74
 
75
  const seq_len = model.get_seq_len();
76
 
77
+ let sentence = firstToken;
78
  let maxTokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
79
  let startTime = performance.now();
80
  let tokensCount = 0;