File size: 5,752 Bytes
f63856c
 
 
 
 
 
09731af
f63856c
 
 
 
 
09731af
f63856c
 
 
 
 
 
 
 
 
 
09731af
f63856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import pipeline
import os
import torch

description = """# <p style="text-align: center; color: white;"> 🎅 <span style='color: #ff75b3;'>SantaFixer:</span> Code Generation </p>
<span style='color: white;'>This is a demo to generate code with <a href="https://huggingface.co/bigcode/santacoder" style="color: #ff75b3;">SantaCoder</a>,
a 1.1B parameter model for code generation in Python, Java & JavaScript. The model can also do infilling, just specify where you would like the model to complete code
with the <span style='color: #ff75b3;'>&lt;FILL-HERE&gt;</span> token.</span>"""

token = os.environ["HUB_TOKEN"]
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"

GENERATION_TITLE= "<p style='font-size: 16px; color: white;'>Generated code:</p>"

tokenizer_fim = AutoTokenizer.from_pretrained("lambdasec/santafixer", use_auth_token=token, padding_side="left")

tokenizer_fim.add_special_tokens({
  "additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
  "pad_token": EOD,
})

tokenizer = AutoTokenizer.from_pretrained("bigcode/christmas-models", use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained("bigcode/christmas-models", trust_remote_code=True, use_auth_token=token).to(device)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)

def post_processing(prompt, completion):
    completion = "<span style='color: #ff75b3;'>" + completion + "</span>"
    prompt = "<span style='color: #727cd6;'>" + prompt + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prompt}{completion}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def post_processing_fim(prefix, middle, suffix):
    prefix = "<span style='color: #727cd6;'>" + prefix + "</span>"
    middle = "<span style='color: #ff75b3;'>" + middle + "</span>"
    suffix = "<span style='color: #727cd6;'>" + suffix + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prefix}{middle}{suffix}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def fim_generation(prompt, max_new_tokens, temperature):
    prefix = prompt.split("<FILL-HERE>")[0]
    suffix = prompt.split("<FILL-HERE>")[1]
    [middle] = infill((prefix, suffix), max_new_tokens, temperature)
    return post_processing_fim(prefix, middle, suffix)

def extract_fim_part(s: str):
    # Find the index of 
    start = s.find(FIM_MIDDLE) + len(FIM_MIDDLE)
    stop = s.find(EOD, start) or len(s)
    return s[start:stop]

def infill(prefix_suffix_tuples, max_new_tokens, temperature):
    if type(prefix_suffix_tuples) == tuple:
        prefix_suffix_tuples = [prefix_suffix_tuples]
        
    prompts = [f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" for prefix, suffix in prefix_suffix_tuples]
    # `return_token_type_ids=False` is essential, or we get nonsense output.
    inputs = tokenizer_fim(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id
        )
    # WARNING: cannot use skip_special_tokens, because it blows away the FIM special tokens.
    return [        
        extract_fim_part(tokenizer_fim.decode(tensor, skip_special_tokens=False)) for tensor in outputs
    ]


def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42):
    #set_seed(seed)
    
    if "<FILL-HERE>" in prompt:
        return fim_generation(prompt, max_new_tokens, temperature=0.2)
    else:
        completion = pipe(prompt, do_sample=True, top_p=0.95, temperature=temperature, max_new_tokens=max_new_tokens)[0]['generated_text']
        completion = completion[len(prompt):]
        return post_processing(prompt, completion)


demo = gr.Blocks(
    css=".gradio-container {background-color: #20233fff; color:white}"
)
with demo:
    with gr.Row():
        _, colum_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
        with colum_2:
            gr.Markdown(value=description)
            code = gr.Code(lines=5, language="python", label="Input code", value="def all_odd_elements(sequence):\n    \"\"\"Returns every odd element of the sequence.\"\"\"")
            
            with gr.Accordion("Advanced settings", open=False):
                max_new_tokens= gr.Slider(
                    minimum=8,
                    maximum=1024,
                    step=1,
                    value=48,
                    label="Number of tokens to generate",
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.5,
                    step=0.1,
                    value=0.2,
                    label="Temperature",
                )
                seed = gr.Slider(
                    minimum=0,
                    maximum=1000,
                    step=1,
                    label="Random seed to use for the generation"
                )
            run = gr.Button()
            output = gr.HTML(label="Generated code")

    event = run.click(code_generation, [code, max_new_tokens, temperature, seed], output, api_name="predict")
    gr.HTML(label="Contact", value="<img src='https://huggingface.co/datasets/bigcode/admin/resolve/main/bigcode_contact.png' alt='contact' style='display: block; margin: auto; max-width: 800px;'>")

demo.launch()