gp-uq-tester / app.py
tombm's picture
Add functionality to app
5212a08
raw
history blame
No virus
2.48 kB
import gradio as gr
from transformers import pipeline, set_seed, AutoTokenizer
from uq import BertForUQSequenceClassification
def predict(sentence):
model_path = "tombm/bert-base-uncased-finetuned-cola"
classifier = pipeline("text-classification", model=model_path, tokenizer=model_path)
label = classifier(sentence)
return label
def uncertainty(sentence):
model_path = "tombm/bert-base-uncased-finetuned-cola"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForUQSequenceClassification.from_pretrained(model_path)
test_input = tokenizer(sentence, return_tensors="pt")
model.return_gp_cov = True
_, gp_cov = model(**test_input)
return str(gp_cov.item())
with gr.Blocks() as demo:
set_seed(12)
intro_str = """The *cola* dataset focuses on determining whether sentences are grammatically correct.
Firstly, let's see how our finetuned model classifies two sentences,
the first of which is correct (i.e. valid) and the second is not (i.e. invalid):"""
gr.Markdown(value=intro_str)
gr.Interface(
fn=predict,
inputs=gr.Textbox(value="Good morning.", label="Input"),
outputs="label",
)
gr.Interface(
fn=predict,
inputs=gr.Textbox(
value="This sentence is sentence, this is a correct sentence!",
label="Input",
),
outputs="label",
)
explain_str = """As we can see, our model correctly classifies the first sentence, but misclassifies the second.
Let's now inspect the uncertainties associated with each prediction generated by our GP head:"""
gr.Markdown(value=explain_str)
gr.Interface(
fn=uncertainty,
inputs=gr.Textbox(value="Good morning.", label="Input"),
outputs="text",
) # should have low uncertainty
gr.Interface(
fn=uncertainty,
inputs=gr.Textbox(
value="This sentence is sentence, this is a correct sentence!",
label="Input",
),
outputs="text",
) # should have high uncertainty
final_str = """We can see here that the variance for the misclassified example is much higher than for the correctly
classified example. This is great, as now we have some indication of when our model might be uncertain!"""
gr.Markdown(value=final_str)
demo.launch()
# iface = gr.Interface(fn=predict, inputs="text", outputs="text")
# iface.launch()