naufderheide commited on
Commit
394a5f3
1 Parent(s): 878af56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -3
app.py CHANGED
@@ -1,7 +1,142 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
+ import streamlit as st
2
  import gradio as gr
3
+ import shap
4
+ import numpy as np
5
+ import scipy as sp
6
+ import torch
7
+ import tensorflow as tf
8
+ import transformers
9
+ from transformers import pipeline
10
+ from transformers import RobertaTokenizer, RobertaModel
11
+ from transformers import AutoModelForSequenceClassification
12
+ from transformers import TFAutoModelForSequenceClassification
13
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
14
 
15
+ import matplotlib.pyplot as plt
16
+ import sys
17
+ import csv
18
 
19
+ csv.field_size_limit(sys.maxsize)
20
+
21
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("MSBA24-Team8/ADRv2024")
24
+ model = AutoModelForSequenceClassification.from_pretrained("MSBA24-Team8/ADRv2024").to(device)
25
+
26
+ # build a pipeline object to do predictions
27
+ pred = transformers.pipeline("text-classification", model=model,
28
+ tokenizer=tokenizer, return_all_scores=True)
29
+
30
+ explainer = shap.Explainer(pred)
31
+
32
+ ##
33
+ # classifier = transformers.pipeline("text-classification", model = "cross-encoder/qnli-electra-base")
34
+
35
+ # def med_score(x):
36
+ # label = x['label']
37
+ # score_1 = x['score']
38
+ # return round(score_1,3)
39
+
40
+ # def sym_score(x):
41
+ # label2sym= x['label']
42
+ # score_1sym = x['score']
43
+ # return round(score_1sym,3)
44
+
45
+ ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
46
+ ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
47
+
48
+ ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
49
+ #
50
+
51
+ def adr_predict(x):
52
+ encoded_input = tokenizer(x, return_tensors='pt')
53
+ output = model(**encoded_input)
54
+ scores = output[0][0].detach().numpy()
55
+ scores = tf.nn.softmax(scores)
56
+
57
+ shap_values = explainer([str(x).lower()])
58
+ # # Find the index of the class you want as the default reference (e.g., 'label_1')
59
+ # label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
60
+
61
+ # # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
62
+ # shap.plots.text(shap_values[label_1_index][0])
63
+
64
+ local_plot = shap.plots.text(shap_values[0], display=False)
65
+
66
+ # med = med_score(classifier(x+str(", There is a medication."))[0])
67
+ # sym = sym_score(classifier(x+str(", There is a symptom."))[0])
68
+
69
+ res = ner_pipe(x)
70
+
71
+ entity_colors = {
72
+ 'Severity': 'red',
73
+ 'Sign_symptom': 'green',
74
+ 'Medication': 'lightblue',
75
+ 'Age': 'yellow',
76
+ 'Sex':'yellow',
77
+ 'Diagnostic_procedure':'gray',
78
+ 'Biological_structure':'silver'}
79
+
80
+ htext = ""
81
+ prev_end = 0
82
+
83
+ for entity in res:
84
+ start = entity['start']
85
+ end = entity['end']
86
+ word = entity['word'].replace("##", "")
87
+ color = entity_colors[entity['entity_group']]
88
+
89
+ htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>"
90
+ prev_end = end
91
+
92
+ htext += x[prev_end:]
93
+
94
+ return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot,htext
95
+ # ,{"Contains Medication": float(med), "No Medications": float(1-med)} , {"Contains Symptoms": float(sym), "No Symptoms": float(1-sym)}
96
+
97
+
98
+ def main(prob1):
99
+ text = str(prob1).lower()
100
+ obj = adr_predict(text)
101
+ return obj[0],obj[1],obj[2]
102
+
103
+ title = "Welcome to **ADR Detector** 🪐"
104
+ description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis."""
105
+
106
+ with gr.Blocks(title=title) as demo:
107
+ gr.Markdown(f"## {title}")
108
+ gr.Markdown(description1)
109
+ gr.Markdown("""---""")
110
+ prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...")
111
+ submit_btn = gr.Button("Analyze")
112
+
113
+ with gr.Row():
114
+
115
+ with gr.Column(visible=True) as output_col:
116
+ label = gr.Label(label = "Predicted Label")
117
+
118
+
119
+ with gr.Column(visible=True) as output_col:
120
+ local_plot = gr.HTML(label = 'Shap:')
121
+ htext = gr.HTML(label="NER")
122
+ # med = gr.Label(label = "Contains Medication")
123
+ # sym = gr.Label(label = "Contains Symptoms")
124
+
125
+ submit_btn.click(
126
+ main,
127
+ [prob1],
128
+ [label
129
+ ,local_plot, htext
130
+ # , med, sym
131
+ ], api_name="adr"
132
+ )
133
+
134
+ with gr.Row():
135
+ gr.Markdown("### Click on any of the examples below to see how it works:")
136
+ gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
137
+ ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]],
138
+ [prob1], [label,local_plot, htext
139
+ # , med, sym
140
+ ], main, cache_examples=True)
141
+
142
  demo.launch()