karthi311 commited on
Commit
4d605d9
·
verified ·
1 Parent(s): 6c2aa57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -32
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline,AutoTokenizer, AutoModelForSeq2SeqLM
4
  from pydub import AudioSegment
5
  from sentence_transformers import SentenceTransformer, util
6
  import spacy
7
- spacy.cli.download("en_core_web_sm")
8
  import json
9
- from faster_whisper import WhisperModel
10
 
11
  # Audio conversion from MP4 to MP3
12
  def convert_mp4_to_mp3(mp4_path, mp3_path):
@@ -16,7 +15,6 @@ def convert_mp4_to_mp3(mp4_path, mp3_path):
16
  except Exception as e:
17
  raise RuntimeError(f"Error converting MP4 to MP3: {e}")
18
 
19
-
20
  # Check if CUDA is available for GPU acceleration
21
  if torch.cuda.is_available():
22
  device = "cuda"
@@ -25,13 +23,11 @@ else:
25
  device = "cpu"
26
  compute_type = "int8"
27
 
28
-
29
  # Load Faster Whisper Model for transcription
30
  def load_faster_whisper():
31
  model = WhisperModel("deepdml/faster-whisper-large-v3-turbo-ct2", device=device, compute_type=compute_type)
32
  return model
33
 
34
-
35
  # Load NLP model and other helpers
36
  nlp = spacy.load("en_core_web_sm")
37
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
@@ -41,7 +37,6 @@ model = AutoModelForSeq2SeqLM.from_pretrained("Mahalingam/DistilBart-Med-Summary
41
 
42
  summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
43
 
44
-
45
  soap_prompts = {
46
  "subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.",
47
  "objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.",
@@ -50,23 +45,15 @@ soap_prompts = {
50
  }
51
  soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()}
52
 
53
-
54
- # Load llava model and processor
55
- processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
56
- model = AutoModelForImageTextToText.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
57
-
58
-
59
- # Query function for LLAVA
60
- def llava_query(user_prompt, soap_note):
61
  combined_prompt = f"User Instructions:\n{user_prompt}\n\nContext:\n{soap_note}"
62
  try:
63
- inputs = processor(combined_prompt, return_tensors="pt", padding=True)
64
- outputs = model.generate(**inputs)
65
- return processor.decode(outputs[0], skip_special_tokens=True)
66
  except Exception as e:
67
  return f"Error generating response: {e}"
68
 
69
-
70
  # Convert the response to JSON format
71
  def convert_to_json(template):
72
  try:
@@ -83,7 +70,6 @@ def convert_to_json(template):
83
  except Exception as e:
84
  return f"Error converting to JSON: {e}"
85
 
86
-
87
  # Transcription using Faster Whisper
88
  def transcribe_audio(mp4_path):
89
  try:
@@ -99,13 +85,11 @@ def transcribe_audio(mp4_path):
99
  except Exception as e:
100
  return f"Error during transcription: {e}"
101
 
102
-
103
  # Classify the sentence to the correct SOAP section
104
  def classify_sentence(sentence):
105
  similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()}
106
  return max(similarities, key=similarities.get)
107
 
108
-
109
  # Summarize the section if it's too long
110
  def summarize_section(section_text):
111
  if len(section_text.split()) < 50:
@@ -121,7 +105,6 @@ def summarize_section(section_text):
121
  )
122
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
123
 
124
-
125
  # Analyze the SOAP content and divide into sections
126
  def soap_analysis(text):
127
  doc = nlp(text)
@@ -137,7 +120,6 @@ def soap_analysis(text):
137
 
138
  return format_soap_output(soap_note)
139
 
140
-
141
  # Format the SOAP note output
142
  def format_soap_output(soap_note):
143
  return (
@@ -147,7 +129,6 @@ def format_soap_output(soap_note):
147
  f"Plan:\n{soap_note['plan']}\n"
148
  )
149
 
150
-
151
  # Process file function for audio to SOAP
152
  def process_file(mp4_file, user_prompt):
153
  transcription = transcribe_audio(mp4_file.name)
@@ -156,26 +137,24 @@ def process_file(mp4_file, user_prompt):
156
  soap_note = soap_analysis(transcription)
157
  print("SOAP Notes: ", soap_note)
158
 
159
- template_output = llava_query(user_prompt, soap_note)
160
  print("Template: ", template_output)
161
 
162
  json_output = convert_to_json(template_output)
163
 
164
  return soap_note, template_output, json_output
165
 
166
-
167
  # Process text function for text input to SOAP
168
  def process_text(text, user_prompt):
169
  soap_note = soap_analysis(text)
170
  print(soap_note)
171
 
172
- template_output = llava_query(user_prompt, soap_note)
173
  print(template_output)
174
  json_output = convert_to_json(template_output)
175
 
176
  return soap_note, template_output, json_output
177
 
178
-
179
  # Launch the Gradio interface
180
  def launch_gradio():
181
  with gr.Blocks(theme=gr.themes.Default()) as demo:
@@ -189,7 +168,7 @@ def launch_gradio():
189
  ],
190
  outputs=[
191
  gr.Textbox(label="SOAP Note"),
192
- gr.Textbox(label="Generated Template from LLAVA"),
193
  gr.Textbox(label="JSON Output"),
194
  ],
195
  )
@@ -202,13 +181,12 @@ def launch_gradio():
202
  ],
203
  outputs=[
204
  gr.Textbox(label="SOAP Note"),
205
- gr.Textbox(label="Generated Template from LLAVA"),
206
  gr.Textbox(label="JSON Output"),
207
  ],
208
  )
209
  demo.launch(share=True, debug=True)
210
 
211
-
212
  # Run the Gradio app
213
  if __name__ == "__main__":
214
  launch_gradio()
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
  from pydub import AudioSegment
5
  from sentence_transformers import SentenceTransformer, util
6
  import spacy
 
7
  import json
8
+ import ollama
9
 
10
  # Audio conversion from MP4 to MP3
11
  def convert_mp4_to_mp3(mp4_path, mp3_path):
 
15
  except Exception as e:
16
  raise RuntimeError(f"Error converting MP4 to MP3: {e}")
17
 
 
18
  # Check if CUDA is available for GPU acceleration
19
  if torch.cuda.is_available():
20
  device = "cuda"
 
23
  device = "cpu"
24
  compute_type = "int8"
25
 
 
26
  # Load Faster Whisper Model for transcription
27
  def load_faster_whisper():
28
  model = WhisperModel("deepdml/faster-whisper-large-v3-turbo-ct2", device=device, compute_type=compute_type)
29
  return model
30
 
 
31
  # Load NLP model and other helpers
32
  nlp = spacy.load("en_core_web_sm")
33
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
 
37
 
38
  summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
39
 
 
40
  soap_prompts = {
41
  "subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.",
42
  "objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.",
 
45
  }
46
  soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()}
47
 
48
+ # Ollama Llama 2 Model Query function
49
+ def ollama_query(user_prompt, soap_note):
 
 
 
 
 
 
50
  combined_prompt = f"User Instructions:\n{user_prompt}\n\nContext:\n{soap_note}"
51
  try:
52
+ response = ollama.run("llama2:7b-uncensored", prompt=combined_prompt)
53
+ return response
 
54
  except Exception as e:
55
  return f"Error generating response: {e}"
56
 
 
57
  # Convert the response to JSON format
58
  def convert_to_json(template):
59
  try:
 
70
  except Exception as e:
71
  return f"Error converting to JSON: {e}"
72
 
 
73
  # Transcription using Faster Whisper
74
  def transcribe_audio(mp4_path):
75
  try:
 
85
  except Exception as e:
86
  return f"Error during transcription: {e}"
87
 
 
88
  # Classify the sentence to the correct SOAP section
89
  def classify_sentence(sentence):
90
  similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()}
91
  return max(similarities, key=similarities.get)
92
 
 
93
  # Summarize the section if it's too long
94
  def summarize_section(section_text):
95
  if len(section_text.split()) < 50:
 
105
  )
106
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
107
 
 
108
  # Analyze the SOAP content and divide into sections
109
  def soap_analysis(text):
110
  doc = nlp(text)
 
120
 
121
  return format_soap_output(soap_note)
122
 
 
123
  # Format the SOAP note output
124
  def format_soap_output(soap_note):
125
  return (
 
129
  f"Plan:\n{soap_note['plan']}\n"
130
  )
131
 
 
132
  # Process file function for audio to SOAP
133
  def process_file(mp4_file, user_prompt):
134
  transcription = transcribe_audio(mp4_file.name)
 
137
  soap_note = soap_analysis(transcription)
138
  print("SOAP Notes: ", soap_note)
139
 
140
+ template_output = ollama_query(user_prompt, soap_note)
141
  print("Template: ", template_output)
142
 
143
  json_output = convert_to_json(template_output)
144
 
145
  return soap_note, template_output, json_output
146
 
 
147
  # Process text function for text input to SOAP
148
  def process_text(text, user_prompt):
149
  soap_note = soap_analysis(text)
150
  print(soap_note)
151
 
152
+ template_output = ollama_query(user_prompt, soap_note)
153
  print(template_output)
154
  json_output = convert_to_json(template_output)
155
 
156
  return soap_note, template_output, json_output
157
 
 
158
  # Launch the Gradio interface
159
  def launch_gradio():
160
  with gr.Blocks(theme=gr.themes.Default()) as demo:
 
168
  ],
169
  outputs=[
170
  gr.Textbox(label="SOAP Note"),
171
+ gr.Textbox(label="Generated Template from Llama 2"),
172
  gr.Textbox(label="JSON Output"),
173
  ],
174
  )
 
181
  ],
182
  outputs=[
183
  gr.Textbox(label="SOAP Note"),
184
+ gr.Textbox(label="Generated Template from Llama 2"),
185
  gr.Textbox(label="JSON Output"),
186
  ],
187
  )
188
  demo.launch(share=True, debug=True)
189
 
 
190
  # Run the Gradio app
191
  if __name__ == "__main__":
192
  launch_gradio()