Source code for arxiv.submission.services.classic.util

"""Utility classes and functions for :mod:`.services.classic`."""

import json
from contextlib import contextmanager
from typing import Optional, Generator

from flask import Flask
from sqlalchemy import create_engine
import sqlalchemy.types as types
from sqlalchemy.engine import Engine
from sqlalchemy.orm.session import Session
from sqlalchemy.orm import sessionmaker
from flask_sqlalchemy import SQLAlchemy

from arxiv.base.globals import get_application_config, get_application_global
from arxiv.base import logging
from .exceptions import ClassicBaseException, TransactionFailed
from ...exceptions import InvalidEvent
from ... import serializer


[docs]class ClassicSQLAlchemy(SQLAlchemy): """SQLAlchemy integration for the classic database."""
[docs] def init_app(self, app: Flask) -> None: """Set default configuration.""" app.config.setdefault( 'SQLALCHEMY_DATABASE_URI', app.config.get('CLASSIC_DATABASE_URI', 'sqlite://') ) app.config.setdefault('SQLALCHEMY_TRACK_MODIFICATIONS', False) super(ClassicSQLAlchemy, self).init_app(app)
[docs] def apply_pool_defaults(self, app, options): """Set options for create_engine().""" super(ClassicSQLAlchemy, self).apply_pool_defaults(app, options) if app.config['SQLALCHEMY_DATABASE_URI'].startswith('mysql'): options['json_serializer'] = serializer.dumps options['json_deserializer'] = serializer.loads
db: SQLAlchemy = ClassicSQLAlchemy() logger = logging.getLogger(__name__)
[docs]class SQLiteJSON(types.TypeDecorator): """A SQLite-friendly JSON data type.""" impl = types.TEXT
[docs] def process_bind_param(self, value: Optional[dict], dialect: str) -> str: """Serialize a dict to JSON.""" if value is not None: value = serializer.dumps(value) return value
[docs] def process_result_value(self, value: str, dialect: str) -> Optional[dict]: """Deserialize JSON content to a dict.""" if value is not None: value = serializer.loads(value) return value
# SQLite does not support JSON, so we extend JSON to use our custom data type # as a variant for the 'sqlite' dialect. FriendlyJSON = types.JSON().with_variant(SQLiteJSON, 'sqlite')
[docs]def current_engine() -> Engine: """Get/create :class:`.Engine` for this context.""" return db.engine
[docs]def current_session() -> Session: """Get/create :class:`.Session` for this context.""" return db.session()
[docs]@contextmanager def transaction() -> Generator: """Context manager for database transaction.""" session = current_session() logger.debug('transaction with session %s', id(session)) try: yield session # Only commit if there are un-flushed changes. The caller may commit # explicitly, e.g. to do exception handling. if session.dirty or session.deleted or session.new: session.commit() logger.debug('committed!') except ClassicBaseException as e: logger.debug('Command failed, rolling back: %s', str(e)) session.rollback() raise # Propagate exceptions raised from this module. except InvalidEvent: session.rollback() raise except Exception as e: logger.debug('Command failed, rolling back: %s', str(e)) session.rollback() raise TransactionFailed('Failed to execute transaction') from e