Cryptic commited on
Commit
eb91ddc
1 Parent(s): 90bcc62
Files changed (2) hide show
  1. app.py +48 -58
  2. requirements.txt +3 -3
app.py CHANGED
@@ -1,70 +1,60 @@
1
- import os
2
  import tempfile
3
- import json
4
- import librosa
5
- import numpy as np
6
  import soundfile as sf
7
- import torch
8
- import gradio as gr
9
  from transformers import pipeline
10
 
11
- # Load models globally to avoid reloading on every request
12
- device = 0 if torch.cuda.is_available() else -1
13
- models = {
14
- 'transcriber': pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=device, chunk_length_s=30),
15
- 'summarizer': pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device)
16
- }
17
 
18
- def load_and_convert_audio(audio_path):
19
- """Load audio using librosa and convert to WAV format"""
20
- audio_data, sample_rate = librosa.load(audio_path, sr=16000) # Whisper expects 16kHz
21
- audio_data = audio_data.astype(np.float32)
22
-
23
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_wav:
24
- sf.write(temp_wav.name, audio_data, sample_rate, format='WAV')
25
- return temp_wav.name
26
 
27
- def process_audio(audio_file):
28
- """Process audio file and return transcription and summary"""
29
- results = {}
30
-
31
- try:
32
- temp_wav_path = load_and_convert_audio(audio_file.name)
33
-
34
- # Transcription
35
- transcription = models['transcriber'](temp_wav_path, return_timestamps=True)
36
- results['transcription'] = transcription['text'] if isinstance(transcription, dict) else ' '.join([chunk['text'] for chunk in transcription])
37
-
38
- # Summarization
39
- text = results['transcription']
40
- words = text.split()
41
- chunk_size = 1000
42
- chunks = [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
43
-
44
- summaries = [models['summarizer'](chunk, max_length=200, min_length=50, truncation=True)[0]['summary_text'] for chunk in chunks]
45
- results['summary'] = ' '.join(summaries)
46
 
47
- except Exception as e:
48
- return {'error': str(e)} # Return error message if something goes wrong
 
 
49
 
50
- finally:
51
- if os.path.exists(temp_wav_path):
52
- os.unlink(temp_wav_path)
53
 
54
- return results
 
 
 
55
 
56
- def gradio_interface(audio):
57
- """Gradio interface function"""
58
- return process_audio(audio)
59
 
60
- # Create Gradio interface
61
- iface = gr.Interface(
62
- fn=gradio_interface,
63
- inputs=gr.inputs.Audio(source="upload", type="file", label="Upload Audio File"),
64
- outputs=["json"],
65
- title="Audio Transcription and Summarization",
66
- description="Upload an audio file to get its transcription and summary."
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- if __name__ == "__main__":
70
- iface.launch()
 
1
+ import streamlit as st
2
  import tempfile
 
 
 
3
  import soundfile as sf
 
 
4
  from transformers import pipeline
5
 
6
+ # Load models
7
+ transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=-1)
8
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=-1)
9
+ question_generator = pipeline("text2text-generation", model="google/t5-efficient-tiny", device=-1)
 
 
10
 
11
+ # Upload audio file
12
+ uploaded_file = st.file_uploader("Upload Audio", type=["wav", "mp3"])
 
 
 
 
 
 
13
 
14
+ if uploaded_file is not None:
15
+ # Save the uploaded file to a temporary file
16
+ with tempfile.NamedTemporaryFile(delete=False) as temp_audio_file:
17
+ temp_audio_file.write(uploaded_file.getbuffer())
18
+ temp_audio_path = temp_audio_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Read the audio file using SoundFile
21
+ try:
22
+ # Load audio data
23
+ audio_data, sample_rate = sf.read(temp_audio_path)
24
 
25
+ # Transcribing audio
26
+ lecture_text = transcriber(temp_audio_path)["text"]
 
27
 
28
+ # Preprocessing data
29
+ num_words = len(lecture_text.split())
30
+ max_length = min(num_words, 1024) # BART model max input length is 1024 tokens
31
+ max_length = int(max_length * 0.75) # Convert max words to approx tokens
32
 
33
+ if max_length > 1024:
34
+ lecture_text = lecture_text[:int(1024 / 0.75)] # Truncate to fit the model's token limit
 
35
 
36
+ # Summarization
37
+ summary = summarizer(
38
+ lecture_text,
39
+ max_length=1024, # DistilBART max input length is 1024 tokens
40
+ min_length=int(max_length * 0.1),
41
+ truncation=True
42
+ )
43
+
44
+ # Clean up the summary text
45
+ if not summary[0]["summary_text"].endswith((".", "!", "?")):
46
+ last_period_index = summary[0]["summary_text"].rfind(".")
47
+ if last_period_index != -1:
48
+ summary[0]["summary_text"] = summary[0]["summary_text"][:last_period_index + 1]
49
+
50
+ # Questions Generation
51
+ context = f"Based on the following lecture summary: {summary[0]['summary_text']}, generate some relevant practice questions."
52
+ questions = question_generator(context, max_new_tokens=50)
53
+
54
+ # Output
55
+ st.write("\nSummary:\n", summary[0]["summary_text"])
56
+ for question in questions:
57
+ st.write(question["generated_text"]) # Output the generated questions
58
 
59
+ except Exception as e:
60
+ st.error(f"Error during processing: {str(e)}")
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- gradio
 
2
  torch
3
  soundfile
4
- transformers
5
  numpy
6
- flask
 
1
+ streamlit
2
+ transformers
3
  torch
4
  soundfile
 
5
  numpy
6
+ librosa