Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline, set_seed, AutoTokenizer | |
from uq import BertForUQSequenceClassification | |
def predict(sentence): | |
model_path = "tombm/bert-base-uncased-finetuned-cola" | |
classifier = pipeline("text-classification", model=model_path, tokenizer=model_path) | |
label = classifier(sentence)[0]["label"] | |
return label | |
def uncertainty(sentence): | |
model_path = "tombm/bert-base-uncased-finetuned-cola" | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = BertForUQSequenceClassification.from_pretrained(model_path) | |
test_input = tokenizer(sentence, return_tensors="pt") | |
model.return_gp_cov = True | |
_, gp_cov = model(**test_input) | |
return gp_cov.item() | |
with gr.Blocks() as demo: | |
set_seed(12) | |
intro_str = """The *cola* dataset focuses on determining whether sentences are grammatically correct. | |
Firstly, let's see how our finetuned model classifies two sentences, | |
the first of which is correct (i.e. valid) and the second is not (i.e. invalid):""" | |
gr.Markdown(value=intro_str) | |
gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(value="Good morning.", label="Input"), | |
outputs="text", | |
) | |
gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox( | |
value="This sentence is sentence, this is a correct sentence!", | |
label="Input", | |
), | |
outputs="text", | |
) | |
explain_str = """As we can see, our model correctly classifies the first sentence, but misclassifies the second. | |
Let's now inspect the uncertainties associated with each prediction generated by our GP head:""" | |
gr.Markdown(value=explain_str) | |
gr.Interface( | |
fn=uncertainty, | |
inputs=gr.Textbox(value="Good morning.", label="Input"), | |
outputs=gr.Number(label="Variance from GP head"), | |
) # should have low uncertainty | |
gr.Interface( | |
fn=uncertainty, | |
inputs=gr.Textbox( | |
value="This sentence is sentence, this is a correct sentence!", | |
label="Input", | |
), | |
outputs=gr.Number(label="Variance from GP head"), | |
) # should have high uncertainty | |
final_str = """We can see here that the variance for the misclassified example is much higher than for the correctly | |
classified example. This is great, as now we have some indication of when our model might be uncertain!""" | |
gr.Markdown(value=final_str) | |
demo.launch() | |
# iface = gr.Interface(fn=predict, inputs="text", outputs="text") | |
# iface.launch() | |