megalaa's picture
Upload 11 files
10597c2
from dataclasses import dataclass
import logging
import os
from abc import ABC
from typing import Optional
import torch
import json
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
MAX_TOKEN_LENGTH_ERR = {
"code": 422,
"type" : "MaxTokenLengthError",
"message": "Max token length exceeded",
}
class EngCopHandler(BaseHandler, ABC):
@dataclass
class GenerationConfig:
max_length: int = 20
max_new_tokens: Optional[int] = None
min_length: int = 0
min_new_tokens: Optional[int] = None
early_stopping: bool = True
do_sample: bool = False
num_beams: int = 1
num_beam_groups: int = 1
top_k: int = 50
top_p: float = 0.95
temperature: float = 1.0
diversity_penalty: float = 0.0
def __init__(self):
super(EngCopHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
"""In this initialize function, the HF large model is loaded and
partitioned using DeepSpeed.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
logger.info("Start initialize")
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
setup_config_path = os.path.join(model_dir, "setup_self.config.json")
if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_path:
self.setup_config = json.load(setup_config_path)
seed = int(42)
torch.manual_seed(seed)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Device: %s", self.device)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.config = EngCopHandler.GenerationConfig(
max_new_tokens=128,
min_new_tokens=1,
num_beams=5,
)
self.initialized = True
logger.info("Init done")
def preprocess(self, requests):
preprocessed_data = []
for data in requests:
data_item = data.get("data")
if data_item is None:
data_item = data.get("body")
if isinstance(data_item, (bytes, bytearray)):
data_item = data_item.decode("utf-8")
preprocessed_data.append(data_item)
logger.info("preprocessed_data %s: ", preprocessed_data)
return preprocessed_data
def inference(self, data):
indices = {}
batch = []
for i, item in enumerate(data):
tokens = self.tokenizer(item, return_tensors="pt", padding=True)
if len(tokens.input_ids.squeeze()) > self.tokenizer.model_max_length:
logger.info("Skipping token %s for index %s", tokens, i)
continue
indices[i] = len(batch)
batch.append(data[i])
logger.info("inference batch: %s", batch)
result = self.batch_translate(batch)
return [
degreekify(result[indices[i]]) if i in indices else None
for i in range(len(data))
]
def postprocess(self, output):
return output
def handle(self, requests, context):
logger.info("requests %s: ", requests)
preprocessed = self.preprocess(requests)
inference_data = self.inference(preprocessed)
postprocessed = self.postprocess(inference_data)
logger.info("inference result: %s", postprocessed)
responses = [
{"code": 200, "translation": translation}
if translation
else MAX_TOKEN_LENGTH_ERR
for translation in postprocessed
]
return responses
def batch_translate(self, input_sentences, output_confidence=False):
if len(input_sentences) == 0:
return []
inputs = self.tokenizer(input_sentences, return_tensors="pt", padding=True).to(
self.device
)
output_scores, return_dict_in_generate = output_confidence, output_confidence
outputs = self.model.generate(
**inputs,
max_length=self.config.max_length,
max_new_tokens=self.config.max_new_tokens,
min_length=self.config.min_length,
min_new_tokens=self.config.min_new_tokens,
early_stopping=self.config.early_stopping,
do_sample=self.config.do_sample,
num_beams=self.config.num_beams,
num_beam_groups=self.config.num_beam_groups,
top_k=self.config.top_k,
top_p=self.config.top_p,
temperature=self.config.temperature,
diversity_penalty=self.config.diversity_penalty,
output_scores=output_scores,
return_dict_in_generate=True,
)
translated_text = self.tokenizer.batch_decode(
outputs.sequences, skip_special_tokens=True
)
return translated_text
GREEK_TO_COPTIC = {
"α": "ⲁ",
"β": "ⲃ",
"γ": "ⲅ",
"δ": "ⲇ",
"ε": "ⲉ",
"ϛ": "ⲋ",
"ζ": "ⲍ",
"η": "ⲏ",
"θ": "ⲑ",
"ι": "ⲓ",
"κ": "ⲕ",
"λ": "ⲗ",
"μ": "ⲙ",
"ν": "ⲛ",
"ξ": "ⲝ",
"ο": "ⲟ",
"π": "ⲡ",
"ρ": "ⲣ",
"σ": "ⲥ",
"τ": "ⲧ",
"υ": "ⲩ",
"φ": "ⲫ",
"χ": "ⲭ",
"ψ": "ⲯ",
"ω": "ⲱ",
"s": "ϣ",
"f": "ϥ",
"k": "ϧ",
"h": "ϩ",
"j": "ϫ",
"c": "ϭ",
"t": "ϯ",
}
def degreekify(greek_text):
chars = []
for c in greek_text:
l_c = c.lower()
chars.append(GREEK_TO_COPTIC.get(l_c, l_c))
return "".join(chars)