import datetime import pytz from flask import current_app, request from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse from constants.languages import supported_language from controllers.console import api from controllers.console.setup import setup_required from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, InvalidInvitationCodeError, RepeatPasswordNotMatchError, ) from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required from models.account import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError class AccountInitApi(Resource): @setup_required @login_required def post(self): account = current_user if account.status == 'active': raise AccountAlreadyInitedError() parser = reqparse.RequestParser() if current_app.config['EDITION'] == 'CLOUD': parser.add_argument('invitation_code', type=str, location='json') parser.add_argument( 'interface_language', type=supported_language, required=True, location='json') parser.add_argument('timezone', type=timezone, required=True, location='json') args = parser.parse_args() if current_app.config['EDITION'] == 'CLOUD': if not args['invitation_code']: raise ValueError('invitation_code is required') # check invitation code invitation_code = db.session.query(InvitationCode).filter( InvitationCode.code == args['invitation_code'], InvitationCode.status == 'unused', ).first() if not invitation_code: raise InvalidInvitationCodeError() invitation_code.status = 'used' invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id account.interface_language = args['interface_language'] account.timezone = args['timezone'] account.interface_theme = 'light' account.status = 'active' account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() return {'result': 'success'} class AccountProfileApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def get(self): return current_user class AccountNameApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, location='json') args = parser.parse_args() # Validate account name length if len(args['name']) < 3 or len(args['name']) > 30: raise ValueError( "Account name must be between 3 and 30 characters.") updated_account = AccountService.update_account(current_user, name=args['name']) return updated_account class AccountAvatarApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument('avatar', type=str, required=True, location='json') args = parser.parse_args() updated_account = AccountService.update_account(current_user, avatar=args['avatar']) return updated_account class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument( 'interface_language', type=supported_language, required=True, location='json') args = parser.parse_args() updated_account = AccountService.update_account(current_user, interface_language=args['interface_language']) return updated_account class AccountInterfaceThemeApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument('interface_theme', type=str, choices=[ 'light', 'dark'], required=True, location='json') args = parser.parse_args() updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme']) return updated_account class AccountTimezoneApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument('timezone', type=str, required=True, location='json') args = parser.parse_args() # Validate timezone string, e.g. America/New_York, Asia/Shanghai if args['timezone'] not in pytz.all_timezones: raise ValueError("Invalid timezone string.") updated_account = AccountService.update_account(current_user, timezone=args['timezone']) return updated_account class AccountPasswordApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument('password', type=str, required=False, location='json') parser.add_argument('new_password', type=str, required=True, location='json') parser.add_argument('repeat_new_password', type=str, required=True, location='json') args = parser.parse_args() if args['new_password'] != args['repeat_new_password']: raise RepeatPasswordNotMatchError() try: AccountService.update_account_password( current_user, args['password'], args['new_password']) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() return {"result": "success"} class AccountIntegrateApi(Resource): integrate_fields = { 'provider': fields.String, 'created_at': TimestampField, 'is_bound': fields.Boolean, 'link': fields.String } integrate_list_fields = { 'data': fields.List(fields.Nested(integrate_fields)), } @setup_required @login_required @account_initialization_required @marshal_with(integrate_list_fields) def get(self): account = current_user account_integrates = db.session.query(AccountIntegrate).filter( AccountIntegrate.account_id == account.id).all() base_url = request.url_root.rstrip('/') oauth_base_path = "/console/api/oauth/login" providers = ["github", "google"] integrate_data = [] for provider in providers: existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) if existing_integrate: integrate_data.append({ 'id': existing_integrate.id, 'provider': provider, 'created_at': existing_integrate.created_at, 'is_bound': True, 'link': None }) else: integrate_data.append({ 'id': None, 'provider': provider, 'created_at': None, 'is_bound': False, 'link': f'{base_url}{oauth_base_path}/{provider}' }) return {'data': integrate_data} # Register API resources api.add_resource(AccountInitApi, '/account/init') api.add_resource(AccountProfileApi, '/account/profile') api.add_resource(AccountNameApi, '/account/name') api.add_resource(AccountAvatarApi, '/account/avatar') api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language') api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme') api.add_resource(AccountTimezoneApi, '/account/timezone') api.add_resource(AccountPasswordApi, '/account/password') api.add_resource(AccountIntegrateApi, '/account/integrates') # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')