File size: 5,642 Bytes
f63856c
 
 
 
 
 
54683ea
60189d4
f63856c
5c38735
09731af
f63856c
 
 
 
 
 
 
 
 
 
5c38735
f63856c
 
 
 
 
 
9453e0d
 
f63856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8568c2
f63856c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import pipeline
import os
import torch

description = """# <p style="text-align: center; color: white;"> 🎅🔨 <span style='color: #ff75b3;'>SantaFixer:</span> Bug Fixing</p>
<span style='color: white;'>This is a demo for <a href="https://huggingface.co/lambdasec/santafixer" style="color: #ff75b3;">SantaFixer</a>, a fine-tuned version of the SantaCoder model that is focussed on bug fixing. Just specify where you would like the model to fix the code with the <span style='color: #ff75b3;'>&lt;FILL-HERE&gt;</span> token.</span>"""

# token = os.environ["HUB_TOKEN"]
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"

GENERATION_TITLE= "<p style='font-size: 16px; color: white;'>Generated code:</p>"

tokenizer_fim = AutoTokenizer.from_pretrained("lambdasec/santafixer", padding_side="left")

tokenizer_fim.add_special_tokens({
  "additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
  "pad_token": EOD,
})

tokenizer = AutoTokenizer.from_pretrained("bigcode/christmas-models")
model = AutoModelForCausalLM.from_pretrained("bigcode/christmas-models", trust_remote_code=True).to(device)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)

def post_processing(prompt, completion):
    completion = "<span style='color: #ff75b3;'>" + completion + "</span>"
    prompt = "<span style='color: #727cd6;'>" + prompt + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prompt}{completion}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def post_processing_fim(prefix, middle, suffix):
    prefix = "<span style='color: #727cd6;'>" + prefix + "</span>"
    middle = "<span style='color: #ff75b3;'>" + middle + "</span>"
    suffix = "<span style='color: #727cd6;'>" + suffix + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prefix}{middle}{suffix}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def fim_generation(prompt, max_new_tokens, temperature):
    prefix = prompt.split("<FILL-HERE>")[0]
    suffix = prompt.split("<FILL-HERE>")[1]
    [middle] = infill((prefix, suffix), max_new_tokens, temperature)
    return post_processing_fim(prefix, middle, suffix)

def extract_fim_part(s: str):
    # Find the index of 
    start = s.find(FIM_MIDDLE) + len(FIM_MIDDLE)
    stop = s.find(EOD, start) or len(s)
    return s[start:stop]

def infill(prefix_suffix_tuples, max_new_tokens, temperature):
    if type(prefix_suffix_tuples) == tuple:
        prefix_suffix_tuples = [prefix_suffix_tuples]
        
    prompts = [f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" for prefix, suffix in prefix_suffix_tuples]
    # `return_token_type_ids=False` is essential, or we get nonsense output.
    inputs = tokenizer_fim(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id
        )
    # WARNING: cannot use skip_special_tokens, because it blows away the FIM special tokens.
    return [        
        extract_fim_part(tokenizer_fim.decode(tensor, skip_special_tokens=False)) for tensor in outputs
    ]


def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42):
    #set_seed(seed)
    
    if "<FILL-HERE>" in prompt:
        return fim_generation(prompt, max_new_tokens, temperature=0.2)
    else:
        completion = pipe(prompt, do_sample=True, top_p=0.95, temperature=temperature, max_new_tokens=max_new_tokens)[0]['generated_text']
        completion = completion[len(prompt):]
        return post_processing(prompt, completion)


demo = gr.Blocks(
    css=".gradio-container {background-color: #20233fff; color:white}"
)
with demo:
    with gr.Row():
        _, colum_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
        with colum_2:
            gr.Markdown(value=description)
            code = gr.Code(lines=5, language="python", label="Input code", value="def all_odd_elements(sequence):\n    \"\"\"Returns every odd element of the sequence.\"\"\"")
            
            with gr.Accordion("Advanced settings", open=False):
                max_new_tokens= gr.Slider(
                    minimum=8,
                    maximum=1024,
                    step=1,
                    value=48,
                    label="Number of tokens to generate",
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.5,
                    step=0.1,
                    value=0.2,
                    label="Temperature",
                )
                seed = gr.Slider(
                    minimum=0,
                    maximum=1000,
                    step=1,
                    label="Random seed to use for the generation"
                )
            run = gr.Button()
            output = gr.HTML(label="Generated code")

    event = run.click(code_generation, [code, max_new_tokens, temperature, seed], output, api_name="predict")
    # gr.HTML(label="Contact", value="<img src='https://huggingface.co/datasets/bigcode/admin/resolve/main/bigcode_contact.png' alt='contact' style='display: block; margin: auto; max-width: 800px;'>")

demo.launch()