Source code for arxiv.submission.services.classic.util
"""Utility classes and functions for :mod:`.services.classic`."""
import json
from contextlib import contextmanager
from typing import Optional, Generator
from flask import Flask
from sqlalchemy import create_engine
import sqlalchemy.types as types
from sqlalchemy.engine import Engine
from sqlalchemy.orm.session import Session
from sqlalchemy.orm import sessionmaker
from flask_sqlalchemy import SQLAlchemy
from arxiv.base.globals import get_application_config, get_application_global
from arxiv.base import logging
from .exceptions import ClassicBaseException, TransactionFailed
from ...exceptions import InvalidEvent
from ... import serializer
[docs]class ClassicSQLAlchemy(SQLAlchemy):
    """SQLAlchemy integration for the classic database."""
[docs]    def init_app(self, app: Flask) -> None:
        """Set default configuration."""
        app.config.setdefault(
            'SQLALCHEMY_DATABASE_URI',
            app.config.get('CLASSIC_DATABASE_URI', 'sqlite://')
        )
        app.config.setdefault('SQLALCHEMY_TRACK_MODIFICATIONS', False)
        super(ClassicSQLAlchemy, self).init_app(app) 
[docs]    def apply_pool_defaults(self, app, options):
        """Set options for create_engine()."""
        super(ClassicSQLAlchemy, self).apply_pool_defaults(app, options)
        if app.config['SQLALCHEMY_DATABASE_URI'].startswith('mysql'):
            options['json_serializer'] = serializer.dumps
            options['json_deserializer'] = serializer.loads  
db: SQLAlchemy = ClassicSQLAlchemy()
logger = logging.getLogger(__name__)
[docs]class SQLiteJSON(types.TypeDecorator):
    """A SQLite-friendly JSON data type."""
    impl = types.TEXT
[docs]    def process_bind_param(self, value: Optional[dict], dialect: str) -> str:
        """Serialize a dict to JSON."""
        if value is not None:
            value = serializer.dumps(value)
        return value 
[docs]    def process_result_value(self, value: str, dialect: str) -> Optional[dict]:
        """Deserialize JSON content to a dict."""
        if value is not None:
            value = serializer.loads(value)
        return value  
# SQLite does not support JSON, so we extend JSON to use our custom data type
# as a variant for the 'sqlite' dialect.
FriendlyJSON = types.JSON().with_variant(SQLiteJSON, 'sqlite')
[docs]def current_engine() -> Engine:
    """Get/create :class:`.Engine` for this context."""
    return db.engine 
[docs]def current_session() -> Session:
    """Get/create :class:`.Session` for this context."""
    return db.session() 
[docs]@contextmanager
def transaction() -> Generator:
    """Context manager for database transaction."""
    session = current_session()
    logger.debug('transaction with session %s', id(session))
    try:
        yield session
        # Only commit if there are un-flushed changes. The caller may commit
        # explicitly, e.g. to do exception handling.
        if session.dirty or session.deleted or session.new:
            session.commit()
        logger.debug('committed!')
    except ClassicBaseException as e:
        logger.debug('Command failed, rolling back: %s', str(e))
        session.rollback()
        raise   # Propagate exceptions raised from this module.
    except InvalidEvent:
        session.rollback()
        raise
    except Exception as e:
        logger.debug('Command failed, rolling back: %s', str(e))
        session.rollback()
        raise TransactionFailed('Failed to execute transaction') from e