Merge pull request #773 from not-my-profile/typing

More typing
This commit is contained in:
Martin Fischer 2022-01-18 16:28:32 +01:00 committed by GitHub
commit 96a1f79c6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 81 additions and 45 deletions

View File

@ -4,7 +4,9 @@
import asyncio
import logging
from ssl import SSLContext
import threading
from typing import Any, Dict
import httpx
from httpx_socks import AsyncProxyTransport
@ -23,7 +25,7 @@ else:
logger = logger.getChild('searx.network.client')
LOOP = None
SSLCONTEXTS = {}
SSLCONTEXTS: Dict[Any, SSLContext] = {}
TRANSPORT_KWARGS = {
'trust_env': False,
}

View File

@ -7,6 +7,7 @@ import atexit
import asyncio
import ipaddress
from itertools import cycle
from typing import Dict
import httpx
@ -16,7 +17,7 @@ from .client import new_client, get_loop, AsyncHTTPTransportNoHttp
logger = logger.getChild('network')
DEFAULT_NAME = '__DEFAULT__'
NETWORKS = {}
NETWORKS: Dict[str, 'Network'] = {}
# requests compatibility when reading proxy settings from settings.yml
PROXY_PATTERN_MAPPING = {
'http': 'http://',

View File

@ -10,7 +10,7 @@ from os.path import abspath, basename, dirname, exists, join
from shutil import copyfile
from pkgutil import iter_modules
from logging import getLogger
from typing import List
from typing import List, Tuple
from searx import logger, settings
@ -22,6 +22,9 @@ class Plugin: # pylint: disable=too-few-public-methods
name: str
description: str
default_on: bool
js_dependencies: Tuple[str]
css_dependencies: Tuple[str]
preference_section: str
logger = logger.getChild("plugins")

View File

@ -2,7 +2,9 @@ import re
from collections import defaultdict
from operator import itemgetter
from threading import RLock
from typing import List, NamedTuple, Set
from urllib.parse import urlparse, unquote
from searx import logger
from searx.engines import engines
from searx.metrics import histogram_observe, counter_add, count_error
@ -137,6 +139,18 @@ def result_score(result):
return sum((occurences * weight) / position for position in result['positions'])
class Timing(NamedTuple):
engine: str
total: float
load: float
class UnresponsiveEngine(NamedTuple):
engine: str
error_type: str
suspended: bool
class ResultContainer:
"""docstring for ResultContainer"""
@ -168,8 +182,8 @@ class ResultContainer:
self.engine_data = defaultdict(dict)
self._closed = False
self.paging = False
self.unresponsive_engines = set()
self.timings = []
self.unresponsive_engines: Set[UnresponsiveEngine] = set()
self.timings: List[Timing] = []
self.redirect_url = None
self.on_result = lambda _: True
self._lock = RLock()
@ -401,17 +415,12 @@ class ResultContainer:
return 0
return resultnum_sum / len(self._number_of_results)
def add_unresponsive_engine(self, engine_name, error_type, error_message=None, suspended=False):
def add_unresponsive_engine(self, engine_name: str, error_type: str, suspended: bool = False):
if engines[engine_name].display_error_messages:
self.unresponsive_engines.add((engine_name, error_type, error_message, suspended))
self.unresponsive_engines.add(UnresponsiveEngine(engine_name, error_type, suspended))
def add_timing(self, engine_name, engine_time, page_load_time):
timing = {
'engine': engines[engine_name].shortcut,
'total': engine_time,
'load': page_load_time,
}
self.timings.append(timing)
def add_timing(self, engine_name: str, engine_time: float, page_load_time: float):
self.timings.append(Timing(engine_name, total=engine_time, load=page_load_time))
def get_timings(self):
return self.timings

View File

@ -15,6 +15,7 @@ __all__ = [
]
import threading
from typing import Dict
from searx import logger
from searx import engines
@ -26,7 +27,7 @@ from .online_currency import OnlineCurrencyProcessor
from .abstract import EngineProcessor
logger = logger.getChild('search.processors')
PROCESSORS = {}
PROCESSORS: Dict[str, EngineProcessor] = {}
"""Cache request processores, stored by *engine-name* (:py:func:`initialize`)"""

View File

@ -8,6 +8,7 @@
import threading
from abc import abstractmethod, ABC
from timeit import default_timer
from typing import Dict, Union
from searx import settings, logger
from searx.engines import engines
@ -17,7 +18,7 @@ from searx.exceptions import SearxEngineAccessDeniedException, SearxEngineRespon
from searx.utils import get_engine_from_settings
logger = logger.getChild('searx.search.processor')
SUSPENDED_STATUS = {}
SUSPENDED_STATUS: Dict[Union[int, str], 'SuspendedStatus'] = {}
class SuspendedStatus:
@ -61,7 +62,7 @@ class EngineProcessor(ABC):
__slots__ = 'engine', 'engine_name', 'lock', 'suspended_status', 'logger'
def __init__(self, engine, engine_name):
def __init__(self, engine, engine_name: str):
self.engine = engine
self.engine_name = engine_name
self.logger = engines[engine_name].logger

View File

@ -14,8 +14,11 @@ from datetime import datetime, timedelta
from timeit import default_timer
from html import escape
from io import StringIO
import typing
from typing import List, Dict, Iterable
import urllib
import urllib.parse
from urllib.parse import urlencode
import httpx
@ -28,7 +31,6 @@ import flask
from flask import (
Flask,
request,
render_template,
url_for,
Response,
@ -55,6 +57,7 @@ from searx import (
searx_debug,
)
from searx.data import ENGINE_DESCRIPTIONS
from searx.results import Timing, UnresponsiveEngine
from searx.settings_defaults import OUTPUT_FORMATS
from searx.settings_loader import get_default_settings_path
from searx.exceptions import SearxParameterException
@ -89,7 +92,7 @@ from searx.utils import (
)
from searx.version import VERSION_STRING, GIT_URL, GIT_BRANCH
from searx.query import RawTextQuery
from searx.plugins import plugins, initialize as plugin_initialize
from searx.plugins import Plugin, plugins, initialize as plugin_initialize
from searx.plugins.oa_doi_rewrite import get_doi_resolver
from searx.preferences import (
Preferences,
@ -224,6 +227,21 @@ exception_classname_to_text = {
_flask_babel_get_translations = flask_babel.get_translations
class ExtendedRequest(flask.Request):
"""This class is never initialized and only used for type checking."""
preferences: Preferences
errors: List[str]
user_plugins: List[Plugin]
form: Dict[str, str]
start_time: float
render_time: float
timings: List[Timing]
request = typing.cast(ExtendedRequest, flask.request)
def _get_translations():
if has_request_context() and request.form.get('use-translation') == 'oc':
babel_ext = flask_babel.current_app.extensions['babel']
@ -321,7 +339,7 @@ def code_highlighter(codelines, language=None):
return html_code
def get_current_theme_name(override=None):
def get_current_theme_name(override: str = None) -> str:
"""Returns theme name.
Checks in this order:
@ -337,14 +355,14 @@ def get_current_theme_name(override=None):
return theme_name
def get_result_template(theme_name, template_name):
def get_result_template(theme_name: str, template_name: str):
themed_path = theme_name + '/result_templates/' + template_name
if themed_path in result_templates:
return themed_path
return 'result_templates/' + template_name
def url_for_theme(endpoint, override_theme=None, **values):
def url_for_theme(endpoint: str, override_theme: str = None, **values):
if endpoint == 'static' and values.get('filename'):
theme_name = get_current_theme_name(override=override_theme)
filename_with_theme = "themes/{}/{}".format(theme_name, values['filename'])
@ -354,7 +372,7 @@ def url_for_theme(endpoint, override_theme=None, **values):
return url
def proxify(url):
def proxify(url: str):
if url.startswith('//'):
url = 'https:' + url
@ -369,7 +387,7 @@ def proxify(url):
return '{0}?{1}'.format(settings['result_proxy']['url'], urlencode(url_params))
def image_proxify(url):
def image_proxify(url: str):
if url.startswith('//'):
url = 'https:' + url
@ -405,7 +423,7 @@ def get_translations():
}
def _get_enable_categories(all_categories):
def _get_enable_categories(all_categories: Iterable[str]):
disabled_engines = request.preferences.engines.get_disabled()
enabled_categories = set(
# pylint: disable=consider-using-dict-items
@ -417,14 +435,14 @@ def _get_enable_categories(all_categories):
return [x for x in all_categories if x in enabled_categories]
def get_pretty_url(parsed_url):
def get_pretty_url(parsed_url: urllib.parse.ParseResult):
path = parsed_url.path
path = path[:-1] if len(path) > 0 and path[-1] == '/' else path
path = path.replace("/", " ")
return [parsed_url.scheme + "://" + parsed_url.netloc, path]
def render(template_name, override_theme=None, **kwargs):
def render(template_name: str, override_theme: str = None, **kwargs):
# values from the HTTP requests
kwargs['endpoint'] = 'results' if 'q' in kwargs else request.endpoint
kwargs['cookies'] = request.cookies
@ -552,7 +570,7 @@ def pre_request():
@app.after_request
def add_default_headers(response):
def add_default_headers(response: flask.Response):
# set default http headers
for header, value in settings['server']['default_http_headers'].items():
if header in response.headers:
@ -562,29 +580,28 @@ def add_default_headers(response):
@app.after_request
def post_request(response):
def post_request(response: flask.Response):
total_time = default_timer() - request.start_time
timings_all = [
'total;dur=' + str(round(total_time * 1000, 3)),
'render;dur=' + str(round(request.render_time * 1000, 3)),
]
if len(request.timings) > 0:
timings = sorted(request.timings, key=lambda v: v['total'])
timings = sorted(request.timings, key=lambda t: t.total)
timings_total = [
'total_' + str(i) + '_' + v['engine'] + ';dur=' + str(round(v['total'] * 1000, 3))
for i, v in enumerate(timings)
'total_' + str(i) + '_' + t.engine + ';dur=' + str(round(t.total * 1000, 3)) for i, t in enumerate(timings)
]
timings_load = [
'load_' + str(i) + '_' + v['engine'] + ';dur=' + str(round(v['load'] * 1000, 3))
for i, v in enumerate(timings)
if v.get('load')
'load_' + str(i) + '_' + t.engine + ';dur=' + str(round(t.load * 1000, 3))
for i, t in enumerate(timings)
if t.load
]
timings_all = timings_all + timings_total + timings_load
response.headers.add('Server-Timing', ', '.join(timings_all))
return response
def index_error(output_format, error_message):
def index_error(output_format: str, error_message: str):
if output_format == 'json':
return Response(json.dumps({'error': error_message}), mimetype='application/json')
if output_format == 'csv':
@ -828,23 +845,21 @@ def search():
)
def __get_translated_errors(unresponsive_engines):
def __get_translated_errors(unresponsive_engines: Iterable[UnresponsiveEngine]):
translated_errors = []
# make a copy unresponsive_engines to avoid "RuntimeError: Set changed size
# during iteration" it happens when an engine modifies the ResultContainer
# after the search_multiple_requests method has stopped waiting
for unresponsive_engine in list(unresponsive_engines):
error_user_text = exception_classname_to_text.get(unresponsive_engine[1])
for unresponsive_engine in unresponsive_engines:
error_user_text = exception_classname_to_text.get(unresponsive_engine.error_type)
if not error_user_text:
error_user_text = exception_classname_to_text[None]
error_msg = gettext(error_user_text)
if unresponsive_engine[2]:
error_msg = "{} {}".format(error_msg, unresponsive_engine[2])
if unresponsive_engine[3]:
if unresponsive_engine.suspended:
error_msg = gettext('Suspended') + ': ' + error_msg
translated_errors.append((unresponsive_engine[0], error_msg))
translated_errors.append((unresponsive_engine.engine, error_msg))
return sorted(translated_errors, key=lambda e: e[0])
@ -1060,7 +1075,7 @@ def preferences():
)
def _is_selected_language_supported(engine, preferences): # pylint: disable=redefined-outer-name
def _is_selected_language_supported(engine, preferences: Preferences): # pylint: disable=redefined-outer-name
language = preferences.get_value('language')
if language == 'all':
return True

View File

@ -3,6 +3,7 @@
import json
from urllib.parse import ParseResult
from mock import Mock
from searx.results import Timing
import searx.search.processors
from searx.search import Search
@ -46,7 +47,10 @@ class ViewsTestCase(SearxTestCase):
},
]
timings = [{'engine': 'startpage', 'total': 0.8, 'load': 0.7}, {'engine': 'youtube', 'total': 0.9, 'load': 0.6}]
timings = [
Timing(engine='startpage', total=0.8, load=0.7),
Timing(engine='youtube', total=0.9, load=0.6),
]
def search_mock(search_self, *args):
search_self.result_container = Mock(