saujasv's picture
reduce sampling budget
24ce2c0
from typing import Any
import gradio as gr
import itertools
import torch
from listener import Listener
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# pragmatic_listener = Listener(
# "pragmatic-programs/pragmatic-ft-listener",
# {
# "do_sample": True,
# "num_return_sequences": 100,
# "num_beams": 1,
# "temperature": 1,
# "top_p": 0.9,
# "max_new_tokens": 128,
# },
# )
listener = Listener(
"pragmatic-programs/listener-suffix-idx-300k",
{
"do_sample": True,
"num_return_sequences": 100,
"num_beams": 1,
"temperature": 1,
"top_p": 0.9,
"max_new_tokens": 128,
},
)
N_EXAMPLES = 3
def synthesize(*inps):
strings = [inps[2 * i] for i in range(N_EXAMPLES) if len(inps[2 * i]) > 0]
labels = [
"+" if inps[2 * i + 1] == "match" else "-"
for i in range(N_EXAMPLES)
if inps[2 * i + 1] is not None
]
spec = [list(zip(strings, labels))]
# if len(context.strip()) == 0:
# return "Empty specification", "Empty specification"
# spec = [[[s[:-1], s[-1]] for s in context.strip().split(" ")]]
if len(spec) == 0:
return "Empty specification"
if not all([len(s) > 0 and l in ["+", "-"] for s, l in spec[0]]):
return "Invalid specification", "Invalid specification"
# pragmatic_outputs = pragmatic_listener.synthesize(spec).programs
# literal_outputs = literal_listener.synthesize(spec).programs
# if len(pragmatic_outputs[0]) > 0:
# pragmatic_program = pragmatic_outputs[0][0]
# else:
# pragmatic_program = "No program found"
# if len(literal_outputs[0]) > 0:
# literal_program = literal_outputs[0][0]
# else:
# literal_program = "No program found"
# return pragmatic_program, literal_program
outputs = listener.synthesize(spec).programs
if len(outputs[0]) > 0:
return outputs[0][0]
else:
return "No program found"
input_fields = [
(
gr.Textbox(
lines=1,
label=f"Example {i + 1}",
# info="Enter a list of examples, separated by spaces. Each example is the string followed by a + or - indicating whether it should be accepted or rejected by the synthesized regex.",
container=True,
),
gr.Radio(["match", "not match"], container=False, label="Label"),
)
for i in range(N_EXAMPLES)
]
iface = gr.Interface(
fn=synthesize,
inputs=list(itertools.chain.from_iterable(input_fields)),
outputs=gr.Textbox(lines=1, label="Synthesizer output"),
# [
# gr.Textbox(lines=1, label="Pragmatic model"),
# gr.Textbox(lines=1, label="Literal model"),
# ],
# examples=["ab+ aabb+ abb+ abab-", "b0b+ aa0000bb+"],
title="Synthesizing regular expressions from examples",
theme=gr.themes.Soft(primary_hue="blue"),
)
iface.launch()