import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import pipeline
import os
import torch
description = """#
🎅 SantaFixer: Code Generation
This is a demo to generate code with SantaCoder,
a 1.1B parameter model for code generation in Python, Java & JavaScript. The model can also do infilling, just specify where you would like the model to complete code
with the <FILL-HERE> token."""
token = os.environ["HUB_TOKEN"]
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
FIM_PREFIX = ""
FIM_MIDDLE = ""
FIM_SUFFIX = ""
FIM_PAD = ""
EOD = "<|endoftext|>"
GENERATION_TITLE= "Generated code:
"
tokenizer_fim = AutoTokenizer.from_pretrained("lambdasec/santafixer", use_auth_token=token, padding_side="left")
tokenizer_fim.add_special_tokens({
"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
"pad_token": EOD,
})
tokenizer = AutoTokenizer.from_pretrained("bigcode/christmas-models", use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained("bigcode/christmas-models", trust_remote_code=True, use_auth_token=token).to(device)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
def post_processing(prompt, completion):
completion = "" + completion + ""
prompt = "" + prompt + ""
code_html = f"
{prompt}{completion}
"
return GENERATION_TITLE + code_html
def post_processing_fim(prefix, middle, suffix):
prefix = "" + prefix + ""
middle = "" + middle + ""
suffix = "" + suffix + ""
code_html = f"
{prefix}{middle}{suffix}
"
return GENERATION_TITLE + code_html
def fim_generation(prompt, max_new_tokens, temperature):
prefix = prompt.split("")[0]
suffix = prompt.split("")[1]
[middle] = infill((prefix, suffix), max_new_tokens, temperature)
return post_processing_fim(prefix, middle, suffix)
def extract_fim_part(s: str):
# Find the index of
start = s.find(FIM_MIDDLE) + len(FIM_MIDDLE)
stop = s.find(EOD, start) or len(s)
return s[start:stop]
def infill(prefix_suffix_tuples, max_new_tokens, temperature):
if type(prefix_suffix_tuples) == tuple:
prefix_suffix_tuples = [prefix_suffix_tuples]
prompts = [f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" for prefix, suffix in prefix_suffix_tuples]
# `return_token_type_ids=False` is essential, or we get nonsense output.
inputs = tokenizer_fim(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.pad_token_id
)
# WARNING: cannot use skip_special_tokens, because it blows away the FIM special tokens.
return [
extract_fim_part(tokenizer_fim.decode(tensor, skip_special_tokens=False)) for tensor in outputs
]
def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42):
#set_seed(seed)
if "" in prompt:
return fim_generation(prompt, max_new_tokens, temperature=0.2)
else:
completion = pipe(prompt, do_sample=True, top_p=0.95, temperature=temperature, max_new_tokens=max_new_tokens)[0]['generated_text']
completion = completion[len(prompt):]
return post_processing(prompt, completion)
demo = gr.Blocks(
css=".gradio-container {background-color: #20233fff; color:white}"
)
with demo:
with gr.Row():
_, colum_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
with colum_2:
gr.Markdown(value=description)
code = gr.Code(lines=5, language="python", label="Input code", value="def all_odd_elements(sequence):\n \"\"\"Returns every odd element of the sequence.\"\"\"")
with gr.Accordion("Advanced settings", open=False):
max_new_tokens= gr.Slider(
minimum=8,
maximum=1024,
step=1,
value=48,
label="Number of tokens to generate",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.5,
step=0.1,
value=0.2,
label="Temperature",
)
seed = gr.Slider(
minimum=0,
maximum=1000,
step=1,
label="Random seed to use for the generation"
)
run = gr.Button()
output = gr.HTML(label="Generated code")
event = run.click(code_generation, [code, max_new_tokens, temperature, seed], output, api_name="predict")
gr.HTML(label="Contact", value="")
demo.launch()