darag's picture
Update app.py
2c66d5b verified
raw
history blame contribute delete
No virus
3.55 kB
import os
import torch
import librosa
import numpy as np
import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
def format_time(milliseconds):
seconds, milliseconds = divmod(int(milliseconds), 1000)
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
def detect_speech_activity(y, sr, frame_length=1024, hop_length=512, threshold=0.01):
energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
speech_frames = energy > threshold
speech_regions = []
in_speech = False
for i, speech in enumerate(speech_frames):
if speech and not in_speech:
start = i
in_speech = True
elif not speech and in_speech:
end = i
speech_regions.append((start * hop_length / sr, end * hop_length / sr))
in_speech = False
if in_speech:
speech_regions.append((start * hop_length / sr, len(y) / sr))
return speech_regions
def post_process_text(text):
text = text.replace(" ", " ")
text = text.strip()
return text
def transcribe_audio(audio_file):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = "Akashpb13/xlsr_kurmanji_kurdish"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
y, sr = librosa.load(audio_file, sr=16000)
voiced_segments = detect_speech_activity(y, sr, threshold=0.005)
srt_content = ""
for i, (start, end) in enumerate(voiced_segments, start=1):
segment_audio = y[int(start * sr):int(end * sr)]
input_values = processor(segment_audio, sampling_rate=sr, return_tensors="pt").input_values
input_values = input_values.to(device)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
transcription = post_process_text(transcription)
if transcription:
start_time = format_time(start * 1000)
end_time = format_time(end * 1000)
srt_content += f"{i}\n"
srt_content += f"{start_time} --> {end_time}\n"
# Break long lines into shorter ones (max 50 characters)
words = transcription.split()
lines = []
current_line = ""
for word in words:
if len(current_line) + len(word) > 50:
lines.append(current_line.strip())
current_line = ""
current_line += word + " "
if current_line:
lines.append(current_line.strip())
srt_content += "\n".join(lines) + "\n\n"
return srt_content
def save_srt(audio_file):
srt_content = transcribe_audio(audio_file)
output_filename = "output.srt"
with open(output_filename, "w", encoding="utf-8") as f:
f.write(srt_content)
return output_filename, srt_content
iface = gr.Interface(
fn=save_srt,
inputs=gr.Audio(type="filepath"),
outputs=[
gr.File(label="Download SRT"),
gr.Textbox(label="SRT Content", lines=10)
],
title="Kurdish Speech-to-Text Transcription",
description="Upload an audio file to generate a SRT subtitle file with Kurdish transcription."
)
if __name__ == "__main__":
iface.launch()