import gradio as gr import torch from openprompt.plms import load_plm from openprompt import PromptDataLoader from openprompt.prompts import ManualVerbalizer from openprompt.prompts import ManualTemplate from openprompt.data_utils import InputExample from openprompt import PromptForClassification def sentiment_analysis(sentence, template, model_name, positive, neutral, negative): model_name = model_name template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}') template = template.replace('[MASK]', '{"mask"}') classes = ['positive', 'neutral', 'negative'] label_words = { "positive": positive.split(" "), "neutral": neutral.split(" "), "negative": negative.split(" "), } type_dic = { "bert-base-uncased":"bert", "roberta-base":"roberta", "yiyanghkust/finbert-pretrain":"bert", } testdata = [InputExample(guid=0,text_a=sentence,label=0)] plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name) promptTemplate = ManualTemplate( text = template, tokenizer = tokenizer, ) promptVerbalizer = ManualVerbalizer( classes = classes, label_words = label_words, tokenizer = tokenizer, ) test_dataloader = PromptDataLoader( dataset = testdata, tokenizer = tokenizer, template = promptTemplate, tokenizer_wrapper_class = WrapperClass, batch_size = 1, max_seq_length = 512, ) prompt_model = PromptForClassification( plm=plm, template=promptTemplate, verbalizer=promptVerbalizer, freeze_plm=False #whether or not to freeze the pretrained language model ) for step, inputs in enumerate(test_dataloader): logits = prompt_model(inputs) return classes[torch.argmax(logits, dim=-1)[0]] demo = gr.Interface(fn=sentiment_analysis, inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"), gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"), gr.Radio(choices=["roberta-base","bert-base-uncased","yiyanghkust/finbert-pretrain"], label="model choics"), gr.Textbox(placeholder="Separate words with Spaces.",label="positive label words"), gr.Textbox(placeholder="Separate words with Spaces.",label="neutral label words"), gr.Textbox(placeholder="Separate words with Spaces.",label="negative label words") ], outputs="text", ) demo.launch()