atereoyinn commited on
Commit
cd33ba8
1 Parent(s): 08af9e7

first commit

Browse files
Files changed (2) hide show
  1. app.py +208 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from faster_whisper import WhisperModel
3
+ from pydantic import BaseModel, Field, AliasChoices, field_validator, ValidationError
4
+ from typing import List
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
+ import csv
7
+ import json
8
+ import tempfile
9
+ import torch
10
+
11
+
12
+ # Initiate checkpoints for model loading
13
+ numind_checkpoint = "numind/NuExtract-tiny"
14
+ llama_checkpoint = "Atereoyin/Llama3_finetuned_for_medical_entity_extraction"
15
+ whisper_checkpoint = "large-v3"
16
+
17
+ quantization_config = BitsAndBytesConfig(
18
+ load_in_8bit=True,
19
+ )
20
+
21
+ # Load models with the correct device
22
+ whisper_model = WhisperModel(whisper_checkpoint, device="cuda")
23
+ numind_model = AutoModelForCausalLM.from_pretrained(numind_checkpoint, quantization_config=quantization_config, torch_dtype=torch.float16, trust_remote_code=True)
24
+ numind_tokenizer = AutoTokenizer.from_pretrained(numind_checkpoint)
25
+ llama_model = AutoModelForCausalLM.from_pretrained(llama_checkpoint, quantization_config=quantization_config, trust_remote_code=True)
26
+ llama_tokenizer = AutoTokenizer.from_pretrained(llama_checkpoint)
27
+
28
+ # Function to transcribe audio
29
+ def transcribe_audio(audio_file_path):
30
+ try:
31
+ segments, info = whisper_model.transcribe(audio_file_path, beam_size=5)
32
+ text = "".join([segment.text for segment in segments])
33
+ return text
34
+ except Exception as e:
35
+ return str(e)
36
+
37
+ # Functions for Person entity extraction
38
+ def predict_NuExtract(model, tokenizer, text, schema, example=["","",""]):
39
+ schema = json.dumps(json.loads(schema), indent=4)
40
+ input_llm = "<|input|>\n### Template:\n" + schema + "\n"
41
+ for i in example:
42
+ if i != "":
43
+ input_llm += "### Example:\n"+ json.dumps(json.loads(i), indent=4)+"\n"
44
+
45
+ input_llm += "### Text:\n"+text +"\n<|output|>\n"
46
+ input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=4000).to("cuda")
47
+
48
+ output = tokenizer.decode(model.generate(**input_ids)[0], skip_special_tokens=True)
49
+ return output.split("<|output|>")[1].split("<|end-output|>")[0]
50
+
51
+
52
+ #Function for generating promtps for Llama
53
+ def prompt_format(text):
54
+ prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
55
+
56
+ ### Instruction:
57
+ {}
58
+
59
+ ### Input:
60
+ {}
61
+
62
+ ### Response:
63
+ {}"""
64
+
65
+ instruction = """Extract the following entities from the medical conversation:
66
+ * **Symptoms:** List all the symptoms the patient mentions.
67
+ * **Diagnosis:** List the doctor's diagnosis or potential diagnoses.
68
+ * **Medical History:** Summarize the patient's relevant medical history.
69
+ * **Action Plan:** List the recommended actions or treatment plan.
70
+
71
+ Provide the result in the following JSON format:
72
+ {
73
+ "Symptoms": [...],
74
+ "Diagnosis": [...],
75
+ "Medical history": [...],
76
+ "Action plan": [...]
77
+ }"""
78
+ full_prompt = prompt.format(instruction, text, "")
79
+ return full_prompt
80
+
81
+
82
+ #Pydantic Validator to validate Llama's response
83
+ def validate_medical_record(response):
84
+
85
+ class MedicalRecord(BaseModel):
86
+ Symptoms: List[str] = Field(default_factory=list)
87
+ Diagnosis: List[str] = Field(default_factory=list)
88
+ Medical_history: List[str] = Field(
89
+ default_factory=list,
90
+ validation_alias=AliasChoices('Medical history', 'History of Patient')
91
+ )
92
+ Action_plan: List[str] = Field(
93
+ default_factory=list,
94
+ validation_alias=AliasChoices('Action plan', 'Plan of Action')
95
+ )
96
+
97
+ @field_validator('*', mode='before')
98
+ def ensure_list(cls, v):
99
+ if isinstance(v, str):
100
+ return [item.strip() for item in v.split(',')]
101
+ return v
102
+
103
+ try:
104
+ validated_data = MedicalRecord(**response)
105
+ return validated_data.dict()
106
+ except ValidationError as e:
107
+ return response
108
+
109
+
110
+
111
+ # Function to predict medical entities using Llama
112
+ def predict_Llama(model, tokenizer, text):
113
+ inputs = tokenizer(prompt_format(text), return_tensors="pt", truncation=True).to("cuda")
114
+
115
+ try:
116
+ outputs = model.generate(**inputs, max_new_tokens=128, temperature=0.2, use_cache=True)
117
+ extracted_entities = tokenizer.decode(outputs[0], skip_special_tokens=True)
118
+
119
+ response = extracted_entities.split("### Response:", 1)[-1].strip()
120
+ response_dict = {k.strip(): v.strip() for k, v in (line.split(': ', 1) for line in response.splitlines() if ': ' in line)}
121
+
122
+ validated_response = validate_medical_record(response_dict)
123
+
124
+ return validated_response
125
+ except Exception as e:
126
+ print(f"Error during Llama prediction: {str(e)}")
127
+ return {}
128
+
129
+
130
+ #Control function that cordinates communication of other functions to map entities to form fields
131
+ def process_audio(audio):
132
+ if isinstance(audio, str):
133
+ with open(audio, 'rb') as f:
134
+ audio_bytes = f.read()
135
+ else:
136
+ audio_bytes = audio
137
+
138
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
139
+ temp_audio.write(audio_bytes)
140
+ temp_audio.flush()
141
+ audio_path = temp_audio.name
142
+
143
+ transcription = transcribe_audio(audio_path)
144
+
145
+ person_schema = """{"Name": "","Age": "","Gender": ""}"""
146
+ person_entities_raw = predict_NuExtract(numind_model, numind_tokenizer, transcription, person_schema)
147
+
148
+ try:
149
+ person_entities = json.loads(person_entities_raw)
150
+ except json.JSONDecodeError as e:
151
+ return f"Error in NuExtract response: {str(e)}"
152
+
153
+ medical_entities = predict_Llama(llama_model, llama_tokenizer, transcription)
154
+
155
+ return (
156
+ person_entities.get("Name", ""),
157
+ person_entities.get("Age", ""),
158
+ person_entities.get("Gender", ""),
159
+ ", ".join(medical_entities.get("Symptoms", [])),
160
+ ", ".join(medical_entities.get("Diagnosis", [])),
161
+ ", ".join(medical_entities.get("Medical_history", [])),
162
+ ", ".join(medical_entities.get("Action_plan", []))
163
+ )
164
+
165
+
166
+
167
+ #Function that allows users to download information
168
+ def download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan):
169
+ csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
170
+
171
+ with open(csv_file.name, mode='w', newline='') as file:
172
+ writer = csv.writer(file)
173
+ writer.writerow(["Name", "Age", "Gender", "Symptoms", "Diagnosis", "Medical History", "Plan of Action"])
174
+ writer.writerow([name, age, gender, symptoms, diagnosis, medical_history, action_plan])
175
+
176
+ return csv_file.name
177
+
178
+
179
+
180
+ # Gradio interface to create a web-based form for users to input audio and fill the medical diagnostic form
181
+ demo = gr.Interface(
182
+ fn=process_audio,
183
+ inputs=[
184
+ gr.Audio(type="filepath")
185
+ ],
186
+ outputs=[
187
+ gr.Textbox(label="Name"),
188
+ gr.Textbox(label="Age"),
189
+ gr.Textbox(label="Gender"),
190
+ gr.Textbox(label="Symptoms"),
191
+ gr.Textbox(label="Diagnosis"),
192
+ gr.Textbox(label="Medical History"),
193
+ gr.Textbox(label="Plan of Action"),
194
+ ],
195
+ title="Medical Diagnostic Form Assistant",
196
+ description="Upload an audio file or record audio to generate a medical diagnostic form."
197
+ )
198
+
199
+ with demo:
200
+ download_button = gr.Button("Download CSV")
201
+ download_button.click(
202
+ fn=lambda name, age, gender, symptoms, diagnosis, medical_history, action_plan: download_csv(name, age, gender, symptoms, diagnosis, medical_history, action_plan),
203
+ inputs=demo.output_components,
204
+ outputs=gr.File(label="Download CSV")
205
+ )
206
+
207
+ demo.launch(share=True)
208
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ faster-whisper
5
+ pydantic