DexterSptizu commited on
Commit
7b02dda
·
verified ·
1 Parent(s): b183cef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -7
app.py CHANGED
@@ -4,29 +4,56 @@ from transformers import pipeline
4
  # Load the token classification model
5
  pipe = pipeline("token-classification", model="Clinical-AI-Apollo/Medical-NER", aggregation_strategy='simple')
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def classify_text(text):
8
  # Get token classification results
9
  result = pipe(text)
10
 
11
- # Format the results to resemble the UI shown in the image
12
- formatted_output = ""
 
 
13
  for res in result:
14
  entity = res['entity_group']
15
  word = res['word']
16
- score = res['score']
17
  start = res['start']
18
  end = res['end']
19
- formatted_output += f"Entity: {entity}, Word: {word}, Score: {score:.4f}, Span: [{start}:{end}]\n"
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- return formatted_output
22
 
23
  # Gradio Interface
24
  demo = gr.Interface(
25
  fn=classify_text,
26
  inputs=gr.Textbox(lines=5, label="Enter Medical Text"),
27
- outputs=gr.Textbox(label="Entity Classification"),
28
  title="Medical Entity Classification",
29
- description="Enter medical-related text, and the model will classify medical entities.",
30
  examples=[
31
  ["45 year old woman diagnosed with CAD"],
32
  ["A 65-year-old male presents with acute chest pain and a history of hypertension."],
 
4
  # Load the token classification model
5
  pipe = pipeline("token-classification", model="Clinical-AI-Apollo/Medical-NER", aggregation_strategy='simple')
6
 
7
+ # Define colors for different entity types
8
+ entity_colors = {
9
+ "AGE": "#ffadad",
10
+ "SEX": "#ffd6a5",
11
+ "DISEASE_DISORDER": "#caffbf",
12
+ "SIGN_SYMPTOM": "#9bf6ff",
13
+ "LAB_VALUE": "#a0c4ff",
14
+ "THERAPEUTIC_PROCEDURE": "#bdb2ff",
15
+ "CLINICAL_EVENT": "#ffc6ff",
16
+ "DIAGNOSTIC_PROCEDURE": "#fffffc",
17
+ "DETAILED_DESCRIPTION": "#fdffb6",
18
+ "BIOLOGICAL_STRUCTURE": "#ffb5a7"
19
+ }
20
+
21
  def classify_text(text):
22
  # Get token classification results
23
  result = pipe(text)
24
 
25
+ # Format the results into HTML with color highlighting
26
+ highlighted_text = ""
27
+ last_pos = 0
28
+
29
  for res in result:
30
  entity = res['entity_group']
31
  word = res['word']
 
32
  start = res['start']
33
  end = res['end']
34
+
35
+ # Add text before the entity without highlighting
36
+ highlighted_text += text[last_pos:start]
37
+
38
+ # Add highlighted entity text
39
+ color = entity_colors.get(entity, "#e0e0e0") # Default to gray if entity type not defined
40
+ highlighted_text += f"<span style='background-color:{color}; padding:2px; border-radius:5px;'>{word}</span>"
41
+
42
+ # Update last position
43
+ last_pos = end
44
+
45
+ # Add the rest of the text after the last entity
46
+ highlighted_text += text[last_pos:]
47
 
48
+ return highlighted_text
49
 
50
  # Gradio Interface
51
  demo = gr.Interface(
52
  fn=classify_text,
53
  inputs=gr.Textbox(lines=5, label="Enter Medical Text"),
54
+ outputs=gr.HTML(label="Entity Classification with Highlighting"),
55
  title="Medical Entity Classification",
56
+ description="Enter medical-related text, and the model will classify medical entities with color highlighting.",
57
  examples=[
58
  ["45 year old woman diagnosed with CAD"],
59
  ["A 65-year-old male presents with acute chest pain and a history of hypertension."],