georeactor's picture
Update app.py
d797f28
import gradio as gr
import torch
import ecco
import requests
from transformers import AutoTokenizer
from torch.nn import functional as F
header = """
import psycopg2
conn = psycopg2.connect("CONN")
cur = conn.cursor()
MIDDLE
def rename_customer(id, newName):\n\t# PROMPT\n\tcur.execute("UPDATE customer SET name =
"""
modelPath = {
# "GPT2-Medium": "gpt2-medium",
"CodeParrot-mini": "codeparrot/codeparrot-small",
# "CodeGen-350-Mono": "Salesforce/codegen-350M-mono",
# "GPT-Neo-1.3B": "EleutherAI/gpt-neo-1.3B",
"CodeParrot": "codeparrot/codeparrot",
# "CodeGen-2B-Mono": "Salesforce/codegen-2B-mono",
}
preloadModels = {}
for m in list(modelPath.keys()):
preloadModels[m] = ecco.from_pretrained(modelPath[m])
def generation(tokenizer, model, content):
decoder = 'Standard'
num_beams = 2 if decoder == 'Beam' else None
typical_p = 0.8 if decoder == 'Typical' else None
do_sample = (decoder in ['Beam', 'Typical', 'Sample'])
seek_token_ids = [
tokenizer.encode('= \'" +')[1:],
tokenizer.encode('= " +')[1:],
]
full_output = model.generate(content, generate=6, do_sample=False)
def next_words(code, position, seek_token_ids):
op_model = model.generate(code, generate=1, do_sample=False)
hidden_states = op_model.hidden_states
layer_no = len(hidden_states) - 1
h = hidden_states[-1]
hidden_state = h[position - 1]
logits = op_model.lm_head(op_model.to(hidden_state))
softmax = F.softmax(logits, dim=-1)
my_token_prob = softmax[seek_token_ids[0]]
if len(seek_token_ids) > 1:
newprompt = code + tokenizer.decode(seek_token_ids[0])
return my_token_prob * next_words(newprompt, position + 1, seek_token_ids[1:])
return my_token_prob
prob = 0
for opt in seek_token_ids:
prob += next_words(content, len(tokenizer(content)['input_ids']), opt)
return ["".join(full_output.tokens), str(prob.item() * 100) + '% chance of risky concatenation']
def code_from_prompts(prompt, model, type_hints, pre_content):
tokenizer = AutoTokenizer.from_pretrained(modelPath[model])
# model = ecco.from_pretrained(modelPath[model])
model = preloadModels[model]
code = header.strip().replace('CONN', "dbname='store'").replace('PROMPT', prompt)
if type_hints:
code = code.replace('id,', 'id: int,')
code = code.replace('id)', 'id: int)')
code = code.replace('newName)', 'newName: str) -> None')
if pre_content == 'None':
code = code.replace('MIDDLE\n', '')
elif 'Concatenation' in pre_content:
code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = ' + str(id))\n\treturn cur.fetchall()
""".strip() + "\n")
elif 'composition' in pre_content:
code = code.replace('MIDDLE', """
def get_customer(id):\n\tcur.execute('SELECT * FROM customers WHERE id = %s', str(id))\n\treturn cur.fetchall()
""".strip() + "\n")
results = generation(tokenizer, model, code)
return results
iface = gr.Interface(
fn=code_from_prompts,
inputs=[
gr.components.Textbox(label="Insert comment"),
gr.components.Radio(list(modelPath.keys()), label="Code Model"),
gr.components.Checkbox(label="Include type hints"),
gr.components.Radio([
"None",
"Proper composition: Include function 'WHERE id = %s'",
"Concatenation: Include a function with 'WHERE id = ' + id",
], label="Has user already written a function?")
],
outputs=[
gr.components.Textbox(label="Most probable code"),
gr.components.Textbox(label="Probability of concat"),
],
description="Prompt the code model to write a SQL query with string concatenation.",
)
iface.launch()