Daniel Fried commited on
Commit
44efa8c
·
1 Parent(s): 8a85023

fix query encoding and add new examples

Browse files
Files changed (2) hide show
  1. modules/app.py +27 -1
  2. static/index.html +23 -7
modules/app.py CHANGED
@@ -2,6 +2,7 @@ import sys
2
  from typing import List
3
  import traceback
4
  import os
 
5
  # needs to be imported *before* transformers
6
  if os.path.exists('use_normal_tokenizers'):
7
  import tokenizers
@@ -11,8 +12,10 @@ else:
11
  import tokenizers_patch
12
  BIG_MODEL = True
13
  CUDA = True
 
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
  import json
 
16
 
17
  # from flask import Flask, request, render_template
18
  # from flask_cors import CORS
@@ -32,8 +35,14 @@ TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in t
32
 
33
  if BIG_MODEL:
34
  model_name = "facebook/incoder-6B"
 
 
 
 
 
35
  else:
36
  model_name = "facebook/incoder-1B"
 
37
 
38
  from fastapi import FastAPI, Request
39
  from fastapi.staticfiles import StaticFiles
@@ -43,7 +52,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
43
 
44
 
45
  print("loading model")
46
- model = AutoModelForCausalLM.from_pretrained(model_name)
47
  print("loading tokenizer")
48
  tokenizer = AutoTokenizer.from_pretrained(model_name)
49
  print("loading complete")
@@ -154,9 +163,18 @@ def index() -> FileResponse:
154
  return FileResponse(path="static/index.html", media_type="text/html")
155
 
156
  @app.get('/generate')
 
157
  async def generate_maybe(info: str):
158
  # form = await info.json()
 
 
 
 
 
 
159
  form = json.loads(info)
 
 
160
  prompt = form['prompt']
161
  length_limit = int(form['length'])
162
  temperature = float(form['temperature'])
@@ -174,9 +192,17 @@ async def generate_maybe(info: str):
174
  return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
175
 
176
  @app.get('/infill')
 
177
  async def infill_maybe(info: str):
178
  # form = await info.json()
 
 
 
 
 
 
179
  form = json.loads(info)
 
180
  length_limit = int(form['length'])
181
  temperature = float(form['temperature'])
182
  max_retries = 1
 
2
  from typing import List
3
  import traceback
4
  import os
5
+ import base64
6
  # needs to be imported *before* transformers
7
  if os.path.exists('use_normal_tokenizers'):
8
  import tokenizers
 
12
  import tokenizers_patch
13
  BIG_MODEL = True
14
  CUDA = True
15
+ import torch
16
  from transformers import AutoModelForCausalLM, AutoTokenizer
17
  import json
18
+ import pprint
19
 
20
  # from flask import Flask, request, render_template
21
  # from flask_cors import CORS
 
35
 
36
  if BIG_MODEL:
37
  model_name = "facebook/incoder-6B"
38
+ kwargs = dict(
39
+ revision="float16",
40
+ torch_dtype=torch.float16,
41
+ low_cpu_mem_usage=True,
42
+ )
43
  else:
44
  model_name = "facebook/incoder-1B"
45
+ kwargs = dict()
46
 
47
  from fastapi import FastAPI, Request
48
  from fastapi.staticfiles import StaticFiles
 
52
 
53
 
54
  print("loading model")
55
+ model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
56
  print("loading tokenizer")
57
  tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  print("loading complete")
 
163
  return FileResponse(path="static/index.html", media_type="text/html")
164
 
165
  @app.get('/generate')
166
+ # async def generate_maybe(request: Request):
167
  async def generate_maybe(info: str):
168
  # form = await info.json()
169
+ # form = await request.json()
170
+ # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
171
+ # fix padding, following https://stackoverflow.com/a/9956217/1319683
172
+ print(info)
173
+ info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
174
+ print(info)
175
  form = json.loads(info)
176
+ pprint.pprint(form)
177
+ # print(form)
178
  prompt = form['prompt']
179
  length_limit = int(form['length'])
180
  temperature = float(form['temperature'])
 
192
  return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
193
 
194
  @app.get('/infill')
195
+ # async def infill_maybe(request: Request):
196
  async def infill_maybe(info: str):
197
  # form = await info.json()
198
+ # form = await request.json()
199
+ # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
200
+ # fix padding, following https://stackoverflow.com/a/9956217/1319683
201
+ print(info)
202
+ info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
203
+ print(info)
204
  form = json.loads(info)
205
+ pprint.pprint(form)
206
  length_limit = int(form['length'])
207
  temperature = float(form['temperature'])
208
  max_retries = 1
static/index.html CHANGED
@@ -134,6 +134,7 @@ label {
134
  <span class="softspan">Infill Examples:</span>
135
  <br>
136
  <span class="softspan"><a href='javascript:select_example("type-pred");'>Type prediction</a></span>
 
137
  <span class="softspan"><a href='javascript:select_example("docstring");'>Function to docstring</a></span>
138
  <span class="softspan"><a href='javascript:select_example("python-infill2");'>Docstring to function</a></span>
139
  <span class="softspan"><a href='javascript:select_example("class");'>Class generation</a></span>
@@ -252,12 +253,20 @@ def <infill>
252
  "temperature": 0.2,
253
  "mode": "python"
254
  },
255
-
 
 
 
 
 
 
 
 
 
 
256
  "type-pred": {
257
  "prompt":
258
- `<| file ext=.py |>
259
-
260
- def count_words(filename: str) -> <infill>
261
  """Count the number of occurrences of each word in the file."""
262
  with open(filename, 'r') as f:
263
  word_counts = {}
@@ -310,7 +319,7 @@ def count_words(filename):
310
  "mode": "python"
311
  },
312
  "javascript": {
313
- "prompt": "<| file ext=.js |>\n // is something really happening here",
314
  "length": 64,
315
  "temperature": 0.6,
316
  "mode": "javascript"
@@ -529,6 +538,7 @@ function make_generate_listener(url) {
529
  console.log("Response:");
530
  console.log(receive_data);
531
  if (receive_data["result"] == "success") {
 
532
  // $("#prompt").text(data["prompt"]);
533
  // $("#response").text(data["text"]);
534
  set_text(receive_data["text"]);
@@ -540,6 +550,7 @@ function make_generate_listener(url) {
540
  $("#warning").text("");
541
  }
542
  } else {
 
543
  set_text(receive_data["text"])
544
  $("#error").text(receive_data["message"]);
545
  }
@@ -552,13 +563,18 @@ function make_generate_listener(url) {
552
  $("#error").text(err);
553
  }
554
 
555
- encoded_data = JSON.stringify(send_data)
556
 
557
  try {
558
  const response = await fetch(`${url}?info=${encoded_data}`);
 
 
 
 
559
  if (response.status >= 400) {
560
  error(response.statusText);
561
- complete();
 
562
  } else {
563
  response.json().then(success).catch(error).finally(complete);
564
  }
 
134
  <span class="softspan">Infill Examples:</span>
135
  <br>
136
  <span class="softspan"><a href='javascript:select_example("type-pred");'>Type prediction</a></span>
137
+ <span class="softspan"><a href='javascript:select_example("multi-region");'>Multi-region</a></span>
138
  <span class="softspan"><a href='javascript:select_example("docstring");'>Function to docstring</a></span>
139
  <span class="softspan"><a href='javascript:select_example("python-infill2");'>Docstring to function</a></span>
140
  <span class="softspan"><a href='javascript:select_example("class");'>Class generation</a></span>
 
253
  "temperature": 0.2,
254
  "mode": "python"
255
  },
256
+ "multi-region": {
257
+ "prompt":
258
+ `<| file ext=.py |>
259
+ <infill>
260
+ """ Load the given gzip jsonl file. """
261
+ <infill>
262
+ `,
263
+ "length": 64,
264
+ "temperature": 0.2,
265
+ "mode": "python"
266
+ },
267
  "type-pred": {
268
  "prompt":
269
+ `def count_words(filename: str) -> <infill>
 
 
270
  """Count the number of occurrences of each word in the file."""
271
  with open(filename, 'r') as f:
272
  word_counts = {}
 
319
  "mode": "python"
320
  },
321
  "javascript": {
322
+ "prompt": "// fetch from the given URL and load the response contents into a new div",
323
  "length": 64,
324
  "temperature": 0.6,
325
  "mode": "javascript"
 
538
  console.log("Response:");
539
  console.log(receive_data);
540
  if (receive_data["result"] == "success") {
541
+ console.log("success");
542
  // $("#prompt").text(data["prompt"]);
543
  // $("#response").text(data["text"]);
544
  set_text(receive_data["text"]);
 
550
  $("#warning").text("");
551
  }
552
  } else {
553
+ console.log("error");
554
  set_text(receive_data["text"])
555
  $("#error").text(receive_data["message"]);
556
  }
 
563
  $("#error").text(err);
564
  }
565
 
566
+ encoded_data = encodeURIComponent(btoa(JSON.stringify(send_data)))
567
 
568
  try {
569
  const response = await fetch(`${url}?info=${encoded_data}`);
570
+ // const response = await fetch(`${url}` {
571
+ // method: 'GET',
572
+ // body: encoded_data,
573
+ // });
574
  if (response.status >= 400) {
575
  error(response.statusText);
576
+ console.log("here");
577
+ console.log(response.status);
578
  } else {
579
  response.json().then(success).catch(error).finally(complete);
580
  }