lrschuman17 commited on
Commit
296b925
·
verified ·
1 Parent(s): 9553e58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -1
app.py CHANGED
@@ -1,3 +1,60 @@
1
  import gradio as gr
 
 
2
 
3
- gr.load("models/facebook/bart-large-mnli").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ from transformers import pipeline
4
 
5
+ # Initialize the zero-shot classifier
6
+ classifier = pipeline('zero-shot-classification', model='distilbert-base-uncased')
7
+
8
+ # Define the candidate labels for injury classification
9
+ candidate_labels = ["ACL Tear", "Meniscus Tear", "Achilles Tear", "Fracture", "Hamstring", "Foot", "Shoulder", "Hip", "Calf", "Hand", "Wrist"]
10
+
11
+ def process_injury_notes(file):
12
+ # Read CSV file
13
+ df = pd.read_csv(file.name)
14
+
15
+ # Limit to a sample (for performance in demonstration)
16
+ new_df = df.head(100).copy()
17
+
18
+ # Classify each note and save results
19
+ classifications = classifier(new_df['Notes'].tolist(), candidate_labels)
20
+ new_df['Classifications'] = classifications
21
+
22
+ # Extract top classification and score
23
+ new_df['Top Classification'] = new_df['Classifications'].apply(lambda x: x['labels'][0] if isinstance(x, dict) else None)
24
+ new_df['Top Score'] = new_df['Classifications'].apply(lambda x: x['scores'][0] if isinstance(x, dict) else None)
25
+
26
+ # Initialize the 'Specific Injury' column with default value
27
+ new_df['Specific Injury'] = None
28
+
29
+ # Function to extract specific injury classification based on keywords
30
+ def extract_specific_injury(note, injury):
31
+ note = note.lower()
32
+ if "left" in note:
33
+ return f"left {injury.lower()} injury"
34
+ elif "right" in note:
35
+ return f"right {injury.lower()} injury"
36
+ else:
37
+ return f"{injury.lower()} injury"
38
+
39
+ # Apply specific injury classification based on keywords
40
+ for injury in candidate_labels:
41
+ new_df.loc[new_df['Top Classification'].str.contains(injury, case=False, na=False), 'Specific Injury'] = \
42
+ new_df['Notes'].apply(lambda x: extract_specific_injury(x, injury) if injury.lower() in x.lower() else None)
43
+
44
+ # Sort by 'Top Score' in descending order
45
+ new_df_sorted = new_df.sort_values(by='Top Score', ascending=False)
46
+
47
+ # Return sorted DataFrame
48
+ return new_df_sorted[['Notes', 'Top Classification', 'Top Score', 'Specific Injury']]
49
+
50
+ # Set up Gradio interface
51
+ iface = gr.Interface(
52
+ fn=process_injury_notes,
53
+ inputs=gr.File(label="Upload CSV File (must have 'Notes' column)"),
54
+ outputs="dataframe",
55
+ title="Injury Classification App",
56
+ description="Upload a CSV file with injury notes to classify injuries based on given categories. Displays top classification and specificity."
57
+ )
58
+
59
+ # Launch the Gradio app
60
+ iface.launch()