Spaces:
Runtime error
Runtime error
File size: 5,229 Bytes
a647c50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import logging
import torch
import base64
import os
from indexify_extractor_sdk import Content, Extractor, Feature
from pyannote.audio import Pipeline
from transformers import pipeline, AutoModelForCausalLM
from .diarization_utils import diarize
from huggingface_hub import HfApi
from starlette.exceptions import HTTPException
from pydantic import BaseModel
from pydantic_settings import BaseSettings
from typing import Optional, Literal, List, Union
logger = logging.getLogger(__name__)
token = os.getenv('HF_TOKEN')
class ModelSettings(BaseSettings):
asr_model: str = "openai/whisper-large-v3"
assistant_model: Optional[str] = "distil-whisper/distil-large-v3"
diarization_model: Optional[str] = "pyannote/speaker-diarization-3.1"
hf_token: Optional[str] = token
model_settings = ModelSettings()
class ASRExtractorConfig(BaseModel):
task: Literal["transcribe", "translate"] = "transcribe"
batch_size: int = 24
assisted: bool = False
chunk_length_s: int = 30
sampling_rate: int = 16000
language: Optional[str] = None
num_speakers: Optional[int] = None
min_speakers: Optional[int] = None
max_speakers: Optional[int] = None
class ASRExtractor(Extractor):
name = "tensorlake/asrdiarization"
description = "Powerful ASR + diarization + speculative decoding."
system_dependencies = ["ffmpeg"]
input_mime_types = ["audio", "audio/mpeg"]
def __init__(self):
super(ASRExtractor, self).__init__()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logger.info(f"Using device: {device.type}")
torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
self.assistant_model = AutoModelForCausalLM.from_pretrained(
model_settings.assistant_model,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
) if model_settings.assistant_model else None
if self.assistant_model:
self.assistant_model.to(device)
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model_settings.asr_model,
torch_dtype=torch_dtype,
device=device
)
if model_settings.diarization_model:
# diarization pipeline doesn't raise if there is no token
HfApi().whoami(model_settings.hf_token)
self.diarization_pipeline = Pipeline.from_pretrained(
checkpoint_path=model_settings.diarization_model,
use_auth_token=model_settings.hf_token,
)
self.diarization_pipeline.to(device)
else:
self.diarization_pipeline = None
def extract(self, content: Content, params: ASRExtractorConfig) -> List[Union[Feature, Content]]:
file = base64.b64decode(content.data)
logger.info(f"inference params: {params}")
generate_kwargs = {
"task": params.task,
"language": params.language,
"assistant_model": self.assistant_model if params.assisted else None
}
try:
asr_outputs = self.asr_pipeline(
file,
chunk_length_s=params.chunk_length_s,
batch_size=params.batch_size,
generate_kwargs=generate_kwargs,
return_timestamps=True,
)
except RuntimeError as e:
logger.error(f"ASR inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error diring ASR inference: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")
if self.diarization_pipeline:
try:
transcript = diarize(self.diarization_pipeline, file, params, asr_outputs)
except RuntimeError as e:
logger.error(f"Diarization inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error during diarization: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
else:
transcript = []
feature = Feature.metadata(value={"chunks": asr_outputs["chunks"], "text": asr_outputs["text"]})
return [Content.from_text(str(transcript), features=[feature])]
def sample_input(self) -> Content:
filepath = "sample.mp3"
with open(filepath, 'rb') as f:
audio_encoded = base64.b64encode(f.read()).decode("utf-8")
return Content(content_type="audio/mpeg", data=audio_encoded)
if __name__ == "__main__":
filepath = "sample.mp3"
with open(filepath, 'rb') as f:
audio_encoded = base64.b64encode(f.read()).decode("utf-8")
data = Content(content_type="audio/mpeg", data=audio_encoded)
params = ASRExtractorConfig(batch_size=24)
extractor = ASRExtractor()
results = extractor.extract(data, params=params)
print(results) |