Spaces:
Build error
Build error
import gradio as gr | |
from transformers import BertTokenizerFast, BertForSequenceClassification,GPT2LMHeadModel,BartForConditionalGeneration | |
import torch | |
import math | |
class CHSentenceSmoothScorer(): | |
def __init__(self) -> None: | |
super().__init__() | |
self.tokenizer = BertTokenizerFast.from_pretrained( | |
"fnlp/bart-base-chinese") | |
self.model = BartForConditionalGeneration.from_pretrained( | |
"fnlp/bart-base-chinese") | |
def __call__(self, sentences): | |
input_ids = self.tokenizer.batch_encode_plus( | |
sentences, return_tensors='pt', | |
padding=True, | |
max_length=50, | |
truncation='longest_first' | |
)['input_ids'] | |
logits = self.model(input_ids).logits | |
softmax = torch.softmax(logits, dim=-1) | |
out = [] | |
for i, sentence in enumerate(sentences): | |
sent_token_ids = input_ids[i].tolist() | |
sent_token_ids = list( | |
filter(lambda x: x not in [self.tokenizer.pad_token_id], sent_token_ids)) | |
ppl = 0.0 | |
for j, token_id in enumerate(sent_token_ids): | |
ppl += math.log(softmax[i][j][token_id].item()) | |
ppl = -1*(ppl/len(sent_token_ids)) | |
prob_socre = math.exp(ppl*-1) | |
out.append(prob_socre) | |
return out | |
model = BertForSequenceClassification.from_pretrained('./ch-sent-check-model') | |
tokenizer = BertTokenizerFast.from_pretrained('./ch-sent-check-model') | |
smooth_scorer = CHSentenceSmoothScorer() | |
def judge(sentence): | |
input_ids = tokenizer(sentence,return_tensors='pt')['input_ids'] | |
out = model(input_ids) | |
logits = out.logits | |
prob = torch.softmax(logits,dim=-1) | |
pred = torch.argmax(prob,dim=-1).item() | |
pred_text = 'Incorrect' if pred == 0 else 'Correct' | |
correct_prob = prob[0][1].item() | |
pred_text = pred_text + f", score: {round(correct_prob*100,2)}" | |
smooth_score = round(smooth_scorer([sentence])[0]*100,2) | |
return pred_text,smooth_score | |
iface = gr.Interface( | |
fn=judge, | |
inputs=gr.Textbox( | |
label="請輸入一段中文句子來檢測正確性", | |
lines=1, | |
), | |
outputs=[ | |
gr.Textbox( | |
label="正確性檢查", | |
lines=1 | |
), | |
gr.Textbox( | |
label="流暢性檢查", | |
lines=1 | |
) | |
], | |
examples = [ | |
'請注意用字的鄭確性', | |
'請注意用字的正確性' | |
] | |
) | |
iface.launch() |