Spaces:
Sleeping
Sleeping
import asyncio | |
import logging | |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
import json | |
from typing import List | |
from functools import lru_cache | |
import os | |
import yt_dlp | |
import whisper | |
import tempfile | |
import google.generativeai as genai | |
import faiss | |
import numpy as np | |
import hashlib | |
app = FastAPI() | |
# Load the multilingual model | |
model_name = "paraphrase-multilingual-mpnet-base-v2" # This model supports 50+ languages | |
model = SentenceTransformer(model_name) | |
# Ensure the chroma_db directory exists | |
chroma_db_path = "/tmp/chroma_db" | |
os.makedirs(chroma_db_path, exist_ok=True) | |
# Initialize persistent ChromaDB | |
chroma_client = chromadb.PersistentClient(path=chroma_db_path) | |
collection_name = "transcriptions" | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Configure Gemini | |
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') # Replace with your actual API key | |
genai.configure(api_key=GOOGLE_API_KEY) | |
dimension = model.get_sentence_embedding_dimension() | |
index = faiss.IndexFlatL2(dimension) | |
transcription_data = [] | |
query_cache = {} | |
embedding_cache = {} | |
# Increased default batch size | |
DEFAULT_BATCH_SIZE = 256 | |
def embed_text(text: str): | |
return model.encode(text) | |
def embed_texts(texts: List[str], batch_size=DEFAULT_BATCH_SIZE): | |
embeddings = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i + batch_size] | |
embeddings.extend(model.encode(batch)) | |
return embeddings | |
def precompute_embeddings(transcription, batch_size=DEFAULT_BATCH_SIZE): | |
global transcription_data, index | |
transcription_data = transcription | |
texts = [item['text'] for item in transcription] | |
transcription_hash = hashlib.md5(json.dumps(transcription).encode()).hexdigest() | |
if transcription_hash in embedding_cache: | |
embeddings = embedding_cache[transcription_hash] | |
else: | |
logger.info(f"Computing embeddings with batch size {batch_size}") | |
embeddings = embed_texts(texts, batch_size) | |
embedding_cache[transcription_hash] = embeddings | |
index = faiss.IndexFlatL2(dimension) | |
index.add(np.array(embeddings)) | |
async def generate_gemini_response(context, sentence): | |
model = genai.GenerativeModel('gemini-1.5-flash-latest') | |
response = await asyncio.to_thread(model.generate_content, | |
f"Given the following context from a video transcription:\n\n{context}\n\nGenerate a relevant response to the query: {sentence}" | |
) | |
return response.text | |
async def generate_text(sentence: str = Form(...), file: UploadFile = File(...), | |
batch_size: int = Form(DEFAULT_BATCH_SIZE)): | |
try: | |
contents = await file.read() | |
transcription = json.loads(contents) | |
start_time = asyncio.get_event_loop().time() | |
precompute_embeddings(transcription, batch_size) | |
precompute_time = asyncio.get_event_loop().time() - start_time | |
logger.info(f"Embeddings precomputed and stored in {precompute_time:.2f} seconds") | |
if sentence in query_cache: | |
return query_cache[sentence] | |
query_embedding = embed_text(sentence) | |
D, I = index.search(np.array([query_embedding]), 10) | |
relevant_lines = [transcription_data[i] for i in I[0]] | |
context = "\n".join([f"{line['offset']}: {line['text']}" for line in relevant_lines]) | |
gemini_start_time = asyncio.get_event_loop().time() | |
gemini_response = await generate_gemini_response(context, sentence) | |
gemini_time = asyncio.get_event_loop().time() - gemini_start_time | |
result = {"relevant_lines": relevant_lines, "gemini_response": gemini_response} | |
query_cache[sentence] = result | |
total_time = asyncio.get_event_loop().time() - start_time | |
logger.info(f"Total processing time: {total_time:.2f} seconds") | |
logger.info(f"Gemini response time: {gemini_time:.2f} seconds") | |
return result | |
except Exception as e: | |
logger.error(f"Error during processing: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
def greet_json(): | |
return {"Hello": "World!"} | |
async def extract_youtube_audio(url: str = Form(...)): | |
try: | |
# Configure yt-dlp options | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'wav', | |
}], | |
'outtmpl': 'temp_audio.%(ext)s', | |
} | |
# Download audio from YouTube | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
# Load Whisper model | |
model = whisper.load_model("base") | |
# Transcribe audio | |
result = model.transcribe("temp_audio.wav") | |
# Format the result | |
formatted_result = [] | |
for segment in result["segments"]: | |
formatted_result.append({ | |
"duration": round(segment["end"] - segment["start"], 3), | |
"lang": segment["language"], | |
"offset": round(segment["start"], 3), | |
"text": segment["text"].strip() | |
}) | |
# Clean up temporary file | |
os.remove("temp_audio.wav") | |
return formatted_result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |