if __name__ == '__main__': inputs = ['gbjjhbdjhbdgjhdbfjhsdkjrkjf', 'fdjhbjhsbd'] from transformers import AutoTokenizer from model import CustomModel import torch from configuration import CFG from dataset import SingleInputDataset from torch.utils.data import DataLoader from utils import inference_fn, get_char_probs, get_results, get_text import numpy as np import gradio as gr import os device = torch.device('cpu') config_path = os.path.join('models_file', 'config.pth') model_path = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth') tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer') model = CustomModel(CFG, config_path=config_path, pretrained=False) state = torch.load(model_path, map_location=device) model.load_state_dict(state['model']) def get_answer(context, feature): ## Input to the model using patient-history and feature-text inputs_single = tokenizer(context, feature, add_special_tokens=True, max_length=CFG.max_len, padding="max_length", return_offsets_mapping=False) for k, v in inputs_single.items(): inputs_single[k] = torch.tensor(v, dtype=torch.long) # Create a new dataset containing only the input sample single_input_dataset = SingleInputDataset(inputs_single) # Create a DataLoader for the new dataset single_input_loader = DataLoader( single_input_dataset, batch_size=1, shuffle=False, num_workers=2 ) # Perform inference on the single input output = inference_fn(single_input_loader, model, device) prediction = output.reshape((1, CFG.max_len)) char_probs = get_char_probs([context], prediction, tokenizer) predictions = np.mean([char_probs], axis=0) results = get_results(predictions, th=0.5) print(results) return get_text(context, results[0]) inputs = [gr.inputs.Textbox(label="Context Para", lines=10), gr.inputs.Textbox(label="Question", lines=1)] output = gr.outputs.Textbox(label="Answer") article = "
" app = gr.Interface( fn=get_answer, inputs=inputs, outputs=output, allow_flagging='never', title="Phrase Extraction", article=article, enable_queue=True, cache_examples=False, css="footer {visibility: hidden}" ) app.launch()