from typing import Any import gradio as gr import itertools import torch from listener import Listener device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # pragmatic_listener = Listener( # "pragmatic-programs/pragmatic-ft-listener", # { # "do_sample": True, # "num_return_sequences": 100, # "num_beams": 1, # "temperature": 1, # "top_p": 0.9, # "max_new_tokens": 128, # }, # ) listener = Listener( "pragmatic-programs/listener-suffix-idx-300k", { "do_sample": True, "num_return_sequences": 100, "num_beams": 1, "temperature": 1, "top_p": 0.9, "max_new_tokens": 128, }, ) N_EXAMPLES = 3 def synthesize(*inps): strings = [inps[2 * i] for i in range(N_EXAMPLES) if len(inps[2 * i]) > 0] labels = [ "+" if inps[2 * i + 1] == "match" else "-" for i in range(N_EXAMPLES) if inps[2 * i + 1] is not None ] spec = [list(zip(strings, labels))] # if len(context.strip()) == 0: # return "Empty specification", "Empty specification" # spec = [[[s[:-1], s[-1]] for s in context.strip().split(" ")]] if len(spec) == 0: return "Empty specification" if not all([len(s) > 0 and l in ["+", "-"] for s, l in spec[0]]): return "Invalid specification", "Invalid specification" # pragmatic_outputs = pragmatic_listener.synthesize(spec).programs # literal_outputs = literal_listener.synthesize(spec).programs # if len(pragmatic_outputs[0]) > 0: # pragmatic_program = pragmatic_outputs[0][0] # else: # pragmatic_program = "No program found" # if len(literal_outputs[0]) > 0: # literal_program = literal_outputs[0][0] # else: # literal_program = "No program found" # return pragmatic_program, literal_program outputs = listener.synthesize(spec).programs if len(outputs[0]) > 0: return outputs[0][0] else: return "No program found" input_fields = [ ( gr.Textbox( lines=1, label=f"Example {i + 1}", # info="Enter a list of examples, separated by spaces. Each example is the string followed by a + or - indicating whether it should be accepted or rejected by the synthesized regex.", container=True, ), gr.Radio(["match", "not match"], container=False, label="Label"), ) for i in range(N_EXAMPLES) ] iface = gr.Interface( fn=synthesize, inputs=list(itertools.chain.from_iterable(input_fields)), outputs=gr.Textbox(lines=1, label="Synthesizer output"), # [ # gr.Textbox(lines=1, label="Pragmatic model"), # gr.Textbox(lines=1, label="Literal model"), # ], # examples=["ab+ aabb+ abb+ abab-", "b0b+ aa0000bb+"], title="Synthesizing regular expressions from examples", theme=gr.themes.Soft(primary_hue="blue"), ) iface.launch()