Hiveurban's picture
Upload handler.py with huggingface_hub
a76b97c verified
raw
history blame
1.45 kB
from typing import Dict, List, Any
from transformers import AutoModel, AutoTokenizer
class EndpointHandler:
def __init__(self, path="."):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModel.from_pretrained(
path,
trust_remote_code=True,
# do_syntax=True, do_prefix=False, do_morph=False, do_ner=True, do_lex=True
)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# return self.pipeline(data['inputs'])
outputs = self.model.predict(data['inputs'], self.tokenizer, output_style='json')
for i, output in enumerate(outputs):
lem = ' '.join([x['lex'] for x in output['tokens']])
ner = [
{
'word': ' '.join([x['lex'] for x in output['tokens'][x['token_start']:x['token_end'] + 1]]),
'entity_group': x['label'],
'token_start': x['token_start'],
'token_end': x['token_end']
}
for x in output['ner_entities']
]
outputs[i] = {
'lex': lem,
'ner': ner
}
return outputs