Source code for search.controllers.advanced

"""
Handle requests to support the advanced search feature.

The primary entrypoint to this module is :func:`.search`, which handles
GET requests to the author search endpoint. It uses
:class:`.AdvancedSearchForm` to generate form HTML, validate request
parameters, and produce informative error messages for the user.
"""

from typing import Tuple, Dict, Any, Optional
import re
from datetime import date, datetime
from dateutil.relativedelta import relativedelta
from pytz import timezone


from werkzeug.datastructures import MultiDict, ImmutableMultiDict
from werkzeug.exceptions import InternalServerError, BadRequest, NotFound
from flask import url_for

from arxiv import status, taxonomy

from search.services import index, fulltext, metadata
from search.domain import AdvancedQuery, FieldedSearchTerm, DateRange, \
    Classification, FieldedSearchList, ClassificationList, Query, asdict
from arxiv.base import logging
from search.controllers.util import paginate, catch_underscore_syntax

from . import forms

logger = logging.getLogger(__name__)

Response = Tuple[Dict[str, Any], int, Dict[str, Any]]

EASTERN = timezone('US/Eastern')





def _query_from_form(form: forms.AdvancedSearchForm) -> AdvancedQuery:
    """
    Generate a :class:`.AdvancedQuery` from valid :class:`.AdvancedSearchForm`.

    Parameters
    ----------
    form : :class:`.AdvancedSearchForm`
        Presumed to be filled and valid.

    Returns
    -------
    :class:`.AdvancedQuery`

    """
    q = AdvancedQuery()
    q = _update_query_with_dates(q, form.date.data)
    q = _update_query_with_terms(q, form.terms.data)
    q = _update_query_with_classification(q, form.classification.data)
    q.include_cross_list = form.classification.include_cross_list.data \
        == form.classification.INCLUDE_CROSS_LIST
    if form.include_older_versions.data:
        q.include_older_versions = True
    order = form.order.data
    if order and order != 'None':
        q.order = order
    q.hide_abstracts = form.abstracts.data == form.HIDE_ABSTRACTS
    return q


def _update_query_with_classification(q: AdvancedQuery, data: MultiDict) \
        -> AdvancedQuery:
    q.classification = ClassificationList()
    archives = [
        ('computer_science', 'cs'), ('economics', 'econ'), ('eess', 'eess'),
        ('mathematics', 'math'), ('q_biology', 'q-bio'),
        ('q_finance', 'q-fin'), ('statistics', 'stat')
    ]
    for field, archive in archives:
        if data.get(field):
            # Fix for these typing issues is coming soon!
            #  See: https://github.com/python/mypy/pull/4397
            q.classification.append(
                Classification(archive={'id': archive})  # type: ignore
            )
    if data.get('physics') and 'physics_archives' in data:
        if 'all' in data['physics_archives']:
            q.classification.append(
                Classification(group={'id': 'grp_physics'})  # type: ignore
            )
        else:
            q.classification.append(
                Classification(     # type: ignore
                    group={'id': 'grp_physics'},
                    archive={'id': data['physics_archives']}
                )
            )
    return q


def _update_query_with_terms(q: AdvancedQuery, terms_data: list) \
        -> AdvancedQuery:
    q.terms = FieldedSearchList([
        FieldedSearchTerm(**term)       # type: ignore
        for term in terms_data if term['term']
    ])
    return q


def _update_query_with_dates(q: AdvancedQuery, date_data: MultiDict) \
        -> AdvancedQuery:
    filter_by = date_data['filter_by']
    if filter_by == 'all_dates':    # Nothing to do; all dates by default.
        return q
    elif filter_by == 'past_12':
        one_year_ago = date.today() - relativedelta(months=12)
        # Fix for these typing issues is coming soon!
        #  See: https://github.com/python/mypy/pull/4397
        q.date_range = DateRange(   # type: ignore
            start_date=datetime(year=one_year_ago.year,
                                month=one_year_ago.month,
                                day=1, hour=0, minute=0, second=0,
                                tzinfo=EASTERN)
        )
    elif filter_by == 'specific_year':
        q.date_range = DateRange(   # type: ignore
            start_date=datetime(year=date_data['year'].year, month=1, day=1,
                                hour=0, minute=0, second=0, tzinfo=EASTERN),
            end_date=datetime(year=date_data['year'].year + 1, month=1, day=1,
                              hour=0, minute=0, second=0, tzinfo=EASTERN),
        )
    elif filter_by == 'date_range':
        if date_data['from_date']:
            date_data['from_date'] = datetime.combine(    # type: ignore
                date_data['from_date'],
                datetime.min.time(),
                tzinfo=EASTERN)
        if date_data['to_date']:
            date_data['to_date'] = datetime.combine(    # type: ignore
                date_data['to_date'],
                datetime.min.time(),
                tzinfo=EASTERN)

        q.date_range = DateRange(   # type: ignore
            start_date=date_data['from_date'],
            end_date=date_data['to_date'],
        )

    if q.date_range:
        q.date_range.date_type = date_data['date_type']
    return q


# TODO: this _could_ go on the AdvancedSearchForm or ClassificationForm.