Ahsen Khaliq
Update app.py
3b50e6e
raw history blame
No virus
2.73 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-scico')
model = AutoModelForSequenceClassification.from_pretrained('allenai/longformer-scico')
start_token = tokenizer.convert_tokens_to_ids("<m>")
end_token = tokenizer.convert_tokens_to_ids("</m>")
def get_global_attention(input_ids):
global_attention_mask = torch.zeros(input_ids.shape)
global_attention_mask[:, 0] = 1 # global attention to the CLS token
start = torch.nonzero(input_ids == start_token) # global attention to the <m> token
end = torch.nonzero(input_ids == end_token) # global attention to the </m> token
globs = torch.cat((start, end))
value = torch.ones(globs.shape[0])
global_attention_mask.index_put_(tuple(globs.t()), value)
return global_attention_mask
def inference(m1,m2):
b = {}
m1 = m1
m2 = m2
inputs = m1 + " </s></s> " + m2
tokens = tokenizer(inputs, return_tensors='pt')
global_attention_mask = get_global_attention(tokens['input_ids'])
with torch.no_grad():
output = model(tokens['input_ids'], tokens['attention_mask'], global_attention_mask)
scores = torch.softmax(output.logits, dim=-1)
listscore = scores.tolist()
print(listscore)
b['not related'] = listscore[0][0]
b['coref'] = listscore[0][1]
b['parent'] = listscore[0][2]
b['child'] = listscore[0][3]
return b
title = "Longformer-scico"
description = """Gradio demo for Longformer-scico. To use it, simply add your text, or click one of the examples to load them. Read more at the links below. The model takes as input two mentions m1 and m2 with their corresponding context and outputs 4 scores:
0: not related
1: m1 and m2 corefer
2: m1 is a parent of m2
3: m1 is a child of m2."""
article = "<p style='text-align: center'><a href='https://openreview.net/forum?id=OFLbgUP04nC' target='_blank'>SciCo: Hierarchical Cross-Document Coreference for Scientific Concepts</a> | <a href='https://github.com/ariecattan/SciCo' target='_blank'>Github Repo</a></p>"
examples = [["In this paper we present the results of an experiment in <m> automatic concept and definition extraction </m> from written sources of law using relatively simple natural methods.","This task is important since many natural language processing (NLP) problems, such as <m> information extraction </m>, summarization and dialogue."]]
gr.Interface(
inference,
[gr.inputs.Textbox(label="m1"),gr.inputs.Textbox(label="m2")],
gr.outputs.Label(label="Output"),
title=title,
description=description,
article=article,
enable_queue=True,
examples=examples
).launch(debug=True)