|
import gradio as gr |
|
import tensorflow as tf |
|
from transformers import TFAutoModel, AutoTokenizer |
|
import os |
|
import numpy as np |
|
|
|
model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest' |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = tf.keras.models.load_model( |
|
"model.h5", |
|
custom_objects={ |
|
'TFRobertaModel': TFAutoModel.from_pretrained(model_name) |
|
} |
|
) |
|
|
|
labels = [ |
|
'Cardiologist', |
|
'Dermatologist', |
|
'ENT Specialist', |
|
'Gastro-enterologist', |
|
'General-Physicians', |
|
'Neurologist/Gastro-enterologist', |
|
'Ophthalmologist', |
|
'Orthopedist', |
|
'Psychiatrist', |
|
'Respirologist', |
|
'Rheumatologist', |
|
'Rheumatologist/Gastro-enterologist', |
|
'Rheumatologist/Orthopedist', |
|
'Surgeon' |
|
] |
|
seq_len = 152 |
|
|
|
def prep_data(text): |
|
tokens = tokenizer( |
|
text, max_length=seq_len, truncation=True, |
|
padding='max_length', |
|
add_special_tokens=True, |
|
return_tensors='tf' |
|
) |
|
return { |
|
'input_ids': tokens['input_ids'], |
|
'attention_mask': tokens['attention_mask'] |
|
} |
|
|
|
def inference(text): |
|
encoded_text = prep_data(text) |
|
probs = model.predict_on_batch(encoded_text) |
|
probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))} |
|
return probabilities |
|
|
|
css = """ |
|
textarea { |
|
background-color: #00000000; |
|
border: 1px solid #6366f160; |
|
} |
|
""" |
|
with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo: |
|
with gr.Row(): |
|
textmd = gr.Markdown(''' |
|
<div style="margin: 50px 0;"></div> |
|
|
|
<h1 style="width:100%; text-align: center;">SpecX: Find the Right Specialist For Your Symptoms!</h1> |
|
|
|
''') |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=600): |
|
text_box = gr.Textbox(label="Explain your problem in one sentence.") |
|
submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary') |
|
examples = gr.Examples(examples=[ |
|
"When I remember her I feel down", |
|
"The area around my heart doesn't feel good.", |
|
"I have a split on my thumb that will not heal." |
|
], inputs=text_box) |
|
label = gr.Label(num_top_classes=4, label="Recommended Specialist") |
|
submit_btn.click(inference, inputs=text_box, outputs=label) |
|
|
|
demo.launch() |