mominah commited on
Commit
97f13e6
·
verified ·
1 Parent(s): 982bc60

Create auth.py

Browse files
Files changed (1) hide show
  1. auth.py +208 -0
auth.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # auth.py
2
+ import os
3
+ import uuid
4
+ import logging
5
+ from datetime import datetime, timedelta
6
+ from urllib.parse import quote_plus
7
+ from typing import Optional
8
+
9
+ from dotenv import load_dotenv
10
+ from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form
11
+ from fastapi.responses import StreamingResponse
12
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
13
+ from jose import JWTError, jwt
14
+ from passlib.context import CryptContext
15
+ from pymongo import MongoClient
16
+ import gridfs
17
+
18
+ from models import User, UserUpdate, Token, LoginResponse
19
+
20
+ load_dotenv()
21
+
22
+ logger = logging.getLogger("uvicorn")
23
+ logger.setLevel(logging.INFO)
24
+
25
+ # MongoDB setup for user management
26
+ MONGO_URL = os.getenv("CONNECTION_STRING")
27
+ client = MongoClient(MONGO_URL)
28
+ db = client.users_database
29
+ users_collection = db.users
30
+ # GridFS instance for storing avatars
31
+ fs = gridfs.GridFS(db, collection="avatars")
32
+
33
+ # OAuth2 setup
34
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
35
+
36
+ router = APIRouter(prefix="/auth", tags=["auth"])
37
+
38
+ # Password hashing
39
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
40
+
41
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
42
+ return pwd_context.verify(plain_password, hashed_password)
43
+
44
+ def get_password_hash(password: str) -> str:
45
+ return pwd_context.hash(password)
46
+
47
+ def get_user(email: str) -> Optional[dict]:
48
+ return users_collection.find_one({"email": email})
49
+
50
+ def authenticate_user(email: str, password: str) -> Optional[dict]:
51
+ user = get_user(email)
52
+ if not user or not verify_password(password, user["hashed_password"]):
53
+ return None
54
+ return user
55
+
56
+ def create_token(data: dict, expires_delta: timedelta = None) -> str:
57
+ to_encode = data.copy()
58
+ expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
59
+ to_encode.update({"exp": expire})
60
+ secret_key = os.getenv("SECRET_KEY")
61
+ algorithm = "HS256"
62
+ return jwt.encode(to_encode, secret_key, algorithm=algorithm)
63
+
64
+ def create_access_token(email: str) -> str:
65
+ return create_token({"sub": email}, timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "90"))))
66
+
67
+ def create_refresh_token(email: str) -> str:
68
+ return create_token({"sub": email}, timedelta(days=int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))))
69
+
70
+ def get_current_user(token: str = Depends(oauth2_scheme)) -> dict:
71
+ secret_key = os.getenv("SECRET_KEY")
72
+ try:
73
+ payload = jwt.decode(token, secret_key, algorithms=["HS256"])
74
+ email: str = payload.get("sub")
75
+ if not email:
76
+ raise HTTPException(status_code=401, detail="Invalid credentials")
77
+ user = get_user(email)
78
+ if not user:
79
+ raise HTTPException(status_code=401, detail="User not found")
80
+ return user
81
+ except JWTError:
82
+ raise HTTPException(status_code=401, detail="Invalid token")
83
+
84
+ async def save_avatar_file_to_gridfs(file: UploadFile) -> str:
85
+ allowed_types = ["image/jpeg", "image/png", "image/gif"]
86
+ if file.content_type not in allowed_types:
87
+ logger.error(f"Unsupported file type: {file.content_type}")
88
+ raise HTTPException(
89
+ status_code=400,
90
+ detail="Invalid image format. Only JPEG, PNG, and GIF are accepted."
91
+ )
92
+ try:
93
+ contents = await file.read()
94
+ file_id = fs.put(contents, filename=file.filename, contentType=file.content_type)
95
+ logger.info(f"Avatar stored in GridFS with file_id: {file_id}")
96
+ return str(file_id)
97
+ except Exception as e:
98
+ logger.exception("Failed to store avatar in GridFS")
99
+ raise HTTPException(status_code=500, detail="Could not store avatar file in MongoDB.")
100
+
101
+ @router.post("/signup", response_model=Token)
102
+ async def signup(
103
+ request: Request,
104
+ name: str = Form(...),
105
+ email: str = Form(...),
106
+ password: str = Form(...),
107
+ avatar: Optional[UploadFile] = File(None)
108
+ ):
109
+ try:
110
+ _ = User(name=name, email=email, password=password)
111
+ except Exception as e:
112
+ logger.error(f"Validation error during signup: {e}")
113
+ raise HTTPException(status_code=400, detail=str(e))
114
+ if get_user(email):
115
+ logger.warning(f"Attempt to register already existing email: {email}")
116
+ raise HTTPException(status_code=400, detail="Email already registered")
117
+ hashed_password = get_password_hash(password)
118
+ user_data = {
119
+ "name": name,
120
+ "email": email,
121
+ "hashed_password": hashed_password,
122
+ "chat_histories": []
123
+ }
124
+ if avatar:
125
+ file_id = await save_avatar_file_to_gridfs(avatar)
126
+ user_data["avatar"] = file_id
127
+ users_collection.insert_one(user_data)
128
+ logger.info(f"New user registered: {email}")
129
+ return {
130
+ "access_token": create_access_token(email),
131
+ "refresh_token": create_refresh_token(email),
132
+ "token_type": "bearer"
133
+ }
134
+
135
+ @router.post("/login", response_model=LoginResponse)
136
+ async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
137
+ user = authenticate_user(form_data.username, form_data.password)
138
+ if not user:
139
+ logger.warning(f"Failed login attempt for: {form_data.username}")
140
+ raise HTTPException(status_code=401, detail="Incorrect username or password")
141
+ logger.info(f"User logged in: {user['email']}")
142
+ avatar_url = None
143
+ if "avatar" in user and user["avatar"]:
144
+ avatar_url = f"/auth/avatar/{user['avatar']}"
145
+ return {
146
+ "access_token": create_access_token(user["email"]),
147
+ "refresh_token": create_refresh_token(user["email"]),
148
+ "token_type": "bearer",
149
+ "name": user["name"],
150
+ "avatar": avatar_url
151
+ }
152
+
153
+ @router.get("/user/data")
154
+ async def get_user_data(request: Request, current_user: dict = Depends(get_current_user)):
155
+ avatar_url = None
156
+ if "avatar" in current_user and current_user["avatar"]:
157
+ avatar_url = f"/auth/avatar/{current_user['avatar']}"
158
+ return {
159
+ "name": current_user["name"],
160
+ "email": current_user["email"],
161
+ "avatar": avatar_url,
162
+ "chat_histories": current_user.get("chat_histories", [])
163
+ }
164
+
165
+ @router.put("/user/update")
166
+ async def update_user(
167
+ request: Request,
168
+ name: Optional[str] = Form(None),
169
+ email: Optional[str] = Form(None),
170
+ password: Optional[str] = Form(None),
171
+ avatar: Optional[UploadFile] = File(None),
172
+ current_user: dict = Depends(get_current_user)
173
+ ):
174
+ update_data = {}
175
+ if name is not None:
176
+ update_data["name"] = name
177
+ if email is not None:
178
+ update_data["email"] = email
179
+ if password is not None:
180
+ try:
181
+ _ = User(name=current_user["name"], email=current_user["email"], password=password)
182
+ except Exception as e:
183
+ logger.error(f"Password validation error during update: {e}")
184
+ raise HTTPException(status_code=400, detail=str(e))
185
+ update_data["hashed_password"] = get_password_hash(password)
186
+ if avatar:
187
+ file_id = await save_avatar_file_to_gridfs(avatar)
188
+ update_data["avatar"] = file_id
189
+ if not update_data:
190
+ logger.info("No update parameters provided")
191
+ raise HTTPException(status_code=400, detail="No update parameters provided")
192
+ users_collection.update_one({"email": current_user["email"]}, {"$set": update_data})
193
+ logger.info(f"User updated: {current_user['email']}")
194
+ return {"message": "User updated successfully"}
195
+
196
+ @router.post("/logout")
197
+ async def logout(request: Request, current_user: dict = Depends(get_current_user)):
198
+ logger.info(f"User logged out: {current_user['email']}")
199
+ return {"message": "User logged out successfully"}
200
+
201
+ @router.get("/avatar/{file_id}")
202
+ async def get_avatar(file_id: str):
203
+ try:
204
+ file = fs.get(file_id)
205
+ return StreamingResponse(file, media_type=file.content_type)
206
+ except Exception as e:
207
+ logger.error(f"Avatar not found for file_id {file_id}: {e}")
208
+ raise HTTPException(status_code=404, detail="Avatar not found")