Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from transformers import XGLMTokenizer, XGLMForCausalLM | |
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M") | |
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") | |
data_samples = { | |
'en': [ | |
{ | |
"premise": "I wanted to conserve energy.", | |
"choice1": "I swept the floor in the unoccupied room.", | |
"choice2": "I shut off the light in the unoccupied room.", | |
"question": "effect", | |
"label": "1" | |
} | |
], | |
'zh': [ | |
{ | |
"premise": "ζζ³θηΊ¦θ½ζΊγ", | |
"choice1": "ζε¨η©ΊηηζΏι΄ιζ«δΊε°ζΏγ", | |
"choice2": "ζζη©ΊζΏι΄ιηη―ε ³δΊγ", | |
"question": "effect", | |
"label": "1" | |
} | |
] | |
} | |
def get_logprobs(prompt): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids, output_ids = inputs["input_ids"], inputs["input_ids"][:, 1:] | |
outputs = model(**inputs, labels=input_ids) | |
logits = outputs.logits | |
logprobs = torch.gather(F.log_softmax(logits, dim=2), 2, output_ids.unsqueeze(2)) | |
return logprobs | |
# Zero-shot evaluation for the Choice of Plausible Alternatives (COPA) task. | |
# A return value of 0 indicates that the first alternative is more plausible, | |
# while 1 indicates that the second alternative is more plausible. | |
def COPA_eval(premise, choice1, choice2): | |
lprob1 = get_logprobs(premise + "\n" + choice1).sum() | |
lprob2 = get_logprobs(premise + "\n" + choice2).sum() | |
#return 0 if lprob1 > lprob2 else 1 | |
return choice1 if lprob1 > lprob2 else choice2 | |
iface = gr.Interface( | |
fn=COPA_eval, | |
inputs=["text", "text", "text"], | |
outputs=["text"], | |
) | |
iface.launch() |