Spaces:
Sleeping
Sleeping
import streamlit as st | |
import gradio as gr | |
import shap | |
import numpy as np | |
import scipy as sp | |
import torch | |
import tensorflow as tf | |
import transformers | |
from transformers import pipeline | |
from transformers import RobertaTokenizer, RobertaModel | |
from transformers import AutoModelForSequenceClassification | |
from transformers import TFAutoModelForSequenceClassification | |
from transformers import AutoTokenizer | |
# from transformers_interpret import SequenceClassificationExplainer | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1") | |
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device) | |
# build a pipeline object to do predictions | |
pred = transformers.pipeline("text-classification", model=model, | |
tokenizer=tokenizer, return_all_scores=True) | |
explainer = shap.Explainer(pred) | |
def interpretation_function(text): | |
shap_values = explainer([text]) | |
scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1])) | |
return scores | |
# model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1") | |
# modelc = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").cuda | |
# cls_explainer = SequenceClassificationExplainer( | |
# model, | |
# tokenizer) | |
# # define a prediction function | |
# def f(x): | |
# tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=500, truncation=True) for v in x]).cuda() | |
# outputs = modelc(tv)[0].detach().cpu().numpy() | |
# scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T | |
# val = sp.special.logit(scores[:,1]) # use one vs rest logit units | |
# return val | |
def adr_predict(x): | |
encoded_input = tokenizer(str(x), return_tensors='pt') | |
output = model(**encoded_input) | |
scores = output[0][0].detach().numpy() | |
scores = tf.nn.softmax(scores) | |
# # build a pipeline object to do predictions | |
# pred = transformers.pipeline("text-classification", model=model, | |
# tokenizer=tokenizer, device=0, return_all_scores=True) | |
# explainer = shap.Explainer(pred) | |
# shap_values = explainer([x]) | |
# shap_plot = shap.plots.text(shap_values) | |
# word_attributions = cls_explainer(str(x)) | |
# # scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1])) | |
# letter = [] | |
# score = [] | |
# for i in word_attributions: | |
# if i[1]>0.5: | |
# a = "++" | |
# elif (i[1]<=0.5) and (i[1]>0.1): | |
# a = "+" | |
# elif (i[1]>=-0.5) and (i[1]<-0.1): | |
# a = "-" | |
# elif i[1]<-0.5: | |
# a = "--" | |
# else: | |
# a = "NA" | |
# letter.append(i[0]) | |
# score.append(a) | |
# word_attributions = [(letter[i], score[i]) for i in range(0, len(letter))] | |
# # SHAP: | |
# # build an explainer using a token masker | |
# explainer = shap.Explainer(f, tokenizer) | |
# shap_values = explainer(str(x), fixed_context=1) | |
# scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1])) | |
# # plot the first sentence's explanation | |
# # plt = shap.plots.text(shap_values[0],display=False) | |
shap_scores = interpretation_function(str(x).lower()) | |
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, shap_scores | |
# , word_attributions ,scores | |
def main(prob1): | |
text = str(prob1).lower() | |
obj = adr_predict(text) | |
return obj[0],obj[1] | |
# ,obj[2] | |
title = "Welcome to **ADR Detector** 🪐" | |
description1 = """ | |
This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. | |
""" | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description1) | |
gr.Markdown("""---""") | |
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") | |
submit_btn = gr.Button("Analyze") | |
with gr.Column(visible=True) as output_col: | |
label = gr.Label(label = "Predicted Label") | |
# impplot = gr.HighlightedText(label="Important Words", combine_adjacent=False).style( | |
# color_map={"+++": "royalblue","++": "cornflowerblue", | |
# "+": "lightsteelblue", "NA":"white"}) | |
# NER = gr.HTML(label = 'NER:') | |
# intp = gr.HighlightedText(label="Word Scores", | |
# combine_adjacent=False).style(color_map={"++": "darkred","+": "red", | |
# "--": "darkblue", | |
# "-": "blue", "NA":"white"}) | |
interpretation = gr.components.Interpretation(text) | |
submit_btn.click( | |
main, | |
[prob1], | |
[label | |
# ,intp | |
,interpretation | |
], api_name="adr" | |
) | |
gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:") | |
gr.Examples([["I have minor pain."],["I have severe pain."]], [prob1], [label | |
# ,intp | |
,interpretation | |
], main, cache_examples=True) | |
demo.launch() | |