File size: 3,196 Bytes
009f8e2
23c0953
afb1063
d10fd10
ad8550a
 
d10fd10
 
f4b9a92
ad8550a
23c0953
 
 
6213036
23c0953
bab84e2
23c0953
73ca0b1
 
23c0953
5df676a
 
9b8cca7
5df676a
5065ff5
23c0953
bab84e2
23c0953
1811e61
 
 
23c0953
f4b9a92
23c0953
 
5aafc6e
 
afb1063
 
 
23c0953
 
fa07e23
afb1063
90fd9d9
 
21b62df
 
 
90fd9d9
21b62df
 
 
23c0953
90fd9d9
fa07e23
dd8c189
5df676a
 
 
 
 
 
4fe4bcb
 
5065ff5
9b8cca7
5065ff5
afb1063
23c0953
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("pszemraj/flan-t5-large-grammar-synthesis")
tokenizer = AutoTokenizer.from_pretrained("pszemraj/flan-t5-large-grammar-synthesis")


def correct_text(text, genConfig):
    inputs = tokenizer.encode("" + text, return_tensors="pt")
    outputs = model.generate(inputs, **genConfig.to_dict())

    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_text

def respond(text, max_new_tokens, min_new_tokens, num_beams, num_beam_groups, temperature, top_k, top_p, no_repeat_ngram_size, guidance_scale, do_sample: bool):
    config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        min_new_tokens=min_new_tokens,
        num_beams=num_beams,
        num_beam_groups=num_beam_groups,
        temperature=float(temperature),
        top_k=top_k,
        top_p=float(top_p),
        no_repeat_ngram_size=no_repeat_ngram_size,
        early_stopping=True,
        do_sample=do_sample
    )

    if guidance_scale > 0:
        config.guidance_scale = float(guidance_scale)
    
    corrected = correct_text(text, config)
    yield corrected

def update_prompt(prompt):
    return prompt

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("""# Grammar Correction App""")
    prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
    output_box = gr.Textbox()

    # Sample prompts
    with gr.Row():
        samp1 = gr.Button("we shood buy an car")
        samp2 = gr.Button("she is more taller")
        samp3 = gr.Button("John and i saw a sheep over their.")
        
        samp1.click(update_prompt, samp1, prompt_box)
        samp2.click(update_prompt, samp2, prompt_box)
        samp3.click(update_prompt, samp3, prompt_box)
    submitBtn = gr.Button("Submit")
    
    with gr.Accordion("Generation Parameters:", open=False):
        max_tokens  = gr.Slider(minimum=1,   maximum=256,   value=50,  step=1,    label="Max New Tokens")
        min_tokens  = gr.Slider(minimum=0,   maximum=256,   value=0,   step=1,    label="Min New Tokens")
        num_beams   = gr.Slider(minimum=1,   maximum=20,    value=5,   step=1,    label="Num Beams")
        beam_groups = gr.Slider(minimum=1,   maximum=20,    value=1,   step=1,    label="Num Beams Groups")
        temperature = gr.Slider(minimum=0.1, maximum=100.0, value=0.7, step=0.1,  label="Temperature")
        top_k       = gr.Slider(minimum=0,   maximum=200,   value=50,  step=1,    label="Top-k")
        top_p       = gr.Slider(minimum=0.1, maximum=1.0,   value=1.0, step=0.05, label="Top-p (nucleus sampling)")
        guideScale  = gr.Slider(minimum=0.1, maximum=50.0,  value=1.0, step=0.1, label="Guidance Scale")
        no_repeat_ngram_size = gr.Slider(0, 20, value=0, step=1, label="Limit N-grams of given Size")
        do_sample = gr.Checkbox(value=True, label="Do Sampling")
        
    submitBtn.click(respond, [prompt_box, max_tokens, min_tokens, num_beams, beam_groups, temperature, top_k, top_p, no_repeat_ngram_size, guideScale, do_sample], output_box)

demo.launch()