import argparse |
from dataclasses import dataclass |
from pathlib import Path |
import json |
import hashlib |
import os |
from typing import List, Tuple, Iterator |
import assemblyai as aai |
from google import generativeai |
from pydub import AudioSegment |
import asyncio |
import io |
from multiprocessing import Pool |
from functools import partial |
from itertools import groupby |
@dataclass |
class Utterance: |
"""A single utterance from a speaker""" |
speaker: str |
text: str |
start: int |
end: int |
@property |
def timestamp(self) -> str: |
"""Format start time as HH:MM:SS""" |
seconds = int(self.start // 1000) |
hours = seconds // 3600 |
minutes = (seconds % 3600) // 60 |
seconds = seconds % 60 |
return f"{hours:02d}:{minutes:02d}:{seconds:02d}" |
class Transcriber: |
"""Handles getting and caching transcripts from AssemblyAI""" |
def __init__(self, api_key: str): |
aai.settings.api_key = api_key |
self.cache_dir = Path("output/transcripts/.cache") |
self.cache_dir.mkdir(parents=True, exist_ok=True) |
def get_transcript(self, audio_path: Path) -> List[Utterance]: |
"""Get transcript, using cache if available""" |
cache_file = self.cache_dir / f"{audio_path.stem}.json" |
if cache_file.exists(): |
with open(cache_file) as f: |
data = json.load(f) |
if data["hash"] == self._get_file_hash(audio_path): |
print("Using cached AssemblyAI transcript...") |
return [ |
Utterance( |
speaker=u["speaker"], |
text=u["text"], |
start=u["start"], |
end=u["end"] |
) |
for u in data["utterances"] |
] |
print("Getting new transcript from AssemblyAI...") |
config = aai.TranscriptionConfig(speaker_labels=True, language_code="en") |
transcript = aai.Transcriber().transcribe(str(audio_path), config=config) |
utterances = [ |
Utterance( |
speaker=u.speaker, |
text=u.text, |
start=u.start, |
end=u.end |
) |
for u in transcript.utterances |
] |
cache_data = { |
"hash": self._get_file_hash(audio_path), |
"utterances": [ |
{ |
"speaker": u.speaker, |
"text": u.text, |
"start": u.start, |
"end": u.end |
} |
for u in utterances |
] |
} |
with open(cache_file, "w") as f: |
json.dump(cache_data, f, indent=2) |
return utterances |
def _get_file_hash(self, file_path: Path) -> str: |
"""Calculate MD5 hash of a file""" |
hash_md5 = hashlib.md5() |
with open(file_path, "rb") as f: |
for chunk in iter(lambda: f.read(4096), b""): |
hash_md5.update(chunk) |
return hash_md5.hexdigest() |
class Enhancer: |
"""Handles enhancing transcripts using Gemini""" |
def __init__(self, api_key: str): |
generativeai.configure(api_key=api_key) |
self.model = generativeai.GenerativeModel("gemini-exp-1206") |
self.prompt = Path("prompts/enhance.txt").read_text() |
async def enhance_chunks(self, chunks: List[Tuple[str, io.BytesIO]]) -> List[str]: |
"""Enhance multiple transcript chunks concurrently with concurrency control""" |
print(f"Enhancing {len(chunks)} chunks...") |
semaphore = asyncio.Semaphore(3) |
async def process_chunk(i: int, chunk: Tuple[str, io.BytesIO]) -> str: |
text, audio = chunk |
async with semaphore: |
audio.seek(0) |
response = await self.model.generate_content_async( |
[self.prompt, text, {"mime_type": "audio/mp3", "data": audio.read()}] |
) |
print(f"Completed chunk {i+1}/{len(chunks)}") |
return response.text |
tasks = [ |
process_chunk(i, chunk) |
for i, chunk in enumerate(chunks) |
] |
results = await asyncio.gather(*tasks) |
return results |
@dataclass |
class SpeakerDialogue: |
"""Represents a continuous section of speech from a single speaker""" |
speaker: str |
utterances: List[Utterance] |
@property |
def start(self) -> int: |
"""Start time of first utterance""" |
return self.utterances[0].start |
@property |
def end(self) -> int: |
"""End time of last utterance""" |
return self.utterances[-1].end |
@property |
def timestamp(self) -> str: |
"""Format start time as HH:MM:SS""" |
return self.utterances[0].timestamp |
def format(self, markdown: bool = False) -> str: |
"""Format this dialogue as text with newlines between utterances |
Args: |
markdown: If True, add markdown formatting for speaker and timestamp |
""" |
texts = [u.text + "\n\n" for u in self.utterances] |
combined_text = ''.join(texts).rstrip() |
if markdown: |
return f"**Speaker {self.speaker}** *{self.timestamp}*\n\n{combined_text}" |
return f"Speaker {self.speaker} {self.timestamp}\n\n{combined_text}" |
def group_utterances_by_speaker(utterances: List[Utterance]) -> Iterator[SpeakerDialogue]: |
"""Group consecutive utterances by the same speaker""" |
for speaker, group in groupby(utterances, key=lambda u: u.speaker): |
yield SpeakerDialogue(speaker=speaker, utterances=list(group)) |
def estimate_tokens(text: str, chars_per_token: int = 4) -> int: |
""" |
Estimate number of tokens in text |
Args: |
text: The text to estimate tokens for |
chars_per_token: Estimated characters per token (default 4) |
""" |
return (len(text) + chars_per_token - 1) // chars_per_token |
def chunk_dialogues( |
dialogues: Iterator[SpeakerDialogue], |
max_tokens: int = 2000, |
chars_per_token: int = 4 |
) -> List[List[SpeakerDialogue]]: |
""" |
Split dialogues into chunks that fit within token limit |
Args: |
dialogues: Iterator of SpeakerDialogues |
max_tokens: Maximum tokens per chunk |
chars_per_token: Estimated characters per token (default 4) |
""" |
chunks = [] |
current_chunk = [] |
current_text = "" |
for dialogue in dialogues: |
formatted = dialogue.format() |
new_text = current_text + "\n\n" + formatted if current_text else formatted |
if current_chunk and estimate_tokens(new_text, chars_per_token) > max_tokens: |
chunks.append(current_chunk) |
current_chunk = [dialogue] |
current_text = formatted |
else: |
current_chunk.append(dialogue) |
current_text = new_text |
if current_chunk: |
chunks.append(current_chunk) |
return chunks |
def format_chunk(dialogues: List[SpeakerDialogue], markdown: bool = False) -> str: |
"""Format a chunk of dialogues into readable text |
Args: |
dialogues: List of dialogues to format |
markdown: If True, add markdown formatting for speaker and timestamp |
""" |
return "\n\n".join(dialogue.format(markdown=markdown) for dialogue in dialogues) |
def prepare_audio_chunks(audio_path: Path, utterances: List[Utterance]) -> List[Tuple[str, io.BytesIO]]: |
"""Prepare audio chunks and their corresponding text""" |
dialogues = group_utterances_by_speaker(utterances) |
chunks = chunk_dialogues(dialogues) |
print(f"Preparing {len(chunks)} audio segments...") |
audio = AudioSegment.from_file(audio_path) |
prepared = [] |
for chunk in chunks: |
segment = audio[chunk[0].start:chunk[-1].end] |
buffer = io.BytesIO() |
segment.export(buffer, format="mp3", parameters=["-q:a", "9"]) |
prepared.append((format_chunk(chunk, markdown=False), buffer)) |
return prepared |
def main(): |
parser = argparse.ArgumentParser() |
parser.add_argument("audio_file", help="Audio file to transcribe") |
args = parser.parse_args() |
audio_path = Path(args.audio_file) |
if not audio_path.exists(): |
raise FileNotFoundError(f"File not found: {audio_path}") |
out_dir = Path("output/transcripts") |
out_dir.mkdir(parents=True, exist_ok=True) |
try: |
transcriber = Transcriber(os.getenv("ASSEMBLYAI_API_KEY")) |
utterances = transcriber.get_transcript(audio_path) |
dialogues = list(group_utterances_by_speaker(utterances)) |
original = format_chunk(dialogues, markdown=True) |
(out_dir / "autogenerated-transcript.md").write_text(original) |
enhancer = Enhancer(os.getenv("GOOGLE_API_KEY")) |
chunks = prepare_audio_chunks(audio_path, utterances) |
enhanced = asyncio.run(enhancer.enhance_chunks(chunks)) |
merged = "\n\n".join(chunk.strip() for chunk in enhanced) |
merged = apply_markdown_formatting(merged) |
(out_dir / "transcript.md").write_text(merged) |
print("\nTranscripts saved to:") |
print(f"- {out_dir}/autogenerated-transcript.md") |
print(f"- {out_dir}/transcript.md") |
except Exception as e: |
print(f"Error: {e}") |
return 1 |
return 0 |
def apply_markdown_formatting(text: str) -> str: |
"""Apply markdown formatting to speaker and timestamp in the transcript""" |
import re |
pattern = r"(Speaker \w+) (\d{2}:\d{2}:\d{2})" |
return re.sub(pattern, r"**\1** *\2*", text) |
if __name__ == "__main__": |
main() |