Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- app.py +51 -0
- qasrl_model_pipeline.py +183 -0
- requirements.txt +2 -0
app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from qasrl_model_pipeline import QASRL_Pipeline
|
3 |
+
|
4 |
+
models = ["kleinay/qanom-seq2seq-model-baseline",
|
5 |
+
"kleinay/qanom-seq2seq-model-joint"]
|
6 |
+
pipelines = {model: QASRL_Pipeline(model) for model in models}
|
7 |
+
|
8 |
+
|
9 |
+
description = f"""Using Seq2Seq T5 model which takes a sequence of items and outputs another sequence this model generates Questions and Answers (QA) with focus on Semantic Role Labeling (SRL)"""
|
10 |
+
title="Seq2Seq T5 Questions and Answers (QA) with Semantic Role Labeling (SRL)"
|
11 |
+
examples = [[models[0], "In March and April the patient <p> had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", True, "fall"],
|
12 |
+
[models[1], "In March and April the patient had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions <p> like anaphylaxis and shortness of breath.", True, "reactions"],
|
13 |
+
[models[0], "In March and April the patient had two falls. One was related <p> to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", True, "relate"],
|
14 |
+
[models[1], "In March and April the patient <p> had two falls. One was related to asthma, heart palpitations. The second was due to syncope and post covid vaccination dizziness during exercise. The patient is now getting an EKG. Former EKG had shown that there was a bundle branch block. Patient had some uncontrolled immune system reactions like anaphylaxis and shortness of breath.", False, "fall"]]
|
15 |
+
|
16 |
+
input_sent_box_label = "Insert sentence here. Mark the predicate by adding the token '<p>' before it."
|
17 |
+
verb_form_inp_placeholder = "e.g. 'decide' for the nominalization 'decision', 'teach' for 'teacher', etc."
|
18 |
+
links = """<p style='text-align: center'>
|
19 |
+
<a href='https://www.qasrl.org' target='_blank'>QASRL Website</a> | <a href='https://huggingface.co/kleinay/qanom-seq2seq-model-baseline' target='_blank'>Model Repo at Huggingface Hub</a>
|
20 |
+
</p>"""
|
21 |
+
def call(model_name, sentence, is_nominal, verb_form):
|
22 |
+
predicate_marker="<p>"
|
23 |
+
if predicate_marker not in sentence:
|
24 |
+
raise ValueError("You must highlight one word of the sentence as a predicate using preceding '<p>'.")
|
25 |
+
|
26 |
+
if not verb_form:
|
27 |
+
if is_nominal:
|
28 |
+
raise ValueError("You should provide the verbal form of the nominalization")
|
29 |
+
|
30 |
+
toks = sentence.split(" ")
|
31 |
+
pred_idx = toks.index(predicate_marker)
|
32 |
+
predicate = toks(pred_idx+1)
|
33 |
+
verb_form=predicate
|
34 |
+
pipeline = pipelines[model_name]
|
35 |
+
pipe_out = pipeline([sentence],
|
36 |
+
predicate_marker=predicate_marker,
|
37 |
+
predicate_type="nominal" if is_nominal else "verbal",
|
38 |
+
verb_form=verb_form)[0]
|
39 |
+
return pipe_out["QAs"], pipe_out["generated_text"]
|
40 |
+
iface = gr.Interface(fn=call,
|
41 |
+
inputs=[gr.inputs.Radio(choices=models, default=models[0], label="Model"),
|
42 |
+
gr.inputs.Textbox(placeholder=input_sent_box_label, label="Sentence", lines=4),
|
43 |
+
gr.inputs.Checkbox(default=True, label="Is Nominalization?"),
|
44 |
+
gr.inputs.Textbox(placeholder=verb_form_inp_placeholder, label="Verbal form (for nominalizations)", default='')],
|
45 |
+
outputs=[gr.outputs.JSON(label="Model Output - QASRL"), gr.outputs.Textbox(label="Raw output sequence")],
|
46 |
+
title=title,
|
47 |
+
description=description,
|
48 |
+
article=links,
|
49 |
+
examples=examples )
|
50 |
+
|
51 |
+
iface.launch()
|
qasrl_model_pipeline.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import json
|
3 |
+
from argparse import Namespace
|
4 |
+
from pathlib import Path
|
5 |
+
from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
6 |
+
|
7 |
+
def get_markers_for_model(is_t5_model: bool) -> Namespace:
|
8 |
+
special_tokens_constants = Namespace()
|
9 |
+
if is_t5_model:
|
10 |
+
# T5 model have 100 special tokens by default
|
11 |
+
special_tokens_constants.separator_input_question_predicate = "<extra_id_1>"
|
12 |
+
special_tokens_constants.separator_output_answers = "<extra_id_3>"
|
13 |
+
special_tokens_constants.separator_output_questions = "<extra_id_5>" # if using only questions
|
14 |
+
special_tokens_constants.separator_output_question_answer = "<extra_id_7>"
|
15 |
+
special_tokens_constants.separator_output_pairs = "<extra_id_9>"
|
16 |
+
special_tokens_constants.predicate_generic_marker = "<extra_id_10>"
|
17 |
+
special_tokens_constants.predicate_verb_marker = "<extra_id_11>"
|
18 |
+
special_tokens_constants.predicate_nominalization_marker = "<extra_id_12>"
|
19 |
+
|
20 |
+
else:
|
21 |
+
special_tokens_constants.separator_input_question_predicate = "<question_predicate_sep>"
|
22 |
+
special_tokens_constants.separator_output_answers = "<answers_sep>"
|
23 |
+
special_tokens_constants.separator_output_questions = "<question_sep>" # if using only questions
|
24 |
+
special_tokens_constants.separator_output_question_answer = "<question_answer_sep>"
|
25 |
+
special_tokens_constants.separator_output_pairs = "<qa_pairs_sep>"
|
26 |
+
special_tokens_constants.predicate_generic_marker = "<predicate_marker>"
|
27 |
+
special_tokens_constants.predicate_verb_marker = "<verbal_predicate_marker>"
|
28 |
+
special_tokens_constants.predicate_nominalization_marker = "<nominalization_predicate_marker>"
|
29 |
+
return special_tokens_constants
|
30 |
+
|
31 |
+
def load_trained_model(name_or_path):
|
32 |
+
import huggingface_hub as HFhub
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
|
34 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path)
|
35 |
+
# load preprocessing_kwargs from the model repo on HF hub, or from the local model directory
|
36 |
+
kwargs_filename = None
|
37 |
+
if name_or_path.startswith("kleinay/"): # and 'preprocessing_kwargs.json' in HFhub.list_repo_files(name_or_path): # the supported version of HFhub doesn't support list_repo_files
|
38 |
+
kwargs_filename = HFhub.hf_hub_download(repo_id=name_or_path, filename="preprocessing_kwargs.json")
|
39 |
+
elif Path(name_or_path).is_dir() and (Path(name_or_path) / "experiment_kwargs.json").exists():
|
40 |
+
kwargs_filename = Path(name_or_path) / "experiment_kwargs.json"
|
41 |
+
|
42 |
+
if kwargs_filename:
|
43 |
+
preprocessing_kwargs = json.load(open(kwargs_filename))
|
44 |
+
# integrate into model.config (for decoding args, e.g. "num_beams"), and save also as standalone object for preprocessing
|
45 |
+
model.config.preprocessing_kwargs = Namespace(**preprocessing_kwargs)
|
46 |
+
model.config.update(preprocessing_kwargs)
|
47 |
+
return model, tokenizer
|
48 |
+
|
49 |
+
|
50 |
+
class QASRL_Pipeline(Text2TextGenerationPipeline):
|
51 |
+
def __init__(self, model_repo: str, **kwargs):
|
52 |
+
model, tokenizer = load_trained_model(model_repo)
|
53 |
+
super().__init__(model, tokenizer, framework="pt")
|
54 |
+
self.is_t5_model = "t5" in model.config.model_type
|
55 |
+
self.special_tokens = get_markers_for_model(self.is_t5_model)
|
56 |
+
self.data_args = model.config.preprocessing_kwargs
|
57 |
+
# backward compatibility - default keyword values implemeted in `run_summarization`, thus not saved in `preprocessing_kwargs`
|
58 |
+
if "predicate_marker_type" not in vars(self.data_args):
|
59 |
+
self.data_args.predicate_marker_type = "generic"
|
60 |
+
if "use_bilateral_predicate_marker" not in vars(self.data_args):
|
61 |
+
self.data_args.use_bilateral_predicate_marker = True
|
62 |
+
if "append_verb_form" not in vars(self.data_args):
|
63 |
+
self.data_args.append_verb_form = True
|
64 |
+
self._update_config(**kwargs)
|
65 |
+
|
66 |
+
def _update_config(self, **kwargs):
|
67 |
+
" Update self.model.config with initialization parameters and necessary defaults. "
|
68 |
+
# set default values that will always override model.config, but can overriden by __init__ kwargs
|
69 |
+
kwargs["max_length"] = kwargs.get("max_length", 80)
|
70 |
+
# override model.config with kwargs
|
71 |
+
for k,v in kwargs.items():
|
72 |
+
self.model.config.__dict__[k] = v
|
73 |
+
|
74 |
+
def _sanitize_parameters(self, **kwargs):
|
75 |
+
preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}
|
76 |
+
if "predicate_marker" in kwargs:
|
77 |
+
preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"]
|
78 |
+
if "predicate_type" in kwargs:
|
79 |
+
preprocess_kwargs["predicate_type"] = kwargs["predicate_type"]
|
80 |
+
if "verb_form" in kwargs:
|
81 |
+
preprocess_kwargs["verb_form"] = kwargs["verb_form"]
|
82 |
+
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
83 |
+
|
84 |
+
def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None):
|
85 |
+
# Here, inputs is string or list of strings; apply string postprocessing
|
86 |
+
if isinstance(inputs, str):
|
87 |
+
processed_inputs = self._preprocess_string(inputs, predicate_marker, predicate_type, verb_form)
|
88 |
+
elif hasattr(inputs, "__iter__"):
|
89 |
+
processed_inputs = [self._preprocess_string(s, predicate_marker, predicate_type, verb_form) for s in inputs]
|
90 |
+
else:
|
91 |
+
raise ValueError("inputs must be str or Iterable[str]")
|
92 |
+
# Now pass to super.preprocess for tokenization
|
93 |
+
return super().preprocess(processed_inputs)
|
94 |
+
|
95 |
+
def _preprocess_string(self, seq: str, predicate_marker: str, predicate_type: Optional[str], verb_form: Optional[str]) -> str:
|
96 |
+
sent_tokens = seq.split(" ")
|
97 |
+
assert predicate_marker in sent_tokens, f"Input sentence must include a predicate-marker token ('{predicate_marker}') before the target predicate word"
|
98 |
+
predicate_idx = sent_tokens.index(predicate_marker)
|
99 |
+
sent_tokens.remove(predicate_marker)
|
100 |
+
sentence_before_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx)])
|
101 |
+
predicate = sent_tokens[predicate_idx]
|
102 |
+
sentence_after_predicate = " ".join([sent_tokens[i] for i in range(predicate_idx+1, len(sent_tokens))])
|
103 |
+
|
104 |
+
if self.data_args.predicate_marker_type == "generic":
|
105 |
+
predicate_marker = self.special_tokens.predicate_generic_marker
|
106 |
+
# In case we want special marker for each predicate type: """
|
107 |
+
elif self.data_args.predicate_marker_type == "pred_type":
|
108 |
+
assert predicate_type is not None, "For this model, you must provide the `predicate_type` either when initializing QASRL_Pipeline(...) or when applying __call__(...) on it"
|
109 |
+
assert predicate_type in ("verbal", "nominal"), f"`predicate_type` must be either 'verbal' or 'nominal'; got '{predicate_type}'"
|
110 |
+
predicate_marker = {"verbal": self.special_tokens.predicate_verb_marker ,
|
111 |
+
"nominal": self.special_tokens.predicate_nominalization_marker
|
112 |
+
}[predicate_type]
|
113 |
+
|
114 |
+
if self.data_args.use_bilateral_predicate_marker:
|
115 |
+
seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {predicate_marker} {sentence_after_predicate}"
|
116 |
+
else:
|
117 |
+
seq = f"{sentence_before_predicate} {predicate_marker} {predicate} {sentence_after_predicate}"
|
118 |
+
|
119 |
+
# embed also verb_form
|
120 |
+
if self.data_args.append_verb_form and verb_form is None:
|
121 |
+
raise ValueError(f"For this model, you must provide the `verb_form` of the predicate when applying __call__(...)")
|
122 |
+
elif self.data_args.append_verb_form:
|
123 |
+
seq = f"{seq} {self.special_tokens.separator_input_question_predicate} {verb_form} "
|
124 |
+
else:
|
125 |
+
seq = f"{seq} "
|
126 |
+
|
127 |
+
# append source prefix (for t5 models)
|
128 |
+
prefix = self._get_source_prefix(predicate_type)
|
129 |
+
|
130 |
+
return prefix + seq
|
131 |
+
|
132 |
+
def _get_source_prefix(self, predicate_type: Optional[str]):
|
133 |
+
if not self.is_t5_model or self.data_args.source_prefix is None:
|
134 |
+
return ''
|
135 |
+
if not self.data_args.source_prefix.startswith("<"): # Regular prefix - not dependent on input row x
|
136 |
+
return self.data_args.source_prefix
|
137 |
+
if self.data_args.source_prefix == "<predicate-type>":
|
138 |
+
if predicate_type is None:
|
139 |
+
raise ValueError("source_prefix is '<predicate-type>' but input no `predicate_type`.")
|
140 |
+
else:
|
141 |
+
return f"Generate QAs for {predicate_type} QASRL: "
|
142 |
+
|
143 |
+
def _forward(self, *args, **kwargs):
|
144 |
+
outputs = super()._forward(*args, **kwargs)
|
145 |
+
return outputs
|
146 |
+
|
147 |
+
|
148 |
+
def postprocess(self, model_outputs):
|
149 |
+
output_seq = self.tokenizer.decode(
|
150 |
+
model_outputs["output_ids"].squeeze(),
|
151 |
+
skip_special_tokens=False,
|
152 |
+
clean_up_tokenization_spaces=False,
|
153 |
+
)
|
154 |
+
output_seq = output_seq.strip(self.tokenizer.pad_token).strip(self.tokenizer.eos_token).strip()
|
155 |
+
qa_subseqs = output_seq.split(self.special_tokens.separator_output_pairs)
|
156 |
+
qas = [self._postrocess_qa(qa_subseq) for qa_subseq in qa_subseqs]
|
157 |
+
return {"generated_text": output_seq,
|
158 |
+
"QAs": qas}
|
159 |
+
|
160 |
+
def _postrocess_qa(self, seq: str) -> str:
|
161 |
+
# split question and answers
|
162 |
+
if self.special_tokens.separator_output_question_answer in seq:
|
163 |
+
question, answer = seq.split(self.special_tokens.separator_output_question_answer)[:2]
|
164 |
+
else:
|
165 |
+
print("invalid format: no separator between question and answer found...")
|
166 |
+
return None
|
167 |
+
# question, answer = seq, '' # Or: backoff to only question
|
168 |
+
# skip "_" slots in questions
|
169 |
+
question = ' '.join(t for t in question.split(' ') if t != '_')
|
170 |
+
answers = [a.strip() for a in answer.split(self.special_tokens.separator_output_answers)]
|
171 |
+
return {"question": question, "answers": answers}
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
pipe = QASRL_Pipeline("kleinay/qanom-seq2seq-model-baseline")
|
176 |
+
res1 = pipe("The student was interested in Luke 's <predicate> research about sea animals .", verb_form="research", predicate_type="nominal")
|
177 |
+
res2 = pipe(["The doctor was interested in Luke 's <predicate> treatment .",
|
178 |
+
"The Veterinary student was interested in Luke 's <predicate> treatment of sea animals ."], verb_form="treat", predicate_type="nominal", num_beams=10)
|
179 |
+
res3 = pipe("A number of professions have <predicate> developed that specialize in the treatment of mental disorders .", verb_form="develop", predicate_type="verbal")
|
180 |
+
print(res1)
|
181 |
+
print(res2)
|
182 |
+
print(res3)
|
183 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
transformers==4.15.0
|
2 |
+
torch
|