ledphrey commited on
Commit
3061f29
·
verified ·
1 Parent(s): ecadb62

transfer from thankrandomness

Browse files
Files changed (2) hide show
  1. app.py +187 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from datasets import load_dataset, DatasetDict
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import chromadb
6
+ import gradio as gr
7
+ import numpy as np
8
+ from sklearn.metrics import precision_score, recall_score, f1_score
9
+
10
+ # Mean Pooling - Take attention mask into account for correct averaging
11
+ def meanpooling(output, mask):
12
+ embeddings = output[0] # First element of model_output contains all token embeddings
13
+ mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
14
+ return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
15
+
16
+ # Load the dataset
17
+ dataset = load_dataset("thankrandomness/mimic-iii")
18
+
19
+ # Split the dataset into train and validation sets
20
+ split_dataset = dataset['train'].train_test_split(test_size=0.2, seed=42)
21
+ dataset = DatasetDict({
22
+ 'train': split_dataset['train'],
23
+ 'validation': split_dataset['test']
24
+ })
25
+
26
+ # Load the model and tokenizer
27
+ tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
28
+ model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings-matryoshka")
29
+
30
+ # Function to normalize embeddings to unit vectors
31
+ def normalize_embedding(embedding):
32
+ norm = np.linalg.norm(embedding)
33
+ return (embedding / norm).tolist() if norm > 0 else embedding
34
+
35
+ # Function to embed and normalize text
36
+ def embed_text(text):
37
+ inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
38
+ with torch.no_grad():
39
+ output = model(**inputs)
40
+ embeddings = meanpooling(output, inputs['attention_mask'])
41
+ normalized_embeddings = normalize_embedding(embeddings.numpy())
42
+ return normalized_embeddings
43
+
44
+ # Initialize ChromaDB client
45
+ client = chromadb.Client()
46
+ collection = client.create_collection(name="pubmedbert_matryoshka_embeddings")
47
+
48
+ # Function to upsert data into ChromaDB
49
+ def upsert_data(dataset_split):
50
+ for i, row in enumerate(dataset_split):
51
+ for note in row['notes']:
52
+ text = note.get('text', '')
53
+ annotations_list = []
54
+
55
+ for annotation in note.get('annotations', []):
56
+ try:
57
+ code = annotation['code']
58
+ code_system = annotation['code_system']
59
+ description = annotation['description']
60
+ annotations_list.append({"code": code, "code_system": code_system, "description": description})
61
+ except KeyError as e:
62
+ print(f"Skipping annotation due to missing key: {e}")
63
+
64
+ if text and annotations_list:
65
+ embeddings = embed_text([text])[0]
66
+
67
+ # Upsert data, embeddings, and annotations into ChromaDB
68
+ for j, annotation in enumerate(annotations_list):
69
+ collection.upsert(
70
+ ids=[f"note_{note['note_id']}_{j}"],
71
+ embeddings=[embeddings],
72
+ metadatas=[annotation]
73
+ )
74
+ else:
75
+ print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'")
76
+
77
+ # Upsert training data
78
+ upsert_data(dataset['train'])
79
+
80
+ # Define retrieval function with similarity threshold
81
+ def retrieve_relevant_text(input_text):
82
+ input_embedding = embed_text([input_text])[0]
83
+ results = collection.query(
84
+ query_embeddings=[input_embedding],
85
+ n_results=5,
86
+ include=["metadatas", "documents", "distances"]
87
+ )
88
+
89
+ output = []
90
+ #print("Retrieved items and their similarity scores:")
91
+ for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
92
+ #print(f"Code: {metadata['code']}, Similarity Score: {distance}")
93
+ #if distance <= similarity_threshold:
94
+ output.append({
95
+ "similarity_score": distance,
96
+ "code": metadata['code'],
97
+ "code_system": metadata['code_system'],
98
+ "description": metadata['description']
99
+ })
100
+
101
+ # if not output:
102
+ # print("No results met the similarity threshold.")
103
+ return output
104
+
105
+ # Evaluate retrieval efficiency on the validation/test set
106
+ def evaluate_efficiency(dataset_split):
107
+ y_true = []
108
+ y_pred = []
109
+ total_similarity = 0
110
+ total_items = 0
111
+
112
+ for i, row in enumerate(dataset_split):
113
+ for note in row['notes']:
114
+ text = note.get('text', '')
115
+ annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
116
+
117
+ if text and annotations_list:
118
+ retrieved_results = retrieve_relevant_text(text)
119
+ retrieved_codes = [result['code'] for result in retrieved_results]
120
+
121
+ # Sum up similarity scores for average calculation
122
+ for result in retrieved_results:
123
+ total_similarity += result['similarity_score']
124
+ total_items += 1
125
+
126
+ # Ground truth
127
+ y_true.extend(annotations_list)
128
+ # Predictions (limit to length of true annotations to avoid mismatch)
129
+ y_pred.extend(retrieved_codes[:len(annotations_list)])
130
+
131
+ # for result in retrieved_results:
132
+ # print(f" Code: {result['code']}, Similarity Score: {result['similarity_score']:.2f}")
133
+
134
+ # Debugging output to check for mismatches and understand results
135
+ # print("Sample y_true:", y_true[:10])
136
+ # print("Sample y_pred:", y_pred[:10])
137
+
138
+ if total_items > 0:
139
+ avg_similarity = total_similarity / total_items
140
+ else:
141
+ avg_similarity = 0
142
+
143
+ if len(y_true) != len(y_pred):
144
+ min_length = min(len(y_true), len(y_pred))
145
+ y_true = y_true[:min_length]
146
+ y_pred = y_pred[:min_length]
147
+
148
+ # Calculate metrics
149
+ precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
150
+ recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
151
+ f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
152
+
153
+ return precision, recall, f1, avg_similarity
154
+
155
+ # Calculate retrieval efficiency metrics
156
+ precision, recall, f1, avg_similarity = evaluate_efficiency(dataset['validation'])
157
+
158
+ # Gradio interface
159
+ def gradio_interface(input_text):
160
+ results = retrieve_relevant_text(input_text)
161
+ formatted_results = [
162
+ f"Result {i + 1}:\n"
163
+ f"Similarity Score: {result['similarity_score']:.2f}\n"
164
+ f"Code: {result['code']}\n"
165
+ f"Code System: {result['code_system']}\n"
166
+ f"Description: {result['description']}\n"
167
+ "-------------------"
168
+ for i, result in enumerate(results)
169
+ ]
170
+ return "\n".join(formatted_results)
171
+
172
+ # Display retrieval efficiency metrics
173
+ # metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
174
+ metrics = f"Accuracy: {avg_similarity:.2f}"
175
+
176
+ with gr.Blocks() as interface:
177
+ gr.Markdown("# Automated Medical Coding POC")
178
+ # gr.Markdown(metrics)
179
+ with gr.Row():
180
+ with gr.Column():
181
+ text_input = gr.Textbox(label="Input Text")
182
+ submit_button = gr.Button("Submit")
183
+ with gr.Column():
184
+ text_output = gr.Textbox(label="Retrieved Results", lines=10)
185
+ submit_button.click(fn=gradio_interface, inputs=text_input, outputs=text_output)
186
+
187
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ chromadb
5
+ gradio
6
+ numpy
7
+ scikit-learn