from fastapi import FastAPI, HTTPException, Depends, status from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel from jose import JWTError, jwt from datetime import datetime, timedelta, timezone from openai import OpenAI from pathlib import Path from typing import List, Optional, Dict from datasets import Dataset, load_dataset from sentence_transformers import SentenceTransformer from huggingface_hub import login from contextlib import asynccontextmanager import pandas as pd import numpy as np import torch as t import os import logging from functools import lru_cache from diskcache import Cache # Configure logging logging.basicConfig(level=logging.INFO) @asynccontextmanager async def lifespan(app: FastAPI): # Preload the model get_sentence_transformer() yield # Initialize FastAPI app app = FastAPI() # Initialize disk cache cache = Cache('./cache') # JWT Configuration SECRET_KEY = os.environ.get("prime_auth", "c0369f977b69e717dc16f6fc574039eb2b1ebde38014d2be") REFRESH_SECRET_KEY = os.environ.get("prolonged_auth", "916018771b29084378c9362c0cd9e631fd4927b8aea07f91") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 REFRESH_TOKEN_EXPIRE_DAYS = 7 oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") # Pydantic models class QueryInput(BaseModel): query: str class SearchResult(BaseModel): text: str similarity: float model_type: str class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str class SaveInput(BaseModel): user_type: str username: str query: str retrieved_text: str model_type: str reaction: str class SaveBatchInput(BaseModel): items: List[SaveInput] class RefreshRequest(BaseModel): refresh_token: str # Cache management @lru_cache(maxsize=1) def get_sentence_transformer(): """Load and cache the SentenceTransformer model with lru_cache""" return SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device="cpu") def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]: """Try to get embeddings from cache""" cache_key = f"{model_type}_{hash(text)}" return cache.get(cache_key) def set_cached_embeddings(text: str, model_type: str, embeddings: List[float]): """Store embeddings in cache""" cache_key = f"{model_type}_{hash(text)}" cache.set(cache_key, embeddings, expire=86400) # Cache for 24 hours @lru_cache(maxsize=1) def load_dataframe(): """Load and cache the parquet dataframe""" database_file = Path(__file__).parent / "[all_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet" return pd.read_parquet(database_file) # Utility functions def cosine_similarity(embedding_0, embedding_1): dot_product = sum(a * b for a, b in zip(embedding_0, embedding_1)) norm_0 = sum(a * a for a in embedding_0) ** 0.5 norm_1 = sum(b * b for b in embedding_1) ** 0.5 return dot_product / (norm_0 * norm_1) def generate_embedding(model, text: str, model_type: str) -> List[float]: # Try to get from cache first cached_embedding = get_cached_embeddings(text, model_type) if cached_embedding is not None: return cached_embedding # Generate new embedding if not in cache if model_type == "all-mpnet-base-v2": chunk_embedding = model.encode( text, convert_to_tensor=True ) embedding = np.array(t.Tensor.cpu(chunk_embedding)).tolist() elif model_type == "text-embedding-3-small": response = model.embeddings.create( input=text, model="text-embedding-3-small" ) embedding = response.data[0].embedding # Cache the new embedding set_cached_embeddings(text, model_type, embedding) return embedding def search_query(client, st_model, query: str, df: pd.DataFrame, n: int = 1) -> List[Dict]: # Generate embeddings for both models mpnet_embedding = generate_embedding(st_model, query, "all-mpnet-base-v2") openai_embedding = generate_embedding(client, query, "text-embedding-3-small") # Calculate similarities df['mpnet_similarities'] = df.all_mpnet_embedding.apply( lambda x: cosine_similarity(x, mpnet_embedding) ) df['openai_similarities'] = df.openai_embedding.apply( lambda x: cosine_similarity(x, openai_embedding) ) # Get top results for each model mpnet_results = df.nlargest(n, 'mpnet_similarities') openai_results = df.nlargest(n, 'openai_similarities') # Format results results = [] for _, row in mpnet_results.iterrows(): results.append({ "text": row["ext"], "similarity": float(row["mpnet_similarities"]), "model_type": "all-mpnet-base-v2" }) for _, row in openai_results.iterrows(): results.append({ "text": row["ext"], "similarity": float(row["openai_similarities"]), "model_type": "text-embedding-3-small" }) return results # Authentication functions def load_credentials(): credentials = {} for i in range(1, 51): username = os.environ.get(f"login_{i}") password = os.environ.get(f"password_{i}") if username and password: credentials[username] = password return credentials def authenticate_user(username: str, password: str): credentials_dict = load_credentials() if username in credentials_dict and credentials_dict[username] == password: return username return None def create_token(data: dict, expires_delta: timedelta, secret_key: str): to_encode = data.copy() expire = datetime.utcnow() + expires_delta to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) return encoded_jwt def verify_token(token: str, secret_key: str): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise credentials_exception except JWTError: raise credentials_exception return username def verify_access_token(token: str = Depends(oauth2_scheme)): username = verify_token(token, SECRET_KEY) # Check if token is blacklisted if cache.get(f"blacklist_{token}"): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has been revoked", headers={"WWW-Authenticate": "Bearer"}, ) return username # Endpoints @app.get("/") def index() -> FileResponse: """Serve the custom HTML page from the static directory""" file_path = Path(__file__).parent / "static" / "index.html" return FileResponse(path=str(file_path), media_type="text/html") @app.post("/login", response_model=TokenResponse) def login_app(form_data: OAuth2PasswordRequestForm = Depends()): username = authenticate_user(form_data.username, form_data.password) if not username: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) access_token = create_token( data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) refresh_token = create_token( data={"sub": username}, expires_delta=refresh_token_expires, secret_key=REFRESH_SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer" } @app.post("/refresh", response_model=TokenResponse) async def refresh(refresh_request: RefreshRequest): """ Endpoint to refresh an access token using a valid refresh token. Returns a new access token and the existing refresh token. """ try: # Verify the refresh token username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) # Create new access token access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_token( data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_request.refresh_token, # Return the same refresh token "token_type": "bearer" } except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) @app.post("/logout") def logout( token: str = Depends(oauth2_scheme), username: str = Depends(verify_access_token) ): try: # Decode token to get expiration time payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) exp_timestamp = payload.get("exp") if exp_timestamp is None: raise HTTPException(status_code=400, detail="Token missing expiration time") # Calculate remaining token validity current_time = datetime.now(timezone.utc).timestamp() remaining_time = exp_timestamp - current_time if remaining_time > 0: # Add to blacklist cache with TTL matching token expiration cache_key = f"blacklist_{token}" cache.set(cache_key, True, expire=remaining_time) return {"message": "Successfully logged out"} except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"}, ) @app.post("/search", response_model=List[SearchResult]) async def search( query_input: QueryInput, username: str = Depends(verify_access_token), ): try: # Initialize clients using cached functions client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) st_model = get_sentence_transformer() df = load_dataframe() # Perform search with both models results = search_query(client, st_model, query_input.query, df, n=1) return [SearchResult(**result) for result in results] except Exception as e: logging.error(f"Search error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Search failed: {str(e)}" ) @app.post("/save") async def save_data( save_input: SaveBatchInput, username: str = Depends(verify_access_token) ): try: # Login to Hugging Face hf_token = os.environ.get("al_ghazali_rag_retrieval_evaluation") if not hf_token: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Hugging Face API token not found" ) login(token=hf_token) # Prepare data for saving data = { "user_type": [], "username": [], "query": [], "retrieved_text": [], "model_type": [], "reaction": [], "timestamp": [] } # Add each item to the data dict for item in save_input.items: data["user_type"].append(item.user_type) data["username"].append(item.username) data["query"].append(item.query) data["retrieved_text"].append(item.retrieved_text) data["model_type"].append(item.model_type) data["reaction"].append(item.reaction) data["timestamp"].append(datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')) try: # Load existing dataset and merge dataset = load_dataset( "HumbleBeeAI/al-ghazali-rag-retrieval-evaluation", split="train" ) existing_data = dataset.to_dict() # Add new data for key in data: if key not in existing_data: existing_data[key] = ["" if key in ["timestamp"] else None] * len(next(iter(existing_data.values()))) existing_data[key].extend(data[key]) except Exception as e: logging.warning(f"Could not load existing dataset, creating new one: {str(e)}") existing_data = data # Create and push dataset updated_dataset = Dataset.from_dict(existing_data) updated_dataset.push_to_hub( "HumbleBeeAI/al-ghazali-rag-retrieval-evaluation" ) return {"message": "Data saved successfully"} except Exception as e: logging.error(f"Save error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to save data: {str(e)}" ) # Make sure to keep the static files mounting app.mount("/home", StaticFiles(directory="static", html=True), name="home") # Startup event to create cache directory if it doesn't exist @app.on_event("startup") async def startup_event(): os.makedirs("./cache", exist_ok=True) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)