YT-Trainer / app.py
Fred808's picture
Update app.py
390f7f3 verified
raw
history blame
9.01 kB
import re
import json
import numpy as np
import faiss
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import (
pipeline,
AutoModelForSequenceClassification,
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
T5Tokenizer,
T5ForConditionalGeneration,
)
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from datasets import Features, Value
from googleapiclient.discovery import build
from youtube_transcript_api import YouTubeTranscriptApi
# Initialize FastAPI app
app = FastAPI()
# YouTube Data API setup
API_KEY = "AIzaSyDBdxA6KdOwtaaTgt26EBYRyvknOObmgAc"
YOUTUBE_API_SERVICE_NAME = "youtube"
YOUTUBE_API_VERSION = "v3"
youtube = build(YOUTUBE_API_SERVICE_NAME, YOUTUBE_API_VERSION, developerKey=API_KEY)
# Preprocessing function
def preprocess_text(text):
"""
Cleans and tokenizes text.
"""
text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE) # Remove URLs
text = re.sub(r"\s+", " ", text).strip() # Remove extra spaces
text = re.sub(r"[^\w\s]", "", text) # Remove punctuation
return text.lower()
# Content Classification Model
class ContentClassifier:
def __init__(self, model_name="bert-base-uncased"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.pipeline = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
def classify(self, text):
"""
Classifies text into predefined categories.
"""
result = self.pipeline(text)
return result
# Relevance Detection Model
class RelevanceDetector:
def __init__(self, model_name="bert-base-uncased"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.pipeline = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
def detect_relevance(self, text, threshold=0.5):
"""
Detects whether a text is relevant to a specific domain.
"""
result = self.pipeline(text)
return result[0]["label"] == "RELEVANT" and result[0]["score"] > threshold
# Topic Extraction Model using BERTopic
class TopicExtractor:
def __init__(self):
self.model = BERTopic()
def extract_topics(self, documents):
"""
Extracts topics from a list of documents.
"""
topics, probs = self.model.fit_transform(documents)
return self.model.get_topic_info()
# Summarization Model
class Summarizer:
def __init__(self, model_name="t5-small"):
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.model = T5ForConditionalGeneration.from_pretrained(model_name)
def summarize(self, text, max_length=100):
"""
Summarizes a given text.
"""
inputs = self.tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = self.model.generate(inputs, max_length=max_length, min_length=25, length_penalty=2.0, num_beams=4)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# Search and Recommendation Model using FAISS
class SearchEngine:
def __init__(self, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
self.model = SentenceTransformer(embedding_model)
self.index = None
self.documents = []
def build_index(self, docs):
"""
Builds a FAISS index for document retrieval.
"""
self.documents = docs
embeddings = self.model.encode(docs, convert_to_tensor=True, show_progress_bar=True)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings.cpu().detach().numpy())
def search(self, query, top_k=5):
"""
Searches the index for the top_k most relevant documents.
"""
query_embedding = self.model.encode(query, convert_to_tensor=True)
distances, indices = self.index.search(query_embedding.cpu().detach().numpy().reshape(1, -1), top_k)
return [(self.documents[i], distances[0][i]) for i in indices[0]]
# Conversational Model using GPT-2
class Chatbot:
def __init__(self, model_name="gpt2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_response(self, prompt, max_length=50):
"""
Generates a response to a user query using GPT-2.
"""
inputs = self.tokenizer.encode(prompt, return_tensors="pt")
outputs = self.model.generate(inputs, max_length=max_length, num_return_sequences=1)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Initialize models
classifier = ContentClassifier()
relevance_detector = RelevanceDetector()
summarizer = Summarizer()
search_engine = SearchEngine()
topic_extractor = TopicExtractor()
chatbot = Chatbot()
# Fetch video metadata using YouTube Data API
def fetch_video_metadata(video_id):
request = youtube.videos().list(
part="snippet,statistics",
id=video_id
)
response = request.execute()
return response["items"][0] if response["items"] else None
# Fetch video transcript using youtube-transcript-api
def fetch_video_transcript(video_id):
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id)
return " ".join([entry["text"] for entry in transcript])
except Exception as e:
print(f"Error fetching transcript: {e}")
return None
# Fetch and preprocess video data
def fetch_and_preprocess_video_data(video_id):
metadata = fetch_video_metadata(video_id)
if not metadata:
return None
transcript = fetch_video_transcript(video_id)
# Preprocess the data
video_data = {
"video_id": video_id,
"video_link": f"https://www.youtube.com/watch?v={video_id}",
"title": metadata["snippet"]["title"],
"text": transcript if transcript else metadata["snippet"]["description"],
"channel": metadata["snippet"]["channelTitle"],
"channel_id": metadata["snippet"]["channelId"],
"date": metadata["snippet"]["publishedAt"],
"license": "Unknown",
"original_language": "Unknown",
"source_language": "Unknown",
"transcription_language": "Unknown",
"word_count": len(metadata["snippet"]["description"].split()),
"character_count": len(metadata["snippet"]["description"]),
}
return video_data
# Pydantic models for request validation
class VideoRequest(BaseModel):
video_id: str
class TextRequest(BaseModel):
text: str
class QueryRequest(BaseModel):
query: str
class PromptRequest(BaseModel):
prompt: str
# API Endpoints
@app.post("/classify")
async def classify(request: VideoRequest):
video_id = request.video_id
video_data = fetch_and_preprocess_video_data(video_id)
if not video_data:
raise HTTPException(status_code=400, detail="Failed to fetch video data")
result = classifier.classify(video_data["text"])
return {"result": result}
@app.post("/relevance")
async def relevance(request: VideoRequest):
video_id = request.video_id
video_data = fetch_and_preprocess_video_data(video_id)
if not video_data:
raise HTTPException(status_code=400, detail="Failed to fetch video data")
relevant = relevance_detector.detect_relevance(video_data["text"])
return {"relevant": relevant}
@app.post("/summarize")
async def summarize(request: VideoRequest):
video_id = request.video_id
video_data = fetch_and_preprocess_video_data(video_id)
if not video_data:
raise HTTPException(status_code=400, detail="Failed to fetch video data")
summary = summarizer.summarize(video_data["text"])
return {"summary": summary}
@app.post("/search")
async def search(request: QueryRequest):
query = request.query
if not query:
raise HTTPException(status_code=400, detail="No query provided")
results = search_engine.search(query)
return {"results": results}
@app.post("/topics")
async def topics(request: TextRequest):
text = request.text
if not text:
raise HTTPException(status_code=400, detail="No text provided")
result = topic_extractor.extract_topics([text])
return {"topics": result.to_dict()}
@app.post("/chat")
async def chat(request: PromptRequest):
prompt = request.prompt
if not prompt:
raise HTTPException(status_code=400, detail="No prompt provided")
response = chatbot.generate_response(prompt)
return {"response": response}
# Start the FastAPI app
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)