SearchInVideo / app.py
kayrakan
fix
fb86cf2
import asyncio
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from sentence_transformers import SentenceTransformer
import chromadb
import json
from typing import List
from functools import lru_cache
import os
import yt_dlp
import whisper
import tempfile
import google.generativeai as genai
import faiss
import numpy as np
import hashlib
app = FastAPI()
# Load the multilingual model
model_name = "paraphrase-multilingual-mpnet-base-v2" # This model supports 50+ languages
model = SentenceTransformer(model_name)
# Ensure the chroma_db directory exists
chroma_db_path = "/tmp/chroma_db"
os.makedirs(chroma_db_path, exist_ok=True)
# Initialize persistent ChromaDB
chroma_client = chromadb.PersistentClient(path=chroma_db_path)
collection_name = "transcriptions"
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configure Gemini
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') # Replace with your actual API key
genai.configure(api_key=GOOGLE_API_KEY)
dimension = model.get_sentence_embedding_dimension()
index = faiss.IndexFlatL2(dimension)
transcription_data = []
query_cache = {}
embedding_cache = {}
# Increased default batch size
DEFAULT_BATCH_SIZE = 256
@lru_cache(maxsize=1000)
def embed_text(text: str):
return model.encode(text)
def embed_texts(texts: List[str], batch_size=DEFAULT_BATCH_SIZE):
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
embeddings.extend(model.encode(batch))
return embeddings
def precompute_embeddings(transcription, batch_size=DEFAULT_BATCH_SIZE):
global transcription_data, index
transcription_data = transcription
texts = [item['text'] for item in transcription]
transcription_hash = hashlib.md5(json.dumps(transcription).encode()).hexdigest()
if transcription_hash in embedding_cache:
embeddings = embedding_cache[transcription_hash]
else:
logger.info(f"Computing embeddings with batch size {batch_size}")
embeddings = embed_texts(texts, batch_size)
embedding_cache[transcription_hash] = embeddings
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings))
async def generate_gemini_response(context, sentence):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
response = await asyncio.to_thread(model.generate_content,
f"Given the following context from a video transcription:\n\n{context}\n\nGenerate a relevant response to the query: {sentence}"
)
return response.text
@app.post("/generate")
async def generate_text(sentence: str = Form(...), file: UploadFile = File(...),
batch_size: int = Form(DEFAULT_BATCH_SIZE)):
try:
contents = await file.read()
transcription = json.loads(contents)
start_time = asyncio.get_event_loop().time()
precompute_embeddings(transcription, batch_size)
precompute_time = asyncio.get_event_loop().time() - start_time
logger.info(f"Embeddings precomputed and stored in {precompute_time:.2f} seconds")
if sentence in query_cache:
return query_cache[sentence]
query_embedding = embed_text(sentence)
D, I = index.search(np.array([query_embedding]), 10)
relevant_lines = [transcription_data[i] for i in I[0]]
context = "\n".join([f"{line['offset']}: {line['text']}" for line in relevant_lines])
gemini_start_time = asyncio.get_event_loop().time()
gemini_response = await generate_gemini_response(context, sentence)
gemini_time = asyncio.get_event_loop().time() - gemini_start_time
result = {"relevant_lines": relevant_lines, "gemini_response": gemini_response}
query_cache[sentence] = result
total_time = asyncio.get_event_loop().time() - start_time
logger.info(f"Total processing time: {total_time:.2f} seconds")
logger.info(f"Gemini response time: {gemini_time:.2f} seconds")
return result
except Exception as e:
logger.error(f"Error during processing: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/extract_youtube_audio")
async def extract_youtube_audio(url: str = Form(...)):
try:
# Configure yt-dlp options
ydl_opts = {
'format': 'bestaudio/best',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'wav',
}],
'outtmpl': 'temp_audio.%(ext)s',
}
# Download audio from YouTube
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
# Load Whisper model
model = whisper.load_model("base")
# Transcribe audio
result = model.transcribe("temp_audio.wav")
# Format the result
formatted_result = []
for segment in result["segments"]:
formatted_result.append({
"duration": round(segment["end"] - segment["start"], 3),
"lang": segment["language"],
"offset": round(segment["start"], 3),
"text": segment["text"].strip()
})
# Clean up temporary file
os.remove("temp_audio.wav")
return formatted_result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)