from fastapi import FastAPI, HTTPException from pydantic import BaseModel from app.model_utils import load_model_and_tokenizer, generate_summary from app.classifier import train_classifier, classify_text app = FastAPI() # Load model and tokenizer for the /rag endpoint model_name = "sshleifer/distilbart-cnn-6-6" # Example model model, tokenizer = load_model_and_tokenizer(model_name) # Dummy data and classifier for the /classification endpoint dummy_data = [ ("I feel very sad and hopeless.", "Depression"), ("I have trouble sleeping at night.", "Insomnia"), ("I am constantly worrying about everything.", "Anxiety"), ("I feel energetic and happy.", "Happiness"), ("My mood swings a lot and I feel irritable.", "Mood Disorder") ] classifier, vectorizer = train_classifier(dummy_data) class Prompt(BaseModel): prompt: str class ClassificationInput(BaseModel): data: str @app.post("/rag") def rag_endpoint(prompt: Prompt): try: summary = generate_summary(prompt.prompt, model, tokenizer) return {"summary": summary} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/classification") def classification_endpoint(input: ClassificationInput): try: category = classify_text(input.data, classifier, vectorizer) return {"category": category} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)