from dataclasses import dataclass import logging import os from abc import ABC from typing import Optional import torch import json from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, ) from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) MAX_TOKEN_LENGTH_ERR = { "code": 422, "type" : "MaxTokenLengthError", "message": "Max token length exceeded", } class EngCopHandler(BaseHandler, ABC): @dataclass class GenerationConfig: max_length: int = 20 max_new_tokens: Optional[int] = None min_length: int = 0 min_new_tokens: Optional[int] = None early_stopping: bool = True do_sample: bool = False num_beams: int = 1 num_beam_groups: int = 1 top_k: int = 50 top_p: float = 0.95 temperature: float = 1.0 diversity_penalty: float = 0.0 def __init__(self): super(EngCopHandler, self).__init__() self.initialized = False def initialize(self, ctx): """In this initialize function, the HF large model is loaded and partitioned using DeepSpeed. Args: ctx (context): It is a JSON Object containing information pertaining to the model artifacts parameters. """ logger.info("Start initialize") self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") serialized_file = self.manifest["model"]["serializedFile"] model_pt_path = os.path.join(model_dir, serialized_file) setup_config_path = os.path.join(model_dir, "setup_self.config.json") if os.path.isfile(setup_config_path): with open(setup_config_path) as setup_config_path: self.setup_config = json.load(setup_config_path) seed = int(42) torch.manual_seed(seed) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info("Device: %s", self.device) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) self.model.to(self.device) self.model.eval() self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.config = EngCopHandler.GenerationConfig( max_new_tokens=128, min_new_tokens=1, num_beams=5, ) self.initialized = True logger.info("Init done") def preprocess(self, requests): preprocessed_data = [] for data in requests: data_item = data.get("data") if data_item is None: data_item = data.get("body") if isinstance(data_item, (bytes, bytearray)): data_item = data_item.decode("utf-8") preprocessed_data.append(data_item) logger.info("preprocessed_data %s: ", preprocessed_data) return preprocessed_data def inference(self, data): indices = {} batch = [] for i, item in enumerate(data): tokens = self.tokenizer(item, return_tensors="pt", padding=True) if len(tokens.input_ids.squeeze()) > self.tokenizer.model_max_length: logger.info("Skipping token %s for index %s", tokens, i) continue indices[i] = len(batch) batch.append(data[i]) logger.info("inference batch: %s", batch) result = self.batch_translate(batch) return [ degreekify(result[indices[i]]) if i in indices else None for i in range(len(data)) ] def postprocess(self, output): return output def handle(self, requests, context): logger.info("requests %s: ", requests) preprocessed = self.preprocess(requests) inference_data = self.inference(preprocessed) postprocessed = self.postprocess(inference_data) logger.info("inference result: %s", postprocessed) responses = [ {"code": 200, "translation": translation} if translation else MAX_TOKEN_LENGTH_ERR for translation in postprocessed ] return responses def batch_translate(self, input_sentences, output_confidence=False): if len(input_sentences) == 0: return [] inputs = self.tokenizer(input_sentences, return_tensors="pt", padding=True).to( self.device ) output_scores, return_dict_in_generate = output_confidence, output_confidence outputs = self.model.generate( **inputs, max_length=self.config.max_length, max_new_tokens=self.config.max_new_tokens, min_length=self.config.min_length, min_new_tokens=self.config.min_new_tokens, early_stopping=self.config.early_stopping, do_sample=self.config.do_sample, num_beams=self.config.num_beams, num_beam_groups=self.config.num_beam_groups, top_k=self.config.top_k, top_p=self.config.top_p, temperature=self.config.temperature, diversity_penalty=self.config.diversity_penalty, output_scores=output_scores, return_dict_in_generate=True, ) translated_text = self.tokenizer.batch_decode( outputs.sequences, skip_special_tokens=True ) return translated_text GREEK_TO_COPTIC = { "α": "ⲁ", "β": "ⲃ", "γ": "ⲅ", "δ": "ⲇ", "ε": "ⲉ", "ϛ": "ⲋ", "ζ": "ⲍ", "η": "ⲏ", "θ": "ⲑ", "ι": "ⲓ", "κ": "ⲕ", "λ": "ⲗ", "μ": "ⲙ", "ν": "ⲛ", "ξ": "ⲝ", "ο": "ⲟ", "π": "ⲡ", "ρ": "ⲣ", "σ": "ⲥ", "τ": "ⲧ", "υ": "ⲩ", "φ": "ⲫ", "χ": "ⲭ", "ψ": "ⲯ", "ω": "ⲱ", "s": "ϣ", "f": "ϥ", "k": "ϧ", "h": "ϩ", "j": "ϫ", "c": "ϭ", "t": "ϯ", } def degreekify(greek_text): chars = [] for c in greek_text: l_c = c.lower() chars.append(GREEK_TO_COPTIC.get(l_c, l_c)) return "".join(chars)