test / app.py
adamtayzzz's picture
Update app.py
f915ec0
raw
history blame
2.11 kB
import gradio as gr
import requests
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transform
import os
import logging
from transformers import WEIGHTS_NAME,AdamW,AlbertConfig,AlbertTokenizer,BertConfig,BertTokenizer
from pabee.modeling_albert import AlbertForSequenceClassification
from pabee.modeling_bert import BertForSequenceClassification
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
import datasets
from whitebox_utils.classifier import MyClassifier
import random
import numpy as np
import torch
import argparse
def random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
logger = logging.getLogger(__name__)
# TODO: dataset model tokenizer etc.
best_model_path = {
'albert_STS-B':'./outputs/train/albert/SST-2/checkpoint-7500',
}
MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
}
model = 'albert'
dataset = 'STS-B'
task_name = f'{dataset}'.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]() # transformers package-preprocessor
output_mode = output_modes[task_name] # output type
label_list = processor.get_labels()
num_labels = len(label_list)
output_dir = f'./PABEE/outputs/train/{model}/{dataset}'
data_dir = f'./PABEE/glue_data/{dataset}'
config_class, model_class, tokenizer_class = MODEL_CLASSES[model]
tokenizer = tokenizer_class.from_pretrained(output_dir, do_lower_case=True)
model = model_class.from_pretrained(best_model_path[f'{model}_{dataset}'])
exit_type='patience'
exit_value=3
classifier = MyClassifier(model,tokenizer,label_list,output_mode,exit_type,exit_value,model)
def greet(text,text2,exit_pos):
text_input = [(text,text2)]
classifier.get_prob_time(text_input,exit_position=exit_pos)
iface = gr.Interface(fn=greet, inputs='text', outputs="image")
iface.launch()