ashishkgpian's picture
Update app.py
a9de279 verified
raw
history blame
5.41 kB
import gradio as gr
from transformers import pipeline
import pandas as pd
import os
# Load the model
classifier = pipeline(
"text-classification",
model="ashishkgpian/biobert_icd9_classifier_ehr"
)
# Load ICD9 codes data
icd9_data = pd.read_csv('')
def classify_symptoms(text):
try:
results = classifier(text, top_k=5)
formatted_results = []
for result in results:
code = result['label']
# Look up additional information
code_info = icd9_data[icd9_data['ICD9_CODE'] == code]
formatted_results.append({
"ICD9 Code": code,
"Short Title": code_info['SHORT_TITLE'].iloc[0] if not code_info.empty else "N/A",
"Long Title": code_info['LONG_TITLE'].iloc[0] if not code_info.empty else "N/A",
"Confidence": f"{result['score']:.2%}"
})
return formatted_results
except Exception as e:
return f"Error processing classification: {str(e)}"
# Enhanced CSS with violet theme and better text contrast
custom_css = """
.gradio-container {
max-width: 1200px !important;
margin: auto !important;
padding: 2rem !important;
background-color: #f5f3f7 !important;
}
.main-container {
text-align: center;
padding: 1rem;
margin-bottom: 2rem;
background: #ffffff;
border-radius: 10px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
}
h1 {
color: #4a148c !important;
font-size: 2.5rem !important;
margin-bottom: 0.5rem !important;
}
h3 {
color: #6a1b9a !important;
font-size: 1.2rem !important;
font-weight: normal !important;
}
.input-container {
background: white !important;
padding: 2rem !important;
border-radius: 10px !important;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
margin-bottom: 1.5rem !important;
}
textarea {
background: white !important;
color: #000000 !important;
border: 2px solid #7b1fa2 !important;
border-radius: 8px !important;
padding: 1rem !important;
font-size: 1.1rem !important;
min-height: 120px !important;
}
.submit-btn {
background-color: #6a1b9a !important;
color: white !important;
padding: 0.8rem 2rem !important;
border-radius: 8px !important;
font-size: 1.1rem !important;
margin-top: 1rem !important;
transition: background-color 0.3s ease !important;
}
.submit-btn:hover {
background-color: #4a148c !important;
}
.output-container {
background: white !important;
padding: 2rem !important;
border-radius: 10px !important;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
}
.output-container pre {
background: #f8f9fa !important;
color: #000000 !important;
border-radius: 8px !important;
padding: 1rem !important;
}
.examples-container {
background: white !important;
padding: 1.5rem !important;
border-radius: 10px !important;
margin-top: 1rem !important;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
}
.example-text {
color: #000000 !important;
}
.footer {
text-align: center;
margin-top: 2rem;
padding: 1rem;
background: white;
border-radius: 10px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
color: #4a148c;
}
"""
with gr.Blocks(css=custom_css) as demo:
with gr.Row(elem_classes=["main-container"]):
gr.Markdown(
"""
# 🏥 MedAI: Clinical Symptom ICD9 Classifier
### Advanced AI-Powered Diagnostic Code Assistant
"""
)
with gr.Row():
with gr.Column(elem_classes=["input-container"]):
input_text = gr.Textbox(
label="Clinical Symptom Description",
placeholder="Enter detailed patient symptoms and clinical observations...",
lines=5
)
submit_btn = gr.Button("Analyze Symptoms", elem_classes=["submit-btn"])
with gr.Row(elem_classes=["output-container"]):
output = gr.JSON(
label="Suggested ICD9 Diagnostic Codes with Descriptions"
)
with gr.Row(elem_classes=["examples-container"]):
examples = gr.Examples(
examples=[
["45-year-old male experiencing severe chest pain, radiating to left arm, with shortness of breath and excessive sweating"],
["Persistent headache for 2 weeks, accompanied by dizziness and occasional blurred vision"],
["Diabetic patient reporting frequent urination, increased thirst, and unexplained weight loss"],
["Elderly patient with chronic knee pain, reduced mobility, and signs of inflammation"]
],
inputs=input_text,
label="Example Clinical Cases",
elem_classes=["example-text"]
)
submit_btn.click(fn=classify_symptoms, inputs=input_text, outputs=output)
input_text.submit(fn=classify_symptoms, inputs=input_text, outputs=output)
with gr.Row():
gr.Markdown(
"""
<div class="footer">
⚕️ <strong>Medical Disclaimer:</strong> This AI tool is designed to assist medical professionals in ICD9 code classification.
Always verify suggestions with clinical judgment and consult appropriate medical resources.
</div>
""",
)
if __name__ == "__main__":
demo.launch()