Source code for registry.oauth2

"""
OAuth2 (RFC6749) implementation, using :mod:`authlib`.

This module extends the :mod:`authlib.flask` implementation, leveraging client
data stored in :mod:`registry.services.datastore` and instantiating authorized
sessions in :mod:`registry.services.sessions`.

The current implementation supports the `client_credentials` and
`authorization_code` grants.

.. todo:: Implement backend & integration to control client endorsements.


"""

from typing import List, Optional, Any
import hashlib
from datetime import timedelta, datetime
from flask import Request, Flask, current_app, request
from authlib.flask.oauth2 import AuthorizationServer
from authlib.specs.rfc6749 import ClientMixin, grants, OAuth2Request, \
    OAuth2Error
from authlib.common.security import generate_token

from arxiv.base.globals import get_application_config, get_application_global
from arxiv.base import logging
from arxiv import taxonomy
from ..services import datastore, sessions
from .. import domain

logger = logging.getLogger(__name__)


[docs]class OAuth2User(object): """ Represents the resource owner in OAuth2 workflows. This is a thin wrapper around :class:`domain.User` to support Authlib integration. """ def __init__(self, user: domain.User) -> None: """Initialize with a :class:`domain.User`.""" self._user = user
[docs] def get_user_id(self) -> str: """Get the ID of the user.""" return self._user.user_id
[docs] def get_user_email(self) -> str: """Get the email address of the user.""" return self._user.email
[docs] def get_username(self) -> str: """Get the username of the user.""" return self._user.username
[docs]class OAuth2AuthorizationCode(object): """Wraps :class:`domain.AuthorizationCode` for use in OAuth2 workflows.""" _fields = ['user_id', 'username', 'user_email', 'client_id', 'redirect_uri', 'scope', 'code', 'created', 'expires'] def __init__(self, auth_code: domain.AuthorizationCode) -> None: """Initialize with the wrapped :class:`domain.AuthorizationCode`.""" self._code = auth_code def __getattr__(self, key: str) -> Any: """Get an attribute from the wrapped :class:`.AuthorizationCode`.""" if key in self._fields: return getattr(self._code, key) raise AttributeError(f'No attribute {key}')
[docs] def is_expired(self) -> bool: """Indicate whether the code is expired.""" return self._code.expires <= datetime.now()
[docs] def get_redirect_uri(self) -> str: """Get the authorization code's redirect URI.""" return self._code.redirect_uri
[docs] def get_scope(self) -> str: """Get the scope for the authorization code.""" return self._code.scope
[docs]class OAuth2Client(ClientMixin): """ Implementation of an OAuth2 client as described in RFC6749. This class essentially wraps an aggregate of registry domain objects for a particular client, and implements methods expected by the :class:`AuthorizationServer`. """ def __init__(self, client: domain.Client, credential: domain.ClientCredential, authorizations: List[domain.ClientAuthorization], grant_types: List[domain.ClientGrantType]) -> None: """Initialize with domain data about a client.""" logger.debug('New OAuth2Client with client_id %s', client.client_id) self._client = client self._credential = credential self._scopes = set([str(auth.scope) for auth in authorizations]) self._grant_types = [gtype.grant_type for gtype in grant_types] @property def name(self) -> str: """Get the client name.""" return self._client.name @property def description(self) -> str: """Get the client description.""" return self._client.description @property def scopes(self) -> List[str]: """Authorized scopes as a list.""" return list(self._scopes) @property def url(self) -> str: """Get the client URL.""" return self._client.url @property def client_id(self) -> str: """Get the client ID.""" return self._client.client_id
[docs] def check_client_secret(self, client_secret: str) -> bool: """Check that the provided client secret is correct.""" logger.debug('Check client secret %s', client_secret) hashed = hashlib.sha256(client_secret.encode('utf-8')).hexdigest() return self._credential.client_secret == hashed
[docs] def check_grant_type(self, grant_type: str) -> bool: """Check that the client is authorized for the proposed grant type.""" logger.debug('Check grant type %s', grant_type) return grant_type in self._grant_types
[docs] def check_redirect_uri(self, redirect_uri: str) -> bool: """Check that the provided redirect URI is authorized.""" logger.debug('Check redirect URI: %s, %s', redirect_uri, self._client.redirect_uri) return redirect_uri == self._client.redirect_uri
[docs] def check_requested_scopes(self, scopes: set) -> bool: """Check that the requested scopes are authorized for this client.""" # If there is an active user on the session, ensure that we are not # granting scopes for which the user themself is not authorized. logger.debug('Client requests scopes: %s', scopes) if request.session and request.session.user: session_scopes = { str(s) for s in request.session.authorizations.scopes } logger.debug('Authorized scopes on user session: %s', session_scopes) return self._scopes.issuperset(scopes) and \ session_scopes.issuperset(scopes) return self._scopes.issuperset(scopes)
[docs] def check_response_type(self, response_type: str) -> bool: """Check the proposed response type.""" logger.debug('Check response type: %s', response_type) return response_type == 'code'
[docs] def check_token_endpoint_auth_method(self, method: str) -> bool: """Force POST auth method.""" logger.debug('Check endpoint auth method: %s', method) return method == 'client_secret_post'
[docs] def get_default_redirect_uri(self) -> str: """Get the default redirect URI for the client.""" return self._client.redirect_uri
[docs] def has_client_secret(self) -> bool: """Check that the client has a secret.""" logger.debug('Check has client secret') return self._credential.client_secret is not None
[docs]class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): """Authorization code grant for arXiv users.""" EXPIRES = 3600 TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
[docs] def create_authorization_code(self, client: OAuth2Client, grant_user: OAuth2User, request: OAuth2Request) -> str: """ Generate and store a new authorization code. Parameters ---------- client : :class:`OAuth2Client` The client requesting authorization. grant_user : :class:`OAuth2User` The resource owner who has granted authorization to the client. request : :class:`OAuth2Request` The request wrapper containing request details. Returns ------- str An authorization code that the client can exchange for an access token. """ code = generate_token(48) created = datetime.now() datastore.save_auth_code(domain.AuthorizationCode( code=code, user_id=grant_user.get_user_id(), username=grant_user.get_username(), user_email=grant_user.get_user_email(), redirect_uri=request.redirect_uri, scope=request.scope, client_id=client.client_id, created=created, expires=created + timedelta(seconds=self.EXPIRES) )) return code
[docs] def parse_authorization_code(self, code: str, client: OAuth2Client) \ -> Optional[domain.AuthorizationCode]: """Attempt to retrieve an auth code for an API client.""" logger.debug('Parse authorization code %s for %s', code, client) try: code_grant = OAuth2AuthorizationCode( datastore.load_auth_code(code, client.client_id) ) except datastore.NoSuchAuthCode as e: logger.debug(f'No such auth code: {code}') return if code_grant.is_expired(): return return code_grant
[docs] def delete_authorization_code(self, auth_code: OAuth2AuthorizationCode) \ -> None: """Delete an auth code.""" datastore.delete_auth_code(auth_code.code, auth_code.client_id)
[docs] def authenticate_user(self, auth_code: OAuth2AuthorizationCode) \ -> OAuth2User: """Authenticate the user implicated in the auth code.""" code_grant = OAuth2AuthorizationCode( datastore.load_auth_code_by_user(auth_code.code, auth_code.user_id) ) return OAuth2User(domain.User( user_id=code_grant.user_id, email=code_grant.user_email, username=code_grant.username ))
[docs]class ClientCredentialsGrant(grants.ClientCredentialsGrant): """Our client credentials grant supports only POST requests.""" TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
[docs]def get_client(client_id: str) -> Optional[OAuth2Client]: """ Load client data and generate a :class:`OAuth2Client`. Parameters ---------- client_id : str Returns ------- :class:`OAuth2Client` or None If the client is not found, returns `None`. """ logger.debug('Get client with ID %s', client_id) try: client = OAuth2Client(*datastore.load_client(client_id)) logger.debug('Got client %s', client_id) except datastore.NoSuchClient as e: logger.debug('No such client %s: %s', client_id, e) return None return client
[docs]def save_token(token: dict, oauth_request: OAuth2Request) -> None: """ Persist an auth token as a :class:`domain.Session`. We use the access token as the session ID. This makes for a fast lookup by the :mod:`authenticator` service. Parameters ---------- token : dict Token data generated by the OAuth2 :class:`AuthorizationServer`. At this point the token has not been stored. oauth_request : :class:`OAuth2Request` Wrapper for OAuth2-related request data. """ logger.debug("Persist token: %s", token) session_id = token['access_token'] client = oauth_request.client logger.debug("Client has scopes %s", client.scopes) user = oauth_request.user._user if oauth_request.user else None authorizations = domain.Authorizations( scopes=client.scopes, endorsements=get_endorsements(client) ) session = sessions.create(authorizations, request.remote_addr, request.remote_addr, user=user, client=client._client, session_id=session_id) logger.debug('Created session %s', session.session_id)
[docs]def get_endorsements(client: domain.Client) -> List[domain.Category]: """ Get endorsed categories for a client. The current implementation just returns all categories. Parameters ---------- client : :class:`domain.Client` Returns ------- list Each item is a :class:`domain.Category`. """ return [domain.Category('*', '*')]
[docs]def create_server() -> AuthorizationServer: """Instantiate and configure an :class:`AuthorizationServer`.""" server = AuthorizationServer(query_client=get_client, save_token=save_token) server.register_grant(ClientCredentialsGrant) server.register_grant(AuthorizationCodeGrant) logger.debug('Created server %s', id(server)) return server
[docs]def init_app(app: Flask) -> None: """Attach an :class:`AuthorizationServer` to a :class:`Flask` app.""" server = create_server() server.init_app(app) app.server = server