Spaces:
Build error
Build error
from fastapi import FastAPI, HTTPException, Request | |
from pymongo import MongoClient | |
from pydantic import BaseModel | |
from passlib.context import CryptContext | |
from bson import ObjectId | |
from datetime import datetime, timedelta | |
import jwt | |
from collections import Counter | |
from fastapi.responses import JSONResponse | |
app = FastAPI() | |
# MongoDB connection | |
client = MongoClient( | |
"mongodb+srv://sarmadsiddiqui29:Rollno169@cluster0.uchmc.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0", | |
tls=True, | |
tlsAllowInvalidCertificates=True # For testing only, disable for production | |
) | |
db = client["annotations_db"] | |
# Password hashing context | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
# Secret key for JWT | |
SECRET_KEY = "your_secret_key" # Replace with a secure secret key | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Token expiration time | |
# In-memory variable to store the token | |
current_token = None | |
# MongoDB Collections | |
users_collection = db["users"] | |
stories_collection = db["stories"] | |
prompts_collection = db["prompts"] | |
summaries_collection = db["summaries"] | |
# Models | |
class User(BaseModel): | |
email: str | |
password: str | |
class Story(BaseModel): | |
story_id: str | |
story: str | |
# annotator_id is removed from the Story model | |
class Prompt(BaseModel): | |
story_id: str | |
prompt: str | |
annotator_id: int = None # Will be set automatically | |
class Summary(BaseModel): | |
story_id: str | |
summary: str | |
annotator_id: int =None # Add annotator_id to Summary model | |
# Serialize document function | |
def serialize_document(doc): | |
"""Convert a MongoDB document into a serializable dictionary.""" | |
if isinstance(doc, ObjectId): | |
return str(doc) | |
if isinstance(doc, dict): | |
return {k: serialize_document(v) for k, v in doc.items()} | |
if isinstance(doc, list): | |
return [serialize_document(i) for i in doc] | |
return doc | |
# Helper Functions | |
def hash_password(password: str) -> str: | |
return pwd_context.hash(password) | |
def verify_password(plain_password: str, hashed_password: str) -> bool: | |
return pwd_context.verify(plain_password, hashed_password) | |
def create_access_token(data: dict, expires_delta: timedelta = None): | |
to_encode = data.copy() | |
if expires_delta: | |
expire = datetime.utcnow() + expires_delta | |
else: | |
expire = datetime.utcnow() + timedelta(minutes=15) | |
to_encode.update({"exp": expire}) | |
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
def get_annotator_id() -> int: | |
if current_token is None: | |
raise HTTPException(status_code=401, detail="User not logged in") | |
try: | |
payload = jwt.decode(current_token, SECRET_KEY, algorithms=[ALGORITHM]) | |
return payload["annotator_id"] | |
except jwt.PyJWTError: | |
raise HTTPException(status_code=401, detail="Invalid token") | |
# Endpoints for user, story, prompt, and summary operations | |
# Register User | |
async def register_user(user: User): | |
if db.users.find_one({"email": user.email}): | |
raise HTTPException(status_code=400, detail="Email already registered") | |
user_data = { | |
"email": user.email, | |
"password": hash_password(user.password), | |
"annotator_id": db.users.count_documents({}) + 1 | |
} | |
db.users.insert_one(user_data) | |
return {"message": "User registered successfully", "annotator_id": user_data["annotator_id"]} | |
# Login User | |
async def login_user(user: User): | |
found_user = db.users.find_one({"email": user.email}) | |
if not found_user or not verify_password(user.password, found_user["password"]): | |
raise HTTPException(status_code=400, detail="Invalid email or password") | |
# Create access token and store it | |
global current_token | |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
current_token = create_access_token(data={"email": found_user["email"], "annotator_id": found_user["annotator_id"]}, | |
expires_delta=access_token_expires) | |
return {"access_token": current_token, "token_type": "bearer"} | |
# Add Story | |
async def add_story(story: Story): | |
# annotator_id is not needed when adding a story | |
if db.stories.find_one({"story_id": story.story_id}): | |
raise HTTPException(status_code=400, detail="Story already exists") | |
db.stories.insert_one(story.dict()) | |
return {"message": "Story added successfully"} | |
# Add Prompt | |
async def add_prompt(prompt: Prompt): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID | |
prompt.annotator_id = annotator_id # Assign annotator ID to the prompt | |
db.prompts.insert_one(prompt.dict()) | |
return {"message": "Prompt added successfully"} | |
# Add Summary | |
async def add_summary(summary: Summary): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID | |
summary.annotator_id = annotator_id # Assign annotator ID to the summary | |
db.summaries.insert_one(summary.dict()) | |
return {"message": "Summary added successfully"} | |
# Delete All Users | |
async def delete_all_users(): | |
result = db.users.delete_many({}) | |
return {"message": f"{result.deleted_count} users deleted"} | |
# Delete All Stories | |
async def delete_all_stories(): | |
result = db.stories.delete_many({}) | |
return {"message": f"{result.deleted_count} stories deleted"} | |
# Delete All Prompts | |
async def delete_all_prompts(): | |
result = db.prompts.delete_many({}) | |
return {"message": f"{result.deleted_count} prompts deleted"} | |
# Delete All Summaries | |
async def delete_all_summaries(): | |
result = db.summaries.delete_many({}) | |
return {"message": f"{result.deleted_count} summaries deleted"} | |
# Test MongoDB Connection | |
async def test_connection(): | |
try: | |
db.list_collection_names() | |
return {"message": "Connected to MongoDB successfully"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Display Story by ID | |
async def display_story(story_id: str): | |
story = db.stories.find_one({"story_id": story_id}) | |
if story: | |
return serialize_document(story) # Serialize the story document | |
raise HTTPException(status_code=404, detail="Story not found") | |
# Display All for a Given Annotator ID | |
from fastapi import Query | |
from fastapi import Query, HTTPException | |
async def display_all(story_id: str = Query(...)): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Fetch the specific prompt associated with the provided story_id for the current annotator | |
prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
if not prompt: | |
raise HTTPException(status_code=404, detail="Prompt not found for this annotator and story ID") | |
# Fetch the corresponding story | |
story = db.stories.find_one({"story_id": story_id}) or {"story": ""} | |
# Fetch the summary for the specific annotator | |
summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id}) or {"summary": ""} | |
# Prepare the result | |
result = { | |
"story_id": story_id, | |
"story": story["story"], # Get the story text | |
"annotator_id": prompt["annotator_id"], | |
"summary": summary.get("summary", ""), # Use empty string if summary not found | |
"prompt": prompt.get("prompt", "") # Use empty string if prompt not found | |
} | |
return serialize_document(result) # Serialize the story document | |
async def delete_prompt(story_id: str): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Find and delete all prompts associated with the provided story_id for the current annotator | |
result = db.prompts.delete_many({"story_id": story_id, "annotator_id": annotator_id}) | |
if result.deleted_count > 0: | |
return {"message": f"{result.deleted_count} prompt(s) deleted successfully"} | |
else: | |
raise HTTPException(status_code=404, detail="No prompts found for this annotator and story ID") | |
async def delete_summary(story_id: str): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Find and delete all summaries associated with the provided story_id for the current annotator | |
result = db.summaries.delete_many({"story_id": story_id, "annotator_id": annotator_id}) | |
if result.deleted_count > 0: | |
return {"message": f"{result.deleted_count} summary(ies) deleted successfully"} | |
else: | |
raise HTTPException(status_code=404, detail="No summaries found for this annotator and story ID") | |
async def delete_story(story_id: str): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Find and delete the story associated with the provided story_id for the current annotator | |
story_result = db.stories.delete_one({"story_id": story_id}) | |
# Delete all prompts associated with the provided story_id for the current annotator | |
prompts_result = db.prompts.delete_many({"story_id": story_id, "annotator_id": annotator_id}) | |
# Delete all summaries associated with the provided story_id for the current annotator | |
summaries_result = db.summaries.delete_many({"story_id": story_id, "annotator_id": annotator_id}) | |
if story_result.deleted_count > 0: | |
return { | |
"message": f"Story deleted successfully", | |
"deleted_prompts": prompts_result.deleted_count, | |
"deleted_summaries": summaries_result.deleted_count, | |
} | |
else: | |
raise HTTPException(status_code=404, detail="Story not found for this annotator") | |
async def update_story(story_id: str, updated_story: Story): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Check if the story exists and belongs to the current annotator | |
existing_story = db.stories.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
if not existing_story: | |
raise HTTPException(status_code=404, detail="Story not found or does not belong to this annotator") | |
# Update the story | |
db.stories.update_one({"story_id": story_id}, {"$set": {"story": updated_story.story}}) | |
return {"message": "Story updated successfully"} | |
async def update_prompt(story_id: str, updated_prompt: Prompt): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Check if the prompt exists and belongs to the current annotator | |
existing_prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
if not existing_prompt: | |
raise HTTPException(status_code=404, detail="Prompt not found or does not belong to this annotator") | |
# Update the prompt | |
db.prompts.update_one({"story_id": story_id, "annotator_id": annotator_id}, {"$set": {"prompt": updated_prompt.prompt}}) | |
return {"message": "Prompt updated successfully"} | |
async def update_summary(story_id: str, updated_summary: Summary): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Check if the summary exists and belongs to the current annotator | |
existing_summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
if not existing_summary: | |
raise HTTPException(status_code=404, detail="Summary not found or does not belong to this annotator") | |
# Update the summary | |
db.summaries.update_one({"story_id": story_id, "annotator_id": annotator_id}, {"$set": {"summary": updated_summary.summary}}) | |
return {"message": "Summary updated successfully"} | |
async def get_prompt(story_id: str): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Retrieve the prompt associated with the story_id for the current annotator | |
prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
if prompt: | |
return {"story_id": story_id, "prompt": prompt.get("prompt", "")} # Return prompt or empty string | |
else: | |
return {"story_id": story_id, "prompt": ""} # Return empty if no prompt found | |
async def get_summary(story_id: str): | |
annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
# Retrieve the summary associated with the story_id for the current annotator | |
summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
if summary: | |
return {"story_id": story_id, "summary": summary.get("summary", "")} # Return summary or empty string | |
else: | |
return {"story_id": story_id, "summary": ""} # Return empty if no summary found | |
async def get_story(story_id: str): | |
# Retrieve the story associated with the story_id | |
story = db.stories.find_one({"story_id": story_id}) | |
if story: | |
return {"story_id": story_id, "story": story.get("story", "")} # Return story text or empty string | |
else: | |
return {"story_id": story_id, "story": ""} # Return empty if no story found | |
async def get_annotators(): | |
# Fetch all prompts synchronously | |
prompts = prompts_collection.find() # Get cursor | |
# Count prompts by annotator_id | |
annotator_counts = Counter(prompt['annotator_id'] for prompt in prompts if 'annotator_id' in prompt) | |
# Convert the Counter to a list of dictionaries | |
annotators = [{"annotator_id": annotator_id, "prompt_count": count} for annotator_id, count in annotator_counts.items()] | |
return JSONResponse(content=annotators) |