|
from __future__ import annotations |
|
|
|
import hashlib |
|
import os |
|
import typing |
|
import urllib.parse |
|
import warnings |
|
from dataclasses import dataclass, field |
|
|
|
import fastapi |
|
from fastapi.responses import RedirectResponse |
|
from huggingface_hub import HfFolder, whoami |
|
|
|
from .utils import get_space |
|
|
|
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID") |
|
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET") |
|
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES") |
|
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL") |
|
|
|
|
|
def attach_oauth(app: fastapi.FastAPI): |
|
try: |
|
from starlette.middleware.sessions import SessionMiddleware |
|
except ImportError as e: |
|
raise ImportError( |
|
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add " |
|
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies." |
|
) from e |
|
|
|
|
|
|
|
|
|
if get_space() is not None: |
|
_add_oauth_routes(app) |
|
else: |
|
_add_mocked_oauth_routes(app) |
|
|
|
|
|
|
|
|
|
session_secret = (OAUTH_CLIENT_SECRET or "") + "-v2" |
|
|
|
|
|
app.add_middleware( |
|
SessionMiddleware, |
|
secret_key=hashlib.sha256(session_secret.encode()).hexdigest(), |
|
same_site="none", |
|
https_only=True, |
|
) |
|
|
|
|
|
def _add_oauth_routes(app: fastapi.FastAPI) -> None: |
|
"""Add OAuth routes to the FastAPI app (login, callback handler and logout).""" |
|
try: |
|
from authlib.integrations.starlette_client import OAuth |
|
except ImportError as e: |
|
raise ImportError( |
|
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add " |
|
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies." |
|
) from e |
|
|
|
|
|
msg = ( |
|
"OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by" |
|
" setting `hf_oauth: true` in the Space metadata." |
|
) |
|
if OAUTH_CLIENT_ID is None: |
|
raise ValueError(msg.format("OAUTH_CLIENT_ID")) |
|
if OAUTH_CLIENT_SECRET is None: |
|
raise ValueError(msg.format("OAUTH_CLIENT_SECRET")) |
|
if OAUTH_SCOPES is None: |
|
raise ValueError(msg.format("OAUTH_SCOPES")) |
|
if OPENID_PROVIDER_URL is None: |
|
raise ValueError(msg.format("OPENID_PROVIDER_URL")) |
|
|
|
|
|
oauth = OAuth() |
|
oauth.register( |
|
name="huggingface", |
|
client_id=OAUTH_CLIENT_ID, |
|
client_secret=OAUTH_CLIENT_SECRET, |
|
client_kwargs={"scope": OAUTH_SCOPES}, |
|
server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration", |
|
) |
|
|
|
|
|
@app.get("/login/huggingface") |
|
async def oauth_login(request: fastapi.Request): |
|
"""Endpoint that redirects to HF OAuth page.""" |
|
|
|
redirect_uri = _generate_redirect_uri(request) |
|
return await oauth.huggingface.authorize_redirect(request, redirect_uri) |
|
|
|
@app.get("/login/callback") |
|
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: |
|
"""Endpoint that handles the OAuth callback.""" |
|
oauth_info = await oauth.huggingface.authorize_access_token(request) |
|
request.session["oauth_info"] = oauth_info |
|
return _redirect_to_target(request) |
|
|
|
@app.get("/logout") |
|
async def oauth_logout(request: fastapi.Request) -> RedirectResponse: |
|
"""Endpoint that logs out the user (e.g. delete cookie session).""" |
|
request.session.pop("oauth_info", None) |
|
return _redirect_to_target(request) |
|
|
|
|
|
def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None: |
|
"""Add fake oauth routes if Gradio is run locally and OAuth is enabled. |
|
|
|
Clicking on a gr.LoginButton will have the same behavior as in a Space (i.e. gets redirected in a new tab) but |
|
instead of authenticating with HF, a mocked user profile is added to the session. |
|
""" |
|
warnings.warn( |
|
"Gradio does not support OAuth features outside of a Space environment. To help" |
|
" you debug your app locally, the login and logout buttons are mocked with your" |
|
" profile. To make it work, your machine must be logged in to Huggingface." |
|
) |
|
mocked_oauth_info = _get_mocked_oauth_info() |
|
|
|
|
|
@app.get("/login/huggingface") |
|
async def oauth_login(request: fastapi.Request): |
|
"""Fake endpoint that redirects to HF OAuth page.""" |
|
|
|
redirect_uri = _generate_redirect_uri(request) |
|
return RedirectResponse( |
|
"/login/callback?" + urllib.parse.urlencode({"_target_url": redirect_uri}) |
|
) |
|
|
|
@app.get("/login/callback") |
|
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: |
|
"""Endpoint that handles the OAuth callback.""" |
|
request.session["oauth_info"] = mocked_oauth_info |
|
return _redirect_to_target(request) |
|
|
|
@app.get("/logout") |
|
async def oauth_logout(request: fastapi.Request) -> RedirectResponse: |
|
"""Endpoint that logs out the user (e.g. delete cookie session).""" |
|
request.session.pop("oauth_info", None) |
|
logout_url = str(request.url).replace("/logout", "/") |
|
return RedirectResponse(url=logout_url) |
|
|
|
|
|
def _generate_redirect_uri(request: fastapi.Request) -> str: |
|
if "_target_url" in request.query_params: |
|
|
|
target = request.query_params["_target_url"] |
|
else: |
|
|
|
target = "/?" + urllib.parse.urlencode(request.query_params) |
|
|
|
redirect_uri = request.url_for("oauth_redirect_callback").include_query_params( |
|
_target_url=target |
|
) |
|
redirect_uri_as_str = str(redirect_uri) |
|
if redirect_uri.netloc.endswith(".hf.space"): |
|
|
|
redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://") |
|
return redirect_uri_as_str |
|
|
|
|
|
def _redirect_to_target( |
|
request: fastapi.Request, default_target: str = "/" |
|
) -> RedirectResponse: |
|
target = request.query_params.get("_target_url", default_target) |
|
return RedirectResponse(target) |
|
|
|
|
|
@dataclass |
|
class OAuthProfile(typing.Dict): |
|
""" |
|
A Gradio OAuthProfile object that can be used to inject the profile of a user in a |
|
function. If a function expects `OAuthProfile` or `Optional[OAuthProfile]` as input, |
|
the value will be injected from the FastAPI session if the user is logged in. If the |
|
user is not logged in and the function expects `OAuthProfile`, an error will be |
|
raised. |
|
|
|
Attributes: |
|
name (str): The name of the user (e.g. 'Abubakar Abid'). |
|
username (str): The username of the user (e.g. 'abidlabs') |
|
profile (str): The profile URL of the user (e.g. 'https://huggingface.co/abidlabs'). |
|
picture (str): The profile picture URL of the user. |
|
|
|
Example: |
|
import gradio as gr |
|
from typing import Optional |
|
|
|
|
|
def hello(profile: Optional[gr.OAuthProfile]) -> str: |
|
if profile is None: |
|
return "I don't know you." |
|
return f"Hello {profile.name}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.LoginButton() |
|
gr.LogoutButton() |
|
gr.Markdown().attach_load_event(hello, None) |
|
""" |
|
|
|
name: str = field(init=False) |
|
username: str = field(init=False) |
|
profile: str = field(init=False) |
|
picture: str = field(init=False) |
|
|
|
def __init__(self, data: dict): |
|
self.update(data) |
|
self.name = self["name"] |
|
self.username = self["preferred_username"] |
|
self.profile = self["profile"] |
|
self.picture = self["picture"] |
|
|
|
|
|
@dataclass |
|
class OAuthToken: |
|
""" |
|
A Gradio OAuthToken object that can be used to inject the access token of a user in a |
|
function. If a function expects `OAuthToken` or `Optional[OAuthToken]` as input, |
|
the value will be injected from the FastAPI session if the user is logged in. If the |
|
user is not logged in and the function expects `OAuthToken`, an error will be |
|
raised. |
|
|
|
Attributes: |
|
token (str): The access token of the user. |
|
scope (str): The scope of the access token. |
|
expires_at (int): The expiration timestamp of the access token. |
|
|
|
Example: |
|
import gradio as gr |
|
from typing import Optional |
|
from huggingface_hub import whoami |
|
|
|
|
|
def list_organizations(oauth_token: Optional[gr.OAuthToken]) -> str: |
|
if oauth_token is None: |
|
return "Please log in to list organizations." |
|
org_names = [org["name"] for org in whoami(oauth_token.token)["orgs"]] |
|
return f"You belong to {', '.join(org_names)}." |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.LoginButton() |
|
gr.LogoutButton() |
|
gr.Markdown().attach_load_event(list_organizations, None) |
|
""" |
|
|
|
token: str |
|
scope: str |
|
expires_at: int |
|
|
|
|
|
def _get_mocked_oauth_info() -> typing.Dict: |
|
token = HfFolder.get_token() |
|
if token is None: |
|
raise ValueError( |
|
"Your machine must be logged in to HF to debug a Gradio app locally. Please" |
|
" run `huggingface-cli login` or set `HF_TOKEN` as environment variable " |
|
"with one of your access token. You can generate a new token in your " |
|
"settings page (https://huggingface.co/settings/tokens)." |
|
) |
|
|
|
user = whoami() |
|
if user["type"] != "user": |
|
raise ValueError( |
|
"Your machine is not logged in with a personal account. Please use a " |
|
"personal access token. You can generate a new token in your settings page" |
|
" (https://huggingface.co/settings/tokens)." |
|
) |
|
|
|
return { |
|
"access_token": token, |
|
"token_type": "bearer", |
|
"expires_in": 3600, |
|
"id_token": "AAAAAAAAAAAAAAAAAAAAAAAAAA", |
|
"scope": "openid profile", |
|
"expires_at": 1691676444, |
|
"userinfo": { |
|
"sub": "11111111111111111111111", |
|
"name": user["fullname"], |
|
"preferred_username": user["name"], |
|
"profile": f"https://huggingface.co/{user['name']}", |
|
"picture": user["avatarUrl"], |
|
"website": "", |
|
"aud": "00000000-0000-0000-0000-000000000000", |
|
"auth_time": 1691672844, |
|
"nonce": "aaaaaaaaaaaaaaaaaaa", |
|
"iat": 1691672844, |
|
"exp": 1691676444, |
|
"iss": "https://huggingface.co", |
|
}, |
|
} |
|
|