obtu-ai / routes.py
Jose Benitez
clean code
025cc15
raw
history blame
5.52 kB
# routes.py
from fastapi import APIRouter, Depends, Request
from starlette.responses import RedirectResponse
from auth import oauth
from database import get_or_create_user, update_user_credits, get_user_by_id
from authlib.integrations.starlette_client import OAuthError
import gradio as gr
from utils.stripe_utils import create_checkout_session, verify_webhook, retrieve_stripe_session
from config import DOMAIN
router = APIRouter()
def get_user(request: Request):
user_data = request.session.get('user')
if user_data:
# Refresh user data from the database
user = get_user_by_id(user_data['id'])
request.session['user'] = user # Update session with fresh data
return user['name']
return None
@router.get('/')
def public(request: Request, user = Depends(get_user)):
root_url = gr.route_utils.get_root_url(request, "/", None)
print(f'Root URL: {root_url}')
if user:
return RedirectResponse(url=f'{root_url}/gradio/')
else:
return RedirectResponse(url=f'{root_url}/main/')
@router.route('/logout')
async def logout(request: Request):
request.session.pop('user', None)
return RedirectResponse(url='/')
@router.route('/login')
async def login(request: Request):
root_url = gr.route_utils.get_root_url(request, "/login", None)
redirect_uri = f"{root_url}/auth"
return await oauth.google.authorize_redirect(request, redirect_uri)
@router.route('/auth')
async def auth(request: Request):
try:
token = await oauth.google.authorize_access_token(request)
user_info = token.get('userinfo')
if user_info:
google_id = user_info['sub']
email = user_info['email']
name = user_info['name']
given_name = user_info['given_name']
profile_picture = user_info.get('picture', '')
user = get_or_create_user(google_id, email, name, given_name, profile_picture)
request.session['user'] = user
return RedirectResponse(url='/gradio')
else:
return RedirectResponse(url='/main')
except OAuthError as e:
print(f"OAuth Error: {str(e)}")
return RedirectResponse(url='/main')
# Handle Stripe payments
@router.get("/buy_credits")
async def buy_credits(request: Request):
user = request.session.get('user')
if not user:
return {"error": "User not authenticated"}
session = create_checkout_session(100, 50, user['id'], request) # $1 for 50 credits
# Store the session ID and user ID in the session
request.session['stripe_session_id'] = session['id']
request.session['user_id'] = user['id']
print(f"Stripe session created: {session['id']} for user {user['id']}")
return RedirectResponse(session['url'])
@router.post("/webhook")
async def stripe_webhook(request: Request):
payload = await request.body()
sig_header = request.headers.get("Stripe-Signature")
event = verify_webhook(payload, sig_header)
if event is None:
return {"error": "Invalid payload or signature"}
if event['type'] == 'checkout.session.completed':
session = event['data']['object']
user_id = session.get('client_reference_id')
if user_id:
# Fetch the user from the database
user = get_user_by_id(user_id) # You'll need to implement this function
if user:
# Update user's credits
new_credits = user['generation_credits'] + 50 # Assuming 50 credits were purchased
update_user_credits(user['id'], new_credits, user['train_credits'])
print(f"Credits updated for user {user['id']}")
else:
print(f"User not found for ID: {user_id}")
else:
print("No client_reference_id found in the session")
return {"status": "success"}
@router.get("/cancel")
async def payment_cancel(request: Request):
print("Payment cancelled")
user = request.session.get('user')
print(user)
if user:
return RedirectResponse(url='/gradio', status_code=303)
return RedirectResponse(url='/login', status_code=303)
@router.get('/success')
async def payment_success(request: Request):
print('Payment Sucess')
stripe_session_id = request.session.get('stripe_session_id')
user_id = request.session.get('user_id')
print(user_id)
if stripe_session_id and user_id:
# Retrieve the Stripe session
stripe_session = retrieve_stripe_session(stripe_session_id)
if stripe_session.get('payment_status') == 'paid':
user = get_user_by_id(user_id)
if user:
# Update the session with the latest user data
request.session['user'] = user
print(f"\nUser session updated: {user}\n")
# Clear the stripe_session_id and user_id from the session
request.session.pop('stripe_session_id', None)
request.session.pop('user_id', None)
root_url = DOMAIN
return RedirectResponse(url=f'{root_url}/gradio/', status_code=303)
else:
print(f"User not found for ID: {user_id}")
else:
print(f"Payment not completed for session: {stripe_session_id}")
else:
print("No Stripe session ID or user ID found in the session")
return RedirectResponse(url='/login', status_code=303)