sin-kaf / app.py
selinS's picture
Update app.py
b8c479f
raw
history blame
1.57 kB
import gradio as gra
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModel
import onnxruntime as rt
ort_session = rt.InferenceSession("/home/user/app/onnx_model/model.onnx")
ort_session.get_providers()
# model = ORTModel.load_model("/DATA/sin-kaf/onnx_model/model.onnx")
# model = AutoModelForSequenceClassification.from_pretrained('/DATA/sin-kaf/test_trainer/checkpoint-18500')
tokenizer = AutoTokenizer.from_pretrained("Overfit-GM/distilbert-base-turkish-cased-offensive")
def user_greeting(sent):
encoded_dict = tokenizer.encode_plus(
sent,
add_special_tokens = True,
max_length = 64,
pad_to_max_length = True,
return_attention_mask = True,
return_tensors = 'pt',
)
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
input_ids = torch.cat([input_ids], dim=0)
input_mask = torch.cat([attention_masks], dim=0)
input_feed = {
"input_ids": input_ids.tolist(),
"attention_mask":input_mask.tolist(),
}
output = ort_session.run(None, input_feed)
return np.argmax((output[0][0]))
# outputs = model(input_ids, input_mask)
# return torch.argmax(outputs['logits'])
app = gra.Interface(fn = user_greeting, inputs="text", outputs="text")
app.launch()
# app.launch(server_name="0.0.0.0")