import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-scico') model = AutoModelForSequenceClassification.from_pretrained('allenai/longformer-scico') start_token = tokenizer.convert_tokens_to_ids("") end_token = tokenizer.convert_tokens_to_ids("") def get_global_attention(input_ids): global_attention_mask = torch.zeros(input_ids.shape) global_attention_mask[:, 0] = 1 # global attention to the CLS token start = torch.nonzero(input_ids == start_token) # global attention to the token end = torch.nonzero(input_ids == end_token) # global attention to the token globs = torch.cat((start, end)) value = torch.ones(globs.shape[0]) global_attention_mask.index_put_(tuple(globs.t()), value) return global_attention_mask def inference(m1,m2): b = {} m1 = m1 m2 = m2 inputs = m1 + " " + m2 tokens = tokenizer(inputs, return_tensors='pt') global_attention_mask = get_global_attention(tokens['input_ids']) with torch.no_grad(): output = model(tokens['input_ids'], tokens['attention_mask'], global_attention_mask) scores = torch.softmax(output.logits, dim=-1) listscore = scores.tolist() print(listscore) b['not related'] = listscore[0][0] b['coref'] = listscore[0][1] b['parent'] = listscore[0][2] b['child'] = listscore[0][3] return b title = "Longformer-scico" description = "Gradio demo for Longformer-scico. To use it, simply add your text, or click one of the examples to load them. Read more at the links below." article = "

SciCo: Hierarchical Cross-Document Coreference for Scientific Concepts | Github Repo

" examples = [["In this paper we present the results of an experiment in automatic concept and definition extraction from written sources of law using relatively simple natural methods.","This task is important since many natural language processing (NLP) problems, such as information extraction , summarization and dialogue."]] gr.Interface( inference, [gr.inputs.Textbox(label="Input1"),gr.inputs.Textbox(label="Input2")], gr.outputs.Label(label="Output"), title=title, description=description, article=article, enable_queue=True, examples=examples ).launch(debug=True)