sin-kaf / app.py
selinS's picture
Update app.py
f7ca09f
raw
history blame
No virus
1.56 kB
import gradio as gra
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModel
import onnxruntime as rt
ort_session = rt.InferenceSession("/sin-kaf/onnx_model/model.onnx")
ort_session.get_providers()
# model = ORTModel.load_model("/DATA/sin-kaf/onnx_model/model.onnx")
# model = AutoModelForSequenceClassification.from_pretrained('/DATA/sin-kaf/test_trainer/checkpoint-18500')
tokenizer = AutoTokenizer.from_pretrained("Overfit-GM/distilbert-base-turkish-cased-offensive")
def user_greeting(sent):
encoded_dict = tokenizer.encode_plus(
sent,
add_special_tokens = True,
max_length = 64,
pad_to_max_length = True,
return_attention_mask = True,
return_tensors = 'pt',
)
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
input_ids = torch.cat([input_ids], dim=0)
input_mask = torch.cat([attention_masks], dim=0)
input_feed = {
"input_ids": input_ids.tolist(),
"attention_mask":input_mask.tolist(),
}
output = ort_session.run(None, input_feed)
return np.argmax((output[0][0]))
# outputs = model(input_ids, input_mask)
# return torch.argmax(outputs['logits'])
app = gra.Interface(fn = user_greeting, inputs="text", outputs="text")
app.launch()
# app.launch(server_name="0.0.0.0")