Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
# Load the token classification model | |
pipe = pipeline("token-classification", model="Clinical-AI-Apollo/Medical-NER", aggregation_strategy='simple') | |
# Define colors for different entity types | |
entity_colors = { | |
"AGE": "#ffadad", | |
"SEX": "#ffd6a5", | |
"DISEASE_DISORDER": "#caffbf", | |
"SIGN_SYMPTOM": "#9bf6ff", | |
"LAB_VALUE": "#a0c4ff", | |
"THERAPEUTIC_PROCEDURE": "#bdb2ff", | |
"CLINICAL_EVENT": "#ffc6ff", | |
"DIAGNOSTIC_PROCEDURE": "#fffffc", | |
"DETAILED_DESCRIPTION": "#fdffb6", | |
"BIOLOGICAL_STRUCTURE": "#ffb5a7" | |
} | |
def classify_text(text): | |
# Get token classification results | |
result = pipe(text) | |
# Format the results into HTML with color highlighting and entity names | |
highlighted_text = "" | |
last_pos = 0 | |
for res in result: | |
entity = res['entity_group'] | |
word = res['word'] | |
start = res['start'] | |
end = res['end'] | |
# Add text before the entity without highlighting | |
highlighted_text += text[last_pos:start] | |
# Add highlighted entity text with the entity name displayed | |
color = entity_colors.get(entity, "#e0e0e0") # Default to gray if entity type not defined | |
highlighted_text += f""" | |
<span style='background-color:{color}; padding:2px; border-radius:5px;'> | |
{word} | |
<span style='display:inline-block; background-color:#fff; color:#000; border-radius:3px; padding:2px; margin-left:5px; font-size:10px;'>{entity}</span> | |
</span>""" | |
# Update last position | |
last_pos = end | |
# Add the rest of the text after the last entity | |
highlighted_text += text[last_pos:] | |
return f"<div style='font-family: Arial, sans-serif; line-height: 1.5;'>{highlighted_text}</div>" | |
# Gradio Interface | |
demo = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(lines=5, label="Enter Medical Text"), | |
outputs=gr.HTML(label="Entity Classification with Highlighting and Labels"), | |
title="Medical Entity Classification", | |
description="Enter medical-related text, and the model will classify medical entities with color highlighting and labels.", | |
examples=[ | |
["45 year old woman diagnosed with CAD"], | |
["A 65-year-old male presents with acute chest pain and a history of hypertension."], | |
["The patient underwent a laparoscopic cholecystectomy."] | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |