Kr08 commited on
Commit
127f69f
·
verified ·
1 Parent(s): 9116075

Added llama 3.1 instruct q/a functionality for testing

Browse files
Files changed (1) hide show
  1. app.py +61 -3
app.py CHANGED
@@ -16,6 +16,25 @@ logging.basicConfig(
16
  )
17
  logger = logging.getLogger(__name__)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def load_summarization_model():
20
  logger.info("Loading summarization model...")
21
  try:
@@ -29,6 +48,7 @@ def load_summarization_model():
29
  logger.info("Summarization model loaded successfully on CPU")
30
  return summarizer
31
 
 
32
  def process_with_fallback(func, *args, **kwargs):
33
  try:
34
  return func(*args, **kwargs)
@@ -42,6 +62,7 @@ def process_with_fallback(func, *args, **kwargs):
42
  else:
43
  raise
44
 
 
45
  @spaces.GPU(duration=60)
46
  def transcribe_audio(audio_file, translate, model_size):
47
  logger.info(f"Starting transcription: translate={translate}, model_size={model_size}")
@@ -53,6 +74,7 @@ def transcribe_audio(audio_file, translate, model_size):
53
  logger.error(f"Transcription failed: {str(e)}")
54
  raise gr.Error(f"Transcription failed: {str(e)}")
55
 
 
56
  @spaces.GPU(duration=60)
57
  def summarize_text(text):
58
  logger.info("Starting text summarization")
@@ -98,20 +120,50 @@ def process_and_summarize(audio_file, translate, model_size, do_summarize=True):
98
  logger.error(traceback.format_exc())
99
  raise gr.Error(f"Processing failed: {str(e)}")
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Main interface
102
  with gr.Blocks() as iface:
103
- gr.Markdown("# WhisperX Audio Transcription, Translation, and Summarization (with ZeroGPU support)")
104
 
105
  audio_input = gr.Audio(type="filepath")
106
  translate_checkbox = gr.Checkbox(label="Enable Translation")
107
  summarize_checkbox = gr.Checkbox(label="Enable Summarization", interactive=False)
108
- # diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
109
  model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
110
  process_button = gr.Button("Process Audio")
111
  transcription_output = gr.Textbox(label="Transcription/Translation")
112
- full_text_output = gr.Textbox(label="Transcription/Translation")
113
  summary_output = gr.Textbox(label="Summary")
114
 
 
 
 
115
 
116
  def update_summarize_checkbox(translate):
117
  return gr.Checkbox(interactive=translate)
@@ -123,6 +175,12 @@ with gr.Blocks() as iface:
123
  inputs=[audio_input, translate_checkbox, model_dropdown, summarize_checkbox],
124
  outputs=[transcription_output, full_text_output, summary_output]
125
  )
 
 
 
 
 
 
126
 
127
  gr.Markdown(
128
  f"""
 
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
+
20
+
21
+ def load_qa_model():
22
+ logger.info("Loading Q&A model...")
23
+ try:
24
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
25
+ qa_pipeline = pipeline(
26
+ "text-generation",
27
+ model=model_id,
28
+ model_kwargs={"torch_dtype": torch.bfloat16},
29
+ device_map="auto",
30
+ )
31
+ logger.info(f"Q&A model loaded successfully")
32
+ return qa_pipeline
33
+ except Exception as e:
34
+ logger.warning(f"Failed to load Q&A model. Error: {str(e)}")
35
+ return None
36
+
37
+
38
  def load_summarization_model():
39
  logger.info("Loading summarization model...")
40
  try:
 
48
  logger.info("Summarization model loaded successfully on CPU")
49
  return summarizer
50
 
51
+
52
  def process_with_fallback(func, *args, **kwargs):
53
  try:
54
  return func(*args, **kwargs)
 
62
  else:
63
  raise
64
 
65
+
66
  @spaces.GPU(duration=60)
67
  def transcribe_audio(audio_file, translate, model_size):
68
  logger.info(f"Starting transcription: translate={translate}, model_size={model_size}")
 
74
  logger.error(f"Transcription failed: {str(e)}")
75
  raise gr.Error(f"Transcription failed: {str(e)}")
76
 
77
+
78
  @spaces.GPU(duration=60)
79
  def summarize_text(text):
80
  logger.info("Starting text summarization")
 
120
  logger.error(traceback.format_exc())
121
  raise gr.Error(f"Processing failed: {str(e)}")
122
 
123
+
124
+ @spaces.GPU(duration=60)
125
+ def answer_question(context, question):
126
+ logger.info("Starting Q&A process")
127
+ try:
128
+ qa_pipeline = load_qa_model()
129
+ if qa_pipeline is None:
130
+ return "Error: Q&A model could not be loaded."
131
+
132
+ messages = [
133
+ {"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."},
134
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"},
135
+ ]
136
+
137
+ outputs = qa_pipeline(messages, max_new_tokens=256)
138
+ answer = outputs[0]["generated_text"]
139
+
140
+ # Extract the answer from the generated text
141
+ answer = answer.split("assistant:")[-1].strip()
142
+
143
+ logger.info("Q&A process completed successfully")
144
+ return answer
145
+ except Exception as e:
146
+ logger.error(f"Q&A process failed: {str(e)}")
147
+ logger.error(traceback.format_exc())
148
+ return "Error occurred during Q&A process. Please try again."
149
+
150
+
151
  # Main interface
152
  with gr.Blocks() as iface:
153
+ gr.Markdown("# WhisperX Audio Transcription, Translation, Summarization, and Q&A (with ZeroGPU support)")
154
 
155
  audio_input = gr.Audio(type="filepath")
156
  translate_checkbox = gr.Checkbox(label="Enable Translation")
157
  summarize_checkbox = gr.Checkbox(label="Enable Summarization", interactive=False)
 
158
  model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
159
  process_button = gr.Button("Process Audio")
160
  transcription_output = gr.Textbox(label="Transcription/Translation")
161
+ full_text_output = gr.Textbox(label="Full Text")
162
  summary_output = gr.Textbox(label="Summary")
163
 
164
+ question_input = gr.Textbox(label="Ask a question about the transcription")
165
+ answer_button = gr.Button("Get Answer")
166
+ answer_output = gr.Textbox(label="Answer")
167
 
168
  def update_summarize_checkbox(translate):
169
  return gr.Checkbox(interactive=translate)
 
175
  inputs=[audio_input, translate_checkbox, model_dropdown, summarize_checkbox],
176
  outputs=[transcription_output, full_text_output, summary_output]
177
  )
178
+
179
+ answer_button.click(
180
+ answer_question,
181
+ inputs=[full_text_output, question_input],
182
+ outputs=[answer_output]
183
+ )
184
 
185
  gr.Markdown(
186
  f"""