sin-kaf / app.py
selinS's picture
Update app.py
25a16fd
import gradio as gra
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModel
import onnxruntime as rt
ort_session = rt.InferenceSession("/home/user/app/onnx_model_2/model.onnx")
ort_session.get_providers()
# model = ORTModel.load_model("/DATA/sin-kaf/onnx_model/model.onnx")
# model = AutoModelForSequenceClassification.from_pretrained('/DATA/sin-kaf/test_trainer/checkpoint-18500')
tokenizer = AutoTokenizer.from_pretrained("Overfit-GM/distilbert-base-turkish-cased-offensive")
def user_greeting(sent):
encoded_dict = tokenizer.encode_plus(
sent,
add_special_tokens = True,
max_length = 64,
pad_to_max_length = True,
return_attention_mask = True,
return_tensors = 'pt',
)
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
input_ids = torch.cat([input_ids], dim=0)
input_mask = torch.cat([attention_masks], dim=0)
input_feed = {
"input_ids": input_ids.tolist(),
"attention_mask":input_mask.tolist(),
}
output = ort_session.run(None, input_feed)
if (np.argmax((output[0][0]))==0):
status = 'non-offensive'
else :
status = 'offensive'
return status
# outputs = model(input_ids, input_mask)
# return torch.argmax(outputs['logits'])
app = gra.Interface(fn = user_greeting, inputs="text", outputs="text")
app.launch()
# app.launch(server_name="0.0.0.0")