Claire-7B-0.1-GPTQ / handler.py
TheBloke's picture
GPTQ model commit
4eb0875
raw
history blame
6.2 kB
import torch, transformers
from typing import Any, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
class EndpointHandler:
def __init__(self, path):
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(
path, device_map="auto", torch_dtype=torch.bfloat16, load_in_4bit=True
)
self.pipeline = transformers.pipeline(
"text-generation", model=model, tokenizer=tokenizer
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# process input
inputs = data.pop("inputs", data)
# default parameters
parameters = {
"max_new_tokens": 128,
"do_sample": True,
"top_k": 10,
"temperature": 1.0,
"return_full_text": False,
}
# user parameters
parameters.update(data.pop("parameters", {}))
unique = isinstance(inputs, str)
inputs, denormalize_funcs = claire_text_preproc(inputs)
sequences = self.pipeline(inputs, **parameters)
if unique:
return [{"generated_text": denormalize_funcs(sequences[0]["generated_text"])}]
else:
assert len(denormalize_funcs) == len(sequences)
return [{"generated_text": denormalize_func(seq[0]["generated_text"])} for denormalize_func, seq in zip(denormalize_funcs, sequences)]
def claire_text_preproc(text):
if isinstance(text, (list, tuple)):
assert len(text)
# Apply and transpose
texts, denormalize_funcs = zip(*[claire_text_preproc(t) for t in text])
return list(texts), list(denormalize_funcs)
if not isinstance(text, str):
return text
text = format_special_characters(text)
# text = remove_ligatures(text)
text = re.sub(" - | -$|^- ", " ", text.strip(" "))
global _reverse_tag_transfo
_reverse_tag_transfo = {}
text = format_special_tags(text)
text = collapse_whitespaces(text)
if _reverse_tag_transfo:
reverse_tag_transfo = _reverse_tag_transfo.copy()
def denormalize_func(t):
for k, v in reverse_tag_transfo.items():
if k in t:
t = t.replace(k, v)
return t
return text, lambda x: denormalize_func(x)
else:
return text, lambda x: x
_brackets = re.compile(r"\[([^\]]*)\]")
_pattern_speaker = re.compile(r"[^\]]+:")
_non_printable_pattern = r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]"
# Global variable to remember some normalizations that were done and apply it back
_reverse_tag_transfo = {}
_anonymized_prefix = None
def collapse_whitespaces(text):
text = re.sub(r" +", " ", text)
text = re.sub(r"\n+", "\n", text)
text = re.sub(r" ([\.,])", r"\1", text)
return text.lstrip().rstrip(" ")
def format_special_tags(text):
global _reverse_tag_transfo, _anonymized_prefix
_anonymized_prefix = None
text = re.sub(_brackets, _format_special_tags, text)
# At last the generic anonymization
if _anonymized_prefix:
_reverse_tag_transfo["[Intervenant "] = _anonymized_prefix
return text
def _format_special_tags(match):
content_within_brackets = match.group(1)
if re.match(_pattern_speaker, content_within_brackets):
return _format_tag(match.group())
else:
return ""
def _format_tag(text):
global _reverse_tag_transfo, _anonymized_prefix
if text.endswith(":]"):
anonymized_spk_prefixes = ["speaker", "spk", "locuteur"]
# Conversion "[speaker001:]" -> "[Intervenant 1:]"
for prefix in anonymized_spk_prefixes:
if text.lower().startswith("["+prefix):
try:
index = int(text[len(prefix)+1:-2])
except ValueError:
return text
new_spk_tag = f"[Intervenant {index}:]"
_reverse_tag_transfo[new_spk_tag] = text
if _anonymized_prefix is None:
prefix = "["+prefix
while len(prefix) < len(text) and text[len(prefix)] in " 0":
prefix += text[len(prefix)]
_anonymized_prefix = prefix
return "\n" + new_spk_tag
# Capitalize speaker name
speaker = text[1:-2]
speaker = capitalize(speaker)
new_spk_tag = f"[{speaker}:]"
if text != new_spk_tag:
_reverse_tag_transfo[new_spk_tag] = text
return "\n" + new_spk_tag
if text == "[PII]":
return "[Nom]"
if text == "[NOISE]":
return "[bruit]"
if text == "[LAUGHTER]":
return "[rire]"
def capitalize(text):
# Custom capitalization for first and last names
words = text.split(" ")
words = [w.capitalize() if (not w.isupper() or len(w) > 2) else w for w in words]
for i, w in enumerate(words):
for sep in "-", "'":
if sep in w:
words[i] = sep.join(
[x.capitalize() if not x.isupper() else x for x in w.split(sep)]
)
return " ".join(words)
def format_special_characters(text):
for before, after in [
("â", "â"),
("à", "à"),
("á", "á"),
("ê", "ê"),
("é", "é"),
("è", "è"),
("ô", "ô"),
("û", "û"),
("î", "î"),
("\x92", "'"),
("…", "..."),
(r"[«“][^\S\r\n]*", '"'),
(r"[^\S\r\n]*[»”″„]", '"'),
(r"(``|'')", '"'),
(r"[’‘‛ʿ]", "'"),
("‚", ","),
(r"–", "-"),
("[  ]", " "), # weird whitespace
(_non_printable_pattern, ""), # non-printable characters
("·", "."),
(r"ᵉʳ", "er"),
(r"ᵉ", "e"),
]:
text = re.sub(before, after, text)
return text
def remove_ligatures(text):
text = re.sub(r"œ", "oe", text)
text = re.sub(r"æ", "ae", text)
text = re.sub(r"fi", "fi", text)
text = re.sub(r"fl", "fl", text)
text = re.sub("ij", "ij", text)
text = re.sub(r"Œ", "Oe", text)
text = re.sub(r"Æ", "Ae", text)
return text