|
import torch, transformers |
|
from typing import Any, Dict |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import re |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path): |
|
tokenizer = AutoTokenizer.from_pretrained(path) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
path, device_map="auto", torch_dtype=torch.bfloat16, load_in_4bit=True |
|
) |
|
self.pipeline = transformers.pipeline( |
|
"text-generation", model=model, tokenizer=tokenizer |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
parameters = { |
|
"max_new_tokens": 128, |
|
"do_sample": True, |
|
"top_k": 10, |
|
"temperature": 1.0, |
|
"return_full_text": False, |
|
} |
|
|
|
|
|
parameters.update(data.pop("parameters", {})) |
|
|
|
unique = isinstance(inputs, str) |
|
inputs, denormalize_funcs = claire_text_preproc(inputs) |
|
|
|
sequences = self.pipeline(inputs, **parameters) |
|
|
|
if unique: |
|
return [{"generated_text": denormalize_funcs(sequences[0]["generated_text"])}] |
|
else: |
|
assert len(denormalize_funcs) == len(sequences) |
|
return [{"generated_text": denormalize_func(seq[0]["generated_text"])} for denormalize_func, seq in zip(denormalize_funcs, sequences)] |
|
|
|
|
|
def claire_text_preproc(text): |
|
if isinstance(text, (list, tuple)): |
|
assert len(text) |
|
|
|
texts, denormalize_funcs = zip(*[claire_text_preproc(t) for t in text]) |
|
return list(texts), list(denormalize_funcs) |
|
|
|
if not isinstance(text, str): |
|
return text |
|
|
|
text = format_special_characters(text) |
|
|
|
|
|
|
|
text = re.sub(" - | -$|^- ", " ", text.strip(" ")) |
|
|
|
global _reverse_tag_transfo |
|
_reverse_tag_transfo = {} |
|
text = format_special_tags(text) |
|
|
|
text = collapse_whitespaces(text) |
|
|
|
if _reverse_tag_transfo: |
|
reverse_tag_transfo = _reverse_tag_transfo.copy() |
|
def denormalize_func(t): |
|
for k, v in reverse_tag_transfo.items(): |
|
if k in t: |
|
t = t.replace(k, v) |
|
return t |
|
|
|
return text, lambda x: denormalize_func(x) |
|
|
|
else: |
|
return text, lambda x: x |
|
|
|
|
|
_brackets = re.compile(r"\[([^\]]*)\]") |
|
_pattern_speaker = re.compile(r"[^\]]+:") |
|
_non_printable_pattern = r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]" |
|
|
|
|
|
_reverse_tag_transfo = {} |
|
_anonymized_prefix = None |
|
|
|
def collapse_whitespaces(text): |
|
text = re.sub(r" +", " ", text) |
|
text = re.sub(r"\n+", "\n", text) |
|
text = re.sub(r" ([\.,])", r"\1", text) |
|
return text.lstrip().rstrip(" ") |
|
|
|
|
|
def format_special_tags(text): |
|
global _reverse_tag_transfo, _anonymized_prefix |
|
_anonymized_prefix = None |
|
text = re.sub(_brackets, _format_special_tags, text) |
|
|
|
if _anonymized_prefix: |
|
_reverse_tag_transfo["[Intervenant "] = _anonymized_prefix |
|
return text |
|
|
|
|
|
def _format_special_tags(match): |
|
content_within_brackets = match.group(1) |
|
if re.match(_pattern_speaker, content_within_brackets): |
|
return _format_tag(match.group()) |
|
else: |
|
return "" |
|
|
|
def _format_tag(text): |
|
global _reverse_tag_transfo, _anonymized_prefix |
|
if text.endswith(":]"): |
|
anonymized_spk_prefixes = ["speaker", "spk", "locuteur"] |
|
|
|
for prefix in anonymized_spk_prefixes: |
|
if text.lower().startswith("["+prefix): |
|
try: |
|
index = int(text[len(prefix)+1:-2]) |
|
except ValueError: |
|
return text |
|
new_spk_tag = f"[Intervenant {index}:]" |
|
_reverse_tag_transfo[new_spk_tag] = text |
|
if _anonymized_prefix is None: |
|
prefix = "["+prefix |
|
while len(prefix) < len(text) and text[len(prefix)] in " 0": |
|
prefix += text[len(prefix)] |
|
_anonymized_prefix = prefix |
|
return "\n" + new_spk_tag |
|
|
|
|
|
speaker = text[1:-2] |
|
speaker = capitalize(speaker) |
|
new_spk_tag = f"[{speaker}:]" |
|
if text != new_spk_tag: |
|
_reverse_tag_transfo[new_spk_tag] = text |
|
return "\n" + new_spk_tag |
|
|
|
if text == "[PII]": |
|
return "[Nom]" |
|
if text == "[NOISE]": |
|
return "[bruit]" |
|
if text == "[LAUGHTER]": |
|
return "[rire]" |
|
|
|
|
|
def capitalize(text): |
|
|
|
words = text.split(" ") |
|
words = [w.capitalize() if (not w.isupper() or len(w) > 2) else w for w in words] |
|
for i, w in enumerate(words): |
|
for sep in "-", "'": |
|
if sep in w: |
|
words[i] = sep.join( |
|
[x.capitalize() if not x.isupper() else x for x in w.split(sep)] |
|
) |
|
return " ".join(words) |
|
|
|
|
|
def format_special_characters(text): |
|
for before, after in [ |
|
("â", "â"), |
|
("à", "à"), |
|
("á", "á"), |
|
("ê", "ê"), |
|
("é", "é"), |
|
("è", "è"), |
|
("ô", "ô"), |
|
("û", "û"), |
|
("î", "î"), |
|
("\x92", "'"), |
|
("…", "..."), |
|
(r"[«“][^\S\r\n]*", '"'), |
|
(r"[^\S\r\n]*[»”″„]", '"'), |
|
(r"(``|'')", '"'), |
|
(r"[’‘‛ʿ]", "'"), |
|
("‚", ","), |
|
(r"–", "-"), |
|
("[ ]", " "), |
|
(_non_printable_pattern, ""), |
|
("·", "."), |
|
(r"ᵉʳ", "er"), |
|
(r"ᵉ", "e"), |
|
]: |
|
text = re.sub(before, after, text) |
|
|
|
return text |
|
|
|
|
|
def remove_ligatures(text): |
|
text = re.sub(r"œ", "oe", text) |
|
text = re.sub(r"æ", "ae", text) |
|
text = re.sub(r"fi", "fi", text) |
|
text = re.sub(r"fl", "fl", text) |
|
text = re.sub("ij", "ij", text) |
|
text = re.sub(r"Œ", "Oe", text) |
|
text = re.sub(r"Æ", "Ae", text) |
|
return text |
|
|