Kr08 commited on
Commit
a3f7705
·
verified ·
1 Parent(s): 8dac0cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -59
app.py CHANGED
@@ -1,63 +1,139 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ from audio_processing import process_audio
3
+ from transformers import pipeline
4
+ import spaces
5
+ import torch
6
+ import logging
7
+ import traceback
8
+ import sys
9
+
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
13
+ handlers=[
14
+ logging.StreamHandler(sys.stdout)
15
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ def load_summarization_model():
20
+ logger.info("Loading summarization model...")
21
+ try:
22
+ cuda_available = torch.cuda.is_available()
23
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=0 if cuda_available else -1)
24
+ logger.info(f"Summarization model loaded successfully on {'GPU' if cuda_available else 'CPU'}")
25
+ return summarizer
26
+ except Exception as e:
27
+ logger.warning(f"Failed to load summarization model on GPU. Falling back to CPU. Error: {str(e)}")
28
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=-1)
29
+ logger.info("Summarization model loaded successfully on CPU")
30
+ return summarizer
31
+
32
+ def process_with_fallback(func, *args, **kwargs):
33
+ try:
34
+ return func(*args, **kwargs)
35
+ except Exception as e:
36
+ logger.error(f"Error during processing: {str(e)}")
37
+ logger.error(traceback.format_exc())
38
+ if "CUDA" in str(e) or "GPU" in str(e):
39
+ logger.info("Falling back to CPU processing...")
40
+ kwargs['use_gpu'] = False
41
+ return func(*args, **kwargs)
42
+ else:
43
+ raise
44
+
45
+ @spaces.GPU(duration=60)
46
+ def transcribe_audio(audio_file, translate, model_size):
47
+ logger.info(f"Starting transcription: translate={translate}, model_size={model_size}")
48
+ try:
49
+ result = process_with_fallback(process_audio, audio_file, translate=translate, model_size=model_size) # use_diarization=use_diarization
50
+ logger.info("Transcription completed successfully")
51
+ return result
52
+ except Exception as e:
53
+ logger.error(f"Transcription failed: {str(e)}")
54
+ raise gr.Error(f"Transcription failed: {str(e)}")
55
+
56
+ @spaces.GPU(duration=60)
57
+ def summarize_text(text):
58
+ logger.info("Starting text summarization")
59
+ try:
60
+ summarizer = load_summarization_model()
61
+ summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
62
+ logger.info("Summarization completed successfully")
63
+ return summary
64
+ except Exception as e:
65
+ logger.error(f"Summarization failed: {str(e)}")
66
+ logger.error(traceback.format_exc())
67
+ return "Error occurred during summarization. Please try again."
68
+
69
+ @spaces.GPU(duration=60)
70
+ def process_and_summarize(audio_file, translate, model_size, do_summarize=True):
71
+ logger.info(f"Starting process_and_summarize: translate={translate}, model_size={model_size}, do_summarize={do_summarize}")
72
+ try:
73
+ language_segments, final_segments = transcribe_audio(audio_file, translate, model_size)
74
+
75
+ # transcription = "Detected language changes:\n\n"
76
+ transcription = ""
77
+ for segment in language_segments:
78
+ transcription += f"Language: {segment['language']}\n"
79
+ transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
80
+
81
+ transcription += f"Transcription with language detection and speaker diarization (using {model_size} model):\n\n"
82
+ full_text = ""
83
+ for segment in final_segments:
84
+ transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:\n"
85
+ transcription += f"Original: {segment['text']}\n"
86
+ if translate:
87
+ transcription += f"Translated: {segment['translated']}\n"
88
+ full_text += segment['translated'] + " "
89
+ else:
90
+ full_text += segment['text'] + " "
91
+ transcription += "\n"
92
+
93
+ summary = summarize_text(full_text) if do_summarize else ""
94
+ logger.info("Process and summarize completed successfully")
95
+ return transcription, full_text, summary
96
+ except Exception as e:
97
+ logger.error(f"Process and summarize failed: {str(e)}\n")
98
+ logger.error(traceback.format_exc())
99
+ raise gr.Error(f"Processing failed: {str(e)}")
100
+
101
+ # Main interface
102
+ with gr.Blocks() as iface:
103
+ gr.Markdown("# WhisperX Audio Transcription, Translation, and Summarization (with ZeroGPU support)")
104
+
105
+ audio_input = gr.Audio(type="filepath")
106
+ translate_checkbox = gr.Checkbox(label="Enable Translation")
107
+ summarize_checkbox = gr.Checkbox(label="Enable Summarization", interactive=False)
108
+ # diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
109
+ model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
110
+ process_button = gr.Button("Process Audio")
111
+ transcription_output = gr.Textbox(label="Transcription/Translation")
112
+ full_text_output = gr.Textbox(label="Transcription/Translation")
113
+ summary_output = gr.Textbox(label="Summary")
114
+
115
+
116
+ def update_summarize_checkbox(translate):
117
+ return gr.Checkbox(interactive=translate)
118
+
119
+ translate_checkbox.change(update_summarize_checkbox, inputs=[translate_checkbox], outputs=[summarize_checkbox])
120
+
121
+ process_button.click(
122
+ process_and_summarize,
123
+ inputs=[audio_input, translate_checkbox, model_dropdown, summarize_checkbox],
124
+ outputs=[transcription_output, full_text_output, summary_output]
125
+ )
126
 
127
+ gr.Markdown(
128
+ f"""
129
+ ## System Information
130
+ - Device: {"CUDA" if torch.cuda.is_available() else "CPU"}
131
+ - CUDA Available: {"Yes" if torch.cuda.is_available() else "No"}
132
+
133
+ ## ZeroGPU Support
134
+ This application supports ZeroGPU for Hugging Face Spaces pro users.
135
+ GPU-intensive tasks are automatically optimized for better performance when available.
136
+ """
137
+ )
138
 
139
+ iface.launch()