File size: 2,482 Bytes
fcc241e
 
 
 
 
 
7b02dda
 
 
 
 
 
 
 
 
 
 
 
 
 
fcc241e
 
 
 
8c4e0aa
7b02dda
 
 
fcc241e
 
 
 
 
7b02dda
 
 
 
8c4e0aa
7b02dda
4c072f6
 
 
 
 
7b02dda
 
 
 
 
 
fcc241e
8c4e0aa
fcc241e
 
 
 
 
8c4e0aa
fcc241e
8c4e0aa
fcc241e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import pipeline

# Load the token classification model
pipe = pipeline("token-classification", model="Clinical-AI-Apollo/Medical-NER", aggregation_strategy='simple')

# Define colors for different entity types
entity_colors = {
    "AGE": "#ffadad",
    "SEX": "#ffd6a5",
    "DISEASE_DISORDER": "#caffbf",
    "SIGN_SYMPTOM": "#9bf6ff",
    "LAB_VALUE": "#a0c4ff",
    "THERAPEUTIC_PROCEDURE": "#bdb2ff",
    "CLINICAL_EVENT": "#ffc6ff",
    "DIAGNOSTIC_PROCEDURE": "#fffffc",
    "DETAILED_DESCRIPTION": "#fdffb6",
    "BIOLOGICAL_STRUCTURE": "#ffb5a7"
}

def classify_text(text):
    # Get token classification results
    result = pipe(text)
    
    # Format the results into HTML with color highlighting and entity names
    highlighted_text = ""
    last_pos = 0

    for res in result:
        entity = res['entity_group']
        word = res['word']
        start = res['start']
        end = res['end']
        
        # Add text before the entity without highlighting
        highlighted_text += text[last_pos:start]
        
        # Add highlighted entity text with the entity name displayed
        color = entity_colors.get(entity, "#e0e0e0")  # Default to gray if entity type not defined
        highlighted_text += f"""
        <span style='background-color:{color}; padding:2px; border-radius:5px;'>
            {word} 
            <span style='display:inline-block; background-color:#fff; color:#000; border-radius:3px; padding:2px; margin-left:5px; font-size:10px;'>{entity}</span>
        </span>"""
        
        # Update last position
        last_pos = end
    
    # Add the rest of the text after the last entity
    highlighted_text += text[last_pos:]
    
    return f"<div style='font-family: Arial, sans-serif; line-height: 1.5;'>{highlighted_text}</div>"

# Gradio Interface
demo = gr.Interface(
    fn=classify_text,
    inputs=gr.Textbox(lines=5, label="Enter Medical Text"),
    outputs=gr.HTML(label="Entity Classification with Highlighting and Labels"),
    title="Medical Entity Classification",
    description="Enter medical-related text, and the model will classify medical entities with color highlighting and labels.",
    examples=[
        ["45 year old woman diagnosed with CAD"],
        ["A 65-year-old male presents with acute chest pain and a history of hypertension."],
        ["The patient underwent a laparoscopic cholecystectomy."]
    ]
)

if __name__ == "__main__":
    demo.launch()