karthi311 commited on
Commit
ac75500
·
verified ·
1 Parent(s): abc09a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -42
app.py CHANGED
@@ -1,16 +1,12 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer
4
- from gptq import GPTQForCausalLM # GPTQ model handler
5
  from pydub import AudioSegment
6
  from sentence_transformers import SentenceTransformer, util
7
  import spacy
8
- spacy.cli.download("en_core_web_sm")
9
- from subprocess import Popen, PIPE
10
  import json
11
  from faster_whisper import WhisperModel
12
 
13
-
14
  # Audio conversion from MP4 to MP3
15
  def convert_mp4_to_mp3(mp4_path, mp3_path):
16
  try:
@@ -19,7 +15,6 @@ def convert_mp4_to_mp3(mp4_path, mp3_path):
19
  except Exception as e:
20
  raise RuntimeError(f"Error converting MP4 to MP3: {e}")
21
 
22
-
23
  # Check if CUDA is available for GPU acceleration
24
  if torch.cuda.is_available():
25
  device = "cuda"
@@ -28,36 +23,21 @@ else:
28
  device = "cpu"
29
  compute_type = "int8"
30
 
31
-
32
  # Load Faster Whisper Model for transcription
33
  def load_faster_whisper():
34
  model = WhisperModel("deepdml/faster-whisper-large-v3-turbo-ct2", device=device, compute_type=compute_type)
35
  return model
36
 
37
-
38
- # Load GPTQ Mistral-7B model
39
- def load_mistral_model():
40
- model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
41
-
42
- # Load the tokenizer
43
- tokenizer = AutoTokenizer.from_pretrained(model_name)
44
-
45
- # Load the GPTQ model
46
- model = GPTQForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
47
-
48
- return model, tokenizer
49
-
50
-
51
  # Load NLP model and other helpers
52
  nlp = spacy.load("en_core_web_sm")
53
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
54
 
 
55
  tokenizer = AutoTokenizer.from_pretrained("Mahalingam/DistilBart-Med-Summary")
56
  model = AutoModelForSeq2SeqLM.from_pretrained("Mahalingam/DistilBart-Med-Summary")
57
 
58
  summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
59
 
60
-
61
  soap_prompts = {
62
  "subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.",
63
  "objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.",
@@ -66,19 +46,31 @@ soap_prompts = {
66
  }
67
  soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()}
68
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Llama query function (same as before)
71
- def llama_query(user_prompt, soap_note, model="llama3.2"):
72
  combined_prompt = f"User Instructions:\n{user_prompt}\n\nContext:\n{soap_note}"
73
  try:
74
- process = Popen(['ollama', 'run', model], stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True, encoding='utf-8')
75
- stdout, stderr = process.communicate(input=combined_prompt)
76
- if process.returncode != 0:
77
- return f"Error: {stderr.strip()}"
78
- return stdout.strip()
 
 
 
 
79
  except Exception as e:
80
- return f"Unexpected error: {str(e)}"
81
-
82
 
83
  # Convert the response to JSON format
84
  def convert_to_json(template):
@@ -96,7 +88,6 @@ def convert_to_json(template):
96
  except Exception as e:
97
  return f"Error converting to JSON: {e}"
98
 
99
-
100
  # Transcription using Faster Whisper
101
  def transcribe_audio(mp4_path):
102
  try:
@@ -112,13 +103,11 @@ def transcribe_audio(mp4_path):
112
  except Exception as e:
113
  return f"Error during transcription: {e}"
114
 
115
-
116
  # Classify the sentence to the correct SOAP section
117
  def classify_sentence(sentence):
118
  similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()}
119
  return max(similarities, key=similarities.get)
120
 
121
-
122
  # Summarize the section if it's too long
123
  def summarize_section(section_text):
124
  if len(section_text.split()) < 50:
@@ -134,7 +123,6 @@ def summarize_section(section_text):
134
  )
135
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
136
 
137
-
138
  # Analyze the SOAP content and divide into sections
139
  def soap_analysis(text):
140
  doc = nlp(text)
@@ -150,7 +138,6 @@ def soap_analysis(text):
150
 
151
  return format_soap_output(soap_note)
152
 
153
-
154
  # Format the SOAP note output
155
  def format_soap_output(soap_note):
156
  return (
@@ -160,7 +147,6 @@ def format_soap_output(soap_note):
160
  f"Plan:\n{soap_note['plan']}\n"
161
  )
162
 
163
-
164
  # Process file function for audio to SOAP
165
  def process_file(mp4_file, user_prompt):
166
  transcription = transcribe_audio(mp4_file.name)
@@ -176,7 +162,6 @@ def process_file(mp4_file, user_prompt):
176
 
177
  return soap_note, template_output, json_output
178
 
179
-
180
  # Process text function for text input to SOAP
181
  def process_text(text, user_prompt):
182
  soap_note = soap_analysis(text)
@@ -188,7 +173,6 @@ def process_text(text, user_prompt):
188
 
189
  return soap_note, template_output, json_output
190
 
191
-
192
  # Launch the Gradio interface
193
  def launch_gradio():
194
  with gr.Blocks(theme=gr.themes.Default()) as demo:
@@ -202,7 +186,7 @@ def launch_gradio():
202
  ],
203
  outputs=[
204
  gr.Textbox(label="SOAP Note"),
205
- gr.Textbox(label="Generated Template from Mistral-7B Instruct"),
206
  gr.Textbox(label="JSON Output"),
207
  ],
208
  )
@@ -215,13 +199,12 @@ def launch_gradio():
215
  ],
216
  outputs=[
217
  gr.Textbox(label="SOAP Note"),
218
- gr.Textbox(label="Generated Template from Mistral-7B Instruct"),
219
  gr.Textbox(label="JSON Output"),
220
  ],
221
  )
222
  demo.launch(share=True, debug=True)
223
 
224
-
225
  # Run the Gradio app
226
  if __name__ == "__main__":
227
  launch_gradio()
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
4
  from pydub import AudioSegment
5
  from sentence_transformers import SentenceTransformer, util
6
  import spacy
 
 
7
  import json
8
  from faster_whisper import WhisperModel
9
 
 
10
  # Audio conversion from MP4 to MP3
11
  def convert_mp4_to_mp3(mp4_path, mp3_path):
12
  try:
 
15
  except Exception as e:
16
  raise RuntimeError(f"Error converting MP4 to MP3: {e}")
17
 
 
18
  # Check if CUDA is available for GPU acceleration
19
  if torch.cuda.is_available():
20
  device = "cuda"
 
23
  device = "cpu"
24
  compute_type = "int8"
25
 
 
26
  # Load Faster Whisper Model for transcription
27
  def load_faster_whisper():
28
  model = WhisperModel("deepdml/faster-whisper-large-v3-turbo-ct2", device=device, compute_type=compute_type)
29
  return model
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Load NLP model and other helpers
32
  nlp = spacy.load("en_core_web_sm")
33
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
34
 
35
+ # Load Summarizer Model (DistilBart-Med-Summary)
36
  tokenizer = AutoTokenizer.from_pretrained("Mahalingam/DistilBart-Med-Summary")
37
  model = AutoModelForSeq2SeqLM.from_pretrained("Mahalingam/DistilBart-Med-Summary")
38
 
39
  summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
40
 
 
41
  soap_prompts = {
42
  "subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.",
43
  "objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.",
 
46
  }
47
  soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()}
48
 
49
+ # Load LLaMA 7B model and tokenizer
50
+ def load_llama_model():
51
+ tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
52
+ model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
53
+ model.to(device)
54
+ return model, tokenizer
55
+
56
+ # Initialize LLaMA model
57
+ llama_model, llama_tokenizer = load_llama_model()
58
 
59
+ # Query function for LLaMA
60
+ def llama_query(user_prompt, soap_note):
61
  combined_prompt = f"User Instructions:\n{user_prompt}\n\nContext:\n{soap_note}"
62
  try:
63
+ inputs = llama_tokenizer(combined_prompt, return_tensors="pt", truncation=True, max_length=4096).to(device)
64
+ outputs = llama_model.generate(
65
+ inputs["input_ids"],
66
+ max_length=512,
67
+ temperature=0.7,
68
+ num_beams=4,
69
+ no_repeat_ngram_size=3
70
+ )
71
+ return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
72
  except Exception as e:
73
+ return f"Error generating response: {e}"
 
74
 
75
  # Convert the response to JSON format
76
  def convert_to_json(template):
 
88
  except Exception as e:
89
  return f"Error converting to JSON: {e}"
90
 
 
91
  # Transcription using Faster Whisper
92
  def transcribe_audio(mp4_path):
93
  try:
 
103
  except Exception as e:
104
  return f"Error during transcription: {e}"
105
 
 
106
  # Classify the sentence to the correct SOAP section
107
  def classify_sentence(sentence):
108
  similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()}
109
  return max(similarities, key=similarities.get)
110
 
 
111
  # Summarize the section if it's too long
112
  def summarize_section(section_text):
113
  if len(section_text.split()) < 50:
 
123
  )
124
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
125
 
 
126
  # Analyze the SOAP content and divide into sections
127
  def soap_analysis(text):
128
  doc = nlp(text)
 
138
 
139
  return format_soap_output(soap_note)
140
 
 
141
  # Format the SOAP note output
142
  def format_soap_output(soap_note):
143
  return (
 
147
  f"Plan:\n{soap_note['plan']}\n"
148
  )
149
 
 
150
  # Process file function for audio to SOAP
151
  def process_file(mp4_file, user_prompt):
152
  transcription = transcribe_audio(mp4_file.name)
 
162
 
163
  return soap_note, template_output, json_output
164
 
 
165
  # Process text function for text input to SOAP
166
  def process_text(text, user_prompt):
167
  soap_note = soap_analysis(text)
 
173
 
174
  return soap_note, template_output, json_output
175
 
 
176
  # Launch the Gradio interface
177
  def launch_gradio():
178
  with gr.Blocks(theme=gr.themes.Default()) as demo:
 
186
  ],
187
  outputs=[
188
  gr.Textbox(label="SOAP Note"),
189
+ gr.Textbox(label="Generated Template from LLaMA"),
190
  gr.Textbox(label="JSON Output"),
191
  ],
192
  )
 
199
  ],
200
  outputs=[
201
  gr.Textbox(label="SOAP Note"),
202
+ gr.Textbox(label="Generated Template from LLaMA"),
203
  gr.Textbox(label="JSON Output"),
204
  ],
205
  )
206
  demo.launch(share=True, debug=True)
207
 
 
208
  # Run the Gradio app
209
  if __name__ == "__main__":
210
  launch_gradio()