import streamlit as st import gradio as gr import shap import numpy as np import scipy as sp import torch import transformers from transformers import pipeline from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification import matplotlib.pyplot as plt import sys import csv csv.field_size_limit(sys.maxsize) device = "cuda:0" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024") model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024") # Build a pipeline object for predictions pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None) explainer = shap.Explainer(pred) ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu # entity_colors = { 'Severity': '#E63946', # a vivid red 'Sign_symptom': '#2A9D8F', # a deep teal 'Medication': '#457B9D', # a dusky blue 'Age': '#F4A261', # a sandy orange 'Sex': '#F4A261', # same sandy orange for consistency with 'Age' 'Diagnostic_procedure': '#9C6644', # a brown 'Biological_structure': '#BDB2FF', # a light pastel purple } def adr_predict(x): encoded_input = tokenizer(x, return_tensors='pt') output = model(**encoded_input) scores = output[0][0].detach() scores = torch.nn.functional.softmax(scores) shap_values = explainer([str(x).lower()]) local_plot = shap.plots.text(shap_values[0], display=False) res = ner_pipe(x) htext = "" prev_end = 0 for entity in res: start = entity['start'] end = entity['end'] word = entity['word'].replace("##", "") color = entity_colors[entity['entity_group']] htext += f"{x[prev_end:start]}{word}" prev_end = end htext += x[prev_end:] return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext def main(prob1): text = str(prob1).lower() obj = adr_predict(text) return obj[0], obj[1], obj[2] # Define HTML for the legend legend_html = """