"""
OAuth2 (RFC6749) implementation, using :mod:`authlib`.
This module extends the :mod:`authlib.flask` implementation, leveraging client
data stored in :mod:`registry.services.datastore` and instantiating authorized
sessions in :mod:`registry.services.sessions`.
The current implementation supports the `client_credentials` and
`authorization_code` grants.
.. todo:: Implement backend & integration to control client endorsements.
"""
from typing import List, Optional, Any
import hashlib
from datetime import timedelta, datetime
from flask import Request, Flask, current_app, request
from authlib.flask.oauth2 import AuthorizationServer
from authlib.specs.rfc6749 import ClientMixin, grants, OAuth2Request, \
OAuth2Error
from authlib.common.security import generate_token
from arxiv.base.globals import get_application_config, get_application_global
from arxiv.base import logging
from arxiv import taxonomy
from ..services import datastore, sessions
from .. import domain
logger = logging.getLogger(__name__)
[docs]class OAuth2User(object):
"""
Represents the resource owner in OAuth2 workflows.
This is a thin wrapper around :class:`domain.User` to support Authlib
integration.
"""
def __init__(self, user: domain.User) -> None:
"""Initialize with a :class:`domain.User`."""
self._user = user
[docs] def get_user_id(self) -> str:
"""Get the ID of the user."""
return self._user.user_id
[docs] def get_user_email(self) -> str:
"""Get the email address of the user."""
return self._user.email
[docs] def get_username(self) -> str:
"""Get the username of the user."""
return self._user.username
[docs]class OAuth2AuthorizationCode(object):
"""Wraps :class:`domain.AuthorizationCode` for use in OAuth2 workflows."""
_fields = ['user_id', 'username', 'user_email', 'client_id',
'redirect_uri', 'scope', 'code', 'created', 'expires']
def __init__(self, auth_code: domain.AuthorizationCode) -> None:
"""Initialize with the wrapped :class:`domain.AuthorizationCode`."""
self._code = auth_code
def __getattr__(self, key: str) -> Any:
"""Get an attribute from the wrapped :class:`.AuthorizationCode`."""
if key in self._fields:
return getattr(self._code, key)
raise AttributeError(f'No attribute {key}')
[docs] def is_expired(self) -> bool:
"""Indicate whether the code is expired."""
return self._code.expires <= datetime.now()
[docs] def get_redirect_uri(self) -> str:
"""Get the authorization code's redirect URI."""
return self._code.redirect_uri
[docs] def get_scope(self) -> str:
"""Get the scope for the authorization code."""
return self._code.scope
[docs]class OAuth2Client(ClientMixin):
"""
Implementation of an OAuth2 client as described in RFC6749.
This class essentially wraps an aggregate of registry domain objects for a
particular client, and implements methods expected by the
:class:`AuthorizationServer`.
"""
def __init__(self, client: domain.Client,
credential: domain.ClientCredential,
authorizations: List[domain.ClientAuthorization],
grant_types: List[domain.ClientGrantType]) -> None:
"""Initialize with domain data about a client."""
logger.debug('New OAuth2Client with client_id %s', client.client_id)
self._client = client
self._credential = credential
self._scopes = set([str(auth.scope) for auth in authorizations])
self._grant_types = [gtype.grant_type for gtype in grant_types]
@property
def name(self) -> str:
"""Get the client name."""
return self._client.name
@property
def description(self) -> str:
"""Get the client description."""
return self._client.description
@property
def scopes(self) -> List[str]:
"""Authorized scopes as a list."""
return list(self._scopes)
@property
def url(self) -> str:
"""Get the client URL."""
return self._client.url
@property
def client_id(self) -> str:
"""Get the client ID."""
return self._client.client_id
[docs] def check_client_secret(self, client_secret: str) -> bool:
"""Check that the provided client secret is correct."""
logger.debug('Check client secret %s', client_secret)
hashed = hashlib.sha256(client_secret.encode('utf-8')).hexdigest()
return self._credential.client_secret == hashed
[docs] def check_grant_type(self, grant_type: str) -> bool:
"""Check that the client is authorized for the proposed grant type."""
logger.debug('Check grant type %s', grant_type)
return grant_type in self._grant_types
[docs] def check_redirect_uri(self, redirect_uri: str) -> bool:
"""Check that the provided redirect URI is authorized."""
logger.debug('Check redirect URI: %s, %s',
redirect_uri, self._client.redirect_uri)
return redirect_uri == self._client.redirect_uri
[docs] def check_requested_scopes(self, scopes: set) -> bool:
"""Check that the requested scopes are authorized for this client."""
# If there is an active user on the session, ensure that we are not
# granting scopes for which the user themself is not authorized.
logger.debug('Client requests scopes: %s', scopes)
if request.session and request.session.user:
session_scopes = {
str(s) for s in request.session.authorizations.scopes
}
logger.debug('Authorized scopes on user session: %s',
session_scopes)
return self._scopes.issuperset(scopes) and \
session_scopes.issuperset(scopes)
return self._scopes.issuperset(scopes)
[docs] def check_response_type(self, response_type: str) -> bool:
"""Check the proposed response type."""
logger.debug('Check response type: %s', response_type)
return response_type == 'code'
[docs] def check_token_endpoint_auth_method(self, method: str) -> bool:
"""Force POST auth method."""
logger.debug('Check endpoint auth method: %s', method)
return method == 'client_secret_post'
[docs] def get_default_redirect_uri(self) -> str:
"""Get the default redirect URI for the client."""
return self._client.redirect_uri
[docs] def has_client_secret(self) -> bool:
"""Check that the client has a secret."""
logger.debug('Check has client secret')
return self._credential.client_secret is not None
[docs]class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
"""Authorization code grant for arXiv users."""
EXPIRES = 3600
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
[docs] def create_authorization_code(self, client: OAuth2Client,
grant_user: OAuth2User,
request: OAuth2Request) -> str:
"""
Generate and store a new authorization code.
Parameters
----------
client : :class:`OAuth2Client`
The client requesting authorization.
grant_user : :class:`OAuth2User`
The resource owner who has granted authorization to the client.
request : :class:`OAuth2Request`
The request wrapper containing request details.
Returns
-------
str
An authorization code that the client can exchange for an access
token.
"""
code = generate_token(48)
created = datetime.now()
datastore.save_auth_code(domain.AuthorizationCode(
code=code,
user_id=grant_user.get_user_id(),
username=grant_user.get_username(),
user_email=grant_user.get_user_email(),
redirect_uri=request.redirect_uri,
scope=request.scope,
client_id=client.client_id,
created=created,
expires=created + timedelta(seconds=self.EXPIRES)
))
return code
[docs] def parse_authorization_code(self, code: str, client: OAuth2Client) \
-> Optional[domain.AuthorizationCode]:
"""Attempt to retrieve an auth code for an API client."""
logger.debug('Parse authorization code %s for %s', code, client)
try:
code_grant = OAuth2AuthorizationCode(
datastore.load_auth_code(code, client.client_id)
)
except datastore.NoSuchAuthCode as e:
logger.debug(f'No such auth code: {code}')
return
if code_grant.is_expired():
return
return code_grant
[docs] def delete_authorization_code(self, auth_code: OAuth2AuthorizationCode) \
-> None:
"""Delete an auth code."""
datastore.delete_auth_code(auth_code.code, auth_code.client_id)
[docs] def authenticate_user(self, auth_code: OAuth2AuthorizationCode) \
-> OAuth2User:
"""Authenticate the user implicated in the auth code."""
code_grant = OAuth2AuthorizationCode(
datastore.load_auth_code_by_user(auth_code.code, auth_code.user_id)
)
return OAuth2User(domain.User(
user_id=code_grant.user_id,
email=code_grant.user_email,
username=code_grant.username
))
[docs]class ClientCredentialsGrant(grants.ClientCredentialsGrant):
"""Our client credentials grant supports only POST requests."""
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
[docs]def get_client(client_id: str) -> Optional[OAuth2Client]:
"""
Load client data and generate a :class:`OAuth2Client`.
Parameters
----------
client_id : str
Returns
-------
:class:`OAuth2Client` or None
If the client is not found, returns `None`.
"""
logger.debug('Get client with ID %s', client_id)
try:
client = OAuth2Client(*datastore.load_client(client_id))
logger.debug('Got client %s', client_id)
except datastore.NoSuchClient as e:
logger.debug('No such client %s: %s', client_id, e)
return None
return client
[docs]def save_token(token: dict, oauth_request: OAuth2Request) -> None:
"""
Persist an auth token as a :class:`domain.Session`.
We use the access token as the session ID. This makes for a fast lookup
by the :mod:`authenticator` service.
Parameters
----------
token : dict
Token data generated by the OAuth2 :class:`AuthorizationServer`.
At this point the token has not been stored.
oauth_request : :class:`OAuth2Request`
Wrapper for OAuth2-related request data.
"""
logger.debug("Persist token: %s", token)
session_id = token['access_token']
client = oauth_request.client
logger.debug("Client has scopes %s", client.scopes)
user = oauth_request.user._user if oauth_request.user else None
authorizations = domain.Authorizations(
scopes=client.scopes,
endorsements=get_endorsements(client)
)
session = sessions.create(authorizations, request.remote_addr,
request.remote_addr, user=user,
client=client._client, session_id=session_id)
logger.debug('Created session %s', session.session_id)
[docs]def get_endorsements(client: domain.Client) -> List[domain.Category]:
"""
Get endorsed categories for a client.
The current implementation just returns all categories.
Parameters
----------
client : :class:`domain.Client`
Returns
-------
list
Each item is a :class:`domain.Category`.
"""
return [domain.Category('*', '*')]
[docs]def create_server() -> AuthorizationServer:
"""Instantiate and configure an :class:`AuthorizationServer`."""
server = AuthorizationServer(query_client=get_client,
save_token=save_token)
server.register_grant(ClientCredentialsGrant)
server.register_grant(AuthorizationCodeGrant)
logger.debug('Created server %s', id(server))
return server
[docs]def init_app(app: Flask) -> None:
"""Attach an :class:`AuthorizationServer` to a :class:`Flask` app."""
server = create_server()
server.init_app(app)
app.server = server