Spaces:
Sleeping
Sleeping
Daniel Fried
commited on
Commit
·
44efa8c
1
Parent(s):
8a85023
fix query encoding and add new examples
Browse files- modules/app.py +27 -1
- static/index.html +23 -7
modules/app.py
CHANGED
@@ -2,6 +2,7 @@ import sys
|
|
2 |
from typing import List
|
3 |
import traceback
|
4 |
import os
|
|
|
5 |
# needs to be imported *before* transformers
|
6 |
if os.path.exists('use_normal_tokenizers'):
|
7 |
import tokenizers
|
@@ -11,8 +12,10 @@ else:
|
|
11 |
import tokenizers_patch
|
12 |
BIG_MODEL = True
|
13 |
CUDA = True
|
|
|
14 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
15 |
import json
|
|
|
16 |
|
17 |
# from flask import Flask, request, render_template
|
18 |
# from flask_cors import CORS
|
@@ -32,8 +35,14 @@ TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in t
|
|
32 |
|
33 |
if BIG_MODEL:
|
34 |
model_name = "facebook/incoder-6B"
|
|
|
|
|
|
|
|
|
|
|
35 |
else:
|
36 |
model_name = "facebook/incoder-1B"
|
|
|
37 |
|
38 |
from fastapi import FastAPI, Request
|
39 |
from fastapi.staticfiles import StaticFiles
|
@@ -43,7 +52,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
43 |
|
44 |
|
45 |
print("loading model")
|
46 |
-
model = AutoModelForCausalLM.from_pretrained(model_name)
|
47 |
print("loading tokenizer")
|
48 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
49 |
print("loading complete")
|
@@ -154,9 +163,18 @@ def index() -> FileResponse:
|
|
154 |
return FileResponse(path="static/index.html", media_type="text/html")
|
155 |
|
156 |
@app.get('/generate')
|
|
|
157 |
async def generate_maybe(info: str):
|
158 |
# form = await info.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
form = json.loads(info)
|
|
|
|
|
160 |
prompt = form['prompt']
|
161 |
length_limit = int(form['length'])
|
162 |
temperature = float(form['temperature'])
|
@@ -174,9 +192,17 @@ async def generate_maybe(info: str):
|
|
174 |
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
|
175 |
|
176 |
@app.get('/infill')
|
|
|
177 |
async def infill_maybe(info: str):
|
178 |
# form = await info.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
form = json.loads(info)
|
|
|
180 |
length_limit = int(form['length'])
|
181 |
temperature = float(form['temperature'])
|
182 |
max_retries = 1
|
|
|
2 |
from typing import List
|
3 |
import traceback
|
4 |
import os
|
5 |
+
import base64
|
6 |
# needs to be imported *before* transformers
|
7 |
if os.path.exists('use_normal_tokenizers'):
|
8 |
import tokenizers
|
|
|
12 |
import tokenizers_patch
|
13 |
BIG_MODEL = True
|
14 |
CUDA = True
|
15 |
+
import torch
|
16 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
17 |
import json
|
18 |
+
import pprint
|
19 |
|
20 |
# from flask import Flask, request, render_template
|
21 |
# from flask_cors import CORS
|
|
|
35 |
|
36 |
if BIG_MODEL:
|
37 |
model_name = "facebook/incoder-6B"
|
38 |
+
kwargs = dict(
|
39 |
+
revision="float16",
|
40 |
+
torch_dtype=torch.float16,
|
41 |
+
low_cpu_mem_usage=True,
|
42 |
+
)
|
43 |
else:
|
44 |
model_name = "facebook/incoder-1B"
|
45 |
+
kwargs = dict()
|
46 |
|
47 |
from fastapi import FastAPI, Request
|
48 |
from fastapi.staticfiles import StaticFiles
|
|
|
52 |
|
53 |
|
54 |
print("loading model")
|
55 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
|
56 |
print("loading tokenizer")
|
57 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
58 |
print("loading complete")
|
|
|
163 |
return FileResponse(path="static/index.html", media_type="text/html")
|
164 |
|
165 |
@app.get('/generate')
|
166 |
+
# async def generate_maybe(request: Request):
|
167 |
async def generate_maybe(info: str):
|
168 |
# form = await info.json()
|
169 |
+
# form = await request.json()
|
170 |
+
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
|
171 |
+
# fix padding, following https://stackoverflow.com/a/9956217/1319683
|
172 |
+
print(info)
|
173 |
+
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
|
174 |
+
print(info)
|
175 |
form = json.loads(info)
|
176 |
+
pprint.pprint(form)
|
177 |
+
# print(form)
|
178 |
prompt = form['prompt']
|
179 |
length_limit = int(form['length'])
|
180 |
temperature = float(form['temperature'])
|
|
|
192 |
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
|
193 |
|
194 |
@app.get('/infill')
|
195 |
+
# async def infill_maybe(request: Request):
|
196 |
async def infill_maybe(info: str):
|
197 |
# form = await info.json()
|
198 |
+
# form = await request.json()
|
199 |
+
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
|
200 |
+
# fix padding, following https://stackoverflow.com/a/9956217/1319683
|
201 |
+
print(info)
|
202 |
+
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
|
203 |
+
print(info)
|
204 |
form = json.loads(info)
|
205 |
+
pprint.pprint(form)
|
206 |
length_limit = int(form['length'])
|
207 |
temperature = float(form['temperature'])
|
208 |
max_retries = 1
|
static/index.html
CHANGED
@@ -134,6 +134,7 @@ label {
|
|
134 |
<span class="softspan">Infill Examples:</span>
|
135 |
<br>
|
136 |
<span class="softspan"><a href='javascript:select_example("type-pred");'>Type prediction</a></span>
|
|
|
137 |
<span class="softspan"><a href='javascript:select_example("docstring");'>Function to docstring</a></span>
|
138 |
<span class="softspan"><a href='javascript:select_example("python-infill2");'>Docstring to function</a></span>
|
139 |
<span class="softspan"><a href='javascript:select_example("class");'>Class generation</a></span>
|
@@ -252,12 +253,20 @@ def <infill>
|
|
252 |
"temperature": 0.2,
|
253 |
"mode": "python"
|
254 |
},
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
"type-pred": {
|
257 |
"prompt":
|
258 |
-
|
259 |
-
|
260 |
-
def count_words(filename: str) -> <infill>
|
261 |
"""Count the number of occurrences of each word in the file."""
|
262 |
with open(filename, 'r') as f:
|
263 |
word_counts = {}
|
@@ -310,7 +319,7 @@ def count_words(filename):
|
|
310 |
"mode": "python"
|
311 |
},
|
312 |
"javascript": {
|
313 |
-
"prompt": "
|
314 |
"length": 64,
|
315 |
"temperature": 0.6,
|
316 |
"mode": "javascript"
|
@@ -529,6 +538,7 @@ function make_generate_listener(url) {
|
|
529 |
console.log("Response:");
|
530 |
console.log(receive_data);
|
531 |
if (receive_data["result"] == "success") {
|
|
|
532 |
// $("#prompt").text(data["prompt"]);
|
533 |
// $("#response").text(data["text"]);
|
534 |
set_text(receive_data["text"]);
|
@@ -540,6 +550,7 @@ function make_generate_listener(url) {
|
|
540 |
$("#warning").text("");
|
541 |
}
|
542 |
} else {
|
|
|
543 |
set_text(receive_data["text"])
|
544 |
$("#error").text(receive_data["message"]);
|
545 |
}
|
@@ -552,13 +563,18 @@ function make_generate_listener(url) {
|
|
552 |
$("#error").text(err);
|
553 |
}
|
554 |
|
555 |
-
encoded_data = JSON.stringify(send_data)
|
556 |
|
557 |
try {
|
558 |
const response = await fetch(`${url}?info=${encoded_data}`);
|
|
|
|
|
|
|
|
|
559 |
if (response.status >= 400) {
|
560 |
error(response.statusText);
|
561 |
-
|
|
|
562 |
} else {
|
563 |
response.json().then(success).catch(error).finally(complete);
|
564 |
}
|
|
|
134 |
<span class="softspan">Infill Examples:</span>
|
135 |
<br>
|
136 |
<span class="softspan"><a href='javascript:select_example("type-pred");'>Type prediction</a></span>
|
137 |
+
<span class="softspan"><a href='javascript:select_example("multi-region");'>Multi-region</a></span>
|
138 |
<span class="softspan"><a href='javascript:select_example("docstring");'>Function to docstring</a></span>
|
139 |
<span class="softspan"><a href='javascript:select_example("python-infill2");'>Docstring to function</a></span>
|
140 |
<span class="softspan"><a href='javascript:select_example("class");'>Class generation</a></span>
|
|
|
253 |
"temperature": 0.2,
|
254 |
"mode": "python"
|
255 |
},
|
256 |
+
"multi-region": {
|
257 |
+
"prompt":
|
258 |
+
`<| file ext=.py |>
|
259 |
+
<infill>
|
260 |
+
""" Load the given gzip jsonl file. """
|
261 |
+
<infill>
|
262 |
+
`,
|
263 |
+
"length": 64,
|
264 |
+
"temperature": 0.2,
|
265 |
+
"mode": "python"
|
266 |
+
},
|
267 |
"type-pred": {
|
268 |
"prompt":
|
269 |
+
`def count_words(filename: str) -> <infill>
|
|
|
|
|
270 |
"""Count the number of occurrences of each word in the file."""
|
271 |
with open(filename, 'r') as f:
|
272 |
word_counts = {}
|
|
|
319 |
"mode": "python"
|
320 |
},
|
321 |
"javascript": {
|
322 |
+
"prompt": "// fetch from the given URL and load the response contents into a new div",
|
323 |
"length": 64,
|
324 |
"temperature": 0.6,
|
325 |
"mode": "javascript"
|
|
|
538 |
console.log("Response:");
|
539 |
console.log(receive_data);
|
540 |
if (receive_data["result"] == "success") {
|
541 |
+
console.log("success");
|
542 |
// $("#prompt").text(data["prompt"]);
|
543 |
// $("#response").text(data["text"]);
|
544 |
set_text(receive_data["text"]);
|
|
|
550 |
$("#warning").text("");
|
551 |
}
|
552 |
} else {
|
553 |
+
console.log("error");
|
554 |
set_text(receive_data["text"])
|
555 |
$("#error").text(receive_data["message"]);
|
556 |
}
|
|
|
563 |
$("#error").text(err);
|
564 |
}
|
565 |
|
566 |
+
encoded_data = encodeURIComponent(btoa(JSON.stringify(send_data)))
|
567 |
|
568 |
try {
|
569 |
const response = await fetch(`${url}?info=${encoded_data}`);
|
570 |
+
// const response = await fetch(`${url}` {
|
571 |
+
// method: 'GET',
|
572 |
+
// body: encoded_data,
|
573 |
+
// });
|
574 |
if (response.status >= 400) {
|
575 |
error(response.statusText);
|
576 |
+
console.log("here");
|
577 |
+
console.log(response.status);
|
578 |
} else {
|
579 |
response.json().then(success).catch(error).finally(complete);
|
580 |
}
|