dusense / app.py
saily's picture
change app
e14dd36
raw
history blame
1.42 kB
import gradio as gr
import spaces
import numpy as np
import os
import time
import torch
from config import Config
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
@spaces.GPU
def greet(inputStr):
set_seed(1)
config = Config("./data_12345")
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
bert_config = BertConfig.from_pretrained("bert-base-chinese", num_labels=config.num_labels)
model = BertForSequenceClassification.from_pretrained("bert-base-chinese",
config=bert_config
)
model.to(config.device)
model.load_state_dict(torch.load(config.saved_model))
model.eval()
inputs = tokenizer(
inputStr,
max_length=config.max_seq_len,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to(config.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs[0]
label = torch.max(logits.data, 1)[1].tolist()
print("Classification result:" + config.label_list[label[0]])
return config.label_list[label[0]]
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
#demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
demo.launch()