Spaces:
Sleeping
Sleeping
import streamlit as st | |
import gradio as gr | |
import shap | |
import numpy as np | |
import scipy as sp | |
import torch | |
import tensorflow as tf | |
import transformers | |
from transformers import pipeline | |
from transformers import RobertaTokenizer, RobertaModel | |
from transformers import AutoModelForSequenceClassification | |
from transformers import TFAutoModelForSequenceClassification | |
from transformers import AutoTokenizer | |
import matplotlib.pyplot as plt | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1") | |
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device) | |
# build a pipeline object to do predictions | |
pred = transformers.pipeline("text-classification", model=model, | |
tokenizer=tokenizer, return_all_scores=True) | |
explainer = shap.Explainer(pred) | |
## | |
classifier = transformers.pipeline("text-classification", model = "cross-encoder/qnli-electra-base") | |
def med_score(x): | |
label = x['label'] | |
score_1 = x['score'] | |
return score_1 | |
## | |
def adr_predict(x): | |
encoded_input = tokenizer(x, return_tensors='pt') | |
output = model(**encoded_input) | |
scores = output[0][0].detach().numpy() | |
scores = tf.nn.softmax(scores) | |
shap_values = explainer([str(x).lower()]) | |
local_plot = shap.plots.text(shap_values[0], display=False) | |
med = med_score(classifier(x+str(", There is a medication."))[0]) | |
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, {"Contains Medication": float(med), "No Medications": float(1-med)} | |
def main(prob1): | |
text = str(prob1).lower() | |
obj = adr_predict(text) | |
return obj[0],obj[1],obj[2] | |
title = "Welcome to **ADR Detector** 🪐" | |
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) | |
adverse reaction to medicaitons. Please do NOT use for medical diagnosis.""" | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description1) | |
gr.Markdown("""---""") | |
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") | |
submit_btn = gr.Button("Analyze") | |
with gr.Column(visible=True) as output_col: | |
label = gr.Label(label = "Predicted Label") | |
local_plot = gr.HTML(label = 'Shap:') | |
med = gr.Label(label = "Contains Medication") | |
submit_btn.click( | |
main, | |
[prob1], | |
[label | |
,local_plot, med | |
], api_name="adr" | |
) | |
gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:") | |
gr.Examples([["I have severe pain."],["I have minor pain."]], [prob1], [label,local_plot, med | |
], main, cache_examples=True) | |
demo.launch() | |