mirror of
https://github.com/searxng/searxng
synced 2024-01-01 19:24:07 +01:00
Replace httpx by aiohttp
This commit is contained in:
parent
065b4dab56
commit
4ea887471b
14 changed files with 533 additions and 352 deletions
1
.github/workflows/integration.yml
vendored
1
.github/workflows/integration.yml
vendored
|
@ -11,6 +11,7 @@ jobs:
|
|||
name: Python ${{ matrix.python-version }}
|
||||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-20.04]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
|
|
|
@ -7,10 +7,11 @@ lxml==4.6.3
|
|||
pygments==2.10.0
|
||||
python-dateutil==2.8.2
|
||||
pyyaml==5.4.1
|
||||
httpx[http2]==0.17.1
|
||||
httpx==0.17.1
|
||||
aiohttp==3.7.4.post0
|
||||
aiohttp-socks==0.6.0
|
||||
Brotli==1.0.9
|
||||
uvloop==0.16.0; python_version >= '3.7'
|
||||
uvloop==0.14.0; python_version < '3.7'
|
||||
httpx-socks[asyncio]==0.3.1
|
||||
langdetect==1.0.9
|
||||
setproctitle==1.2.2
|
||||
|
|
|
@ -20,7 +20,7 @@ from lxml import etree
|
|||
from json import loads
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from httpx import HTTPError
|
||||
from aiohttp import ClientError
|
||||
|
||||
|
||||
from searx import settings
|
||||
|
@ -137,5 +137,5 @@ def search_autocomplete(backend_name, query, lang):
|
|||
|
||||
try:
|
||||
return backend(query, lang)
|
||||
except (HTTPError, SearxEngineResponseException):
|
||||
except (ClientError, SearxEngineResponseException):
|
||||
return []
|
||||
|
|
|
@ -8,11 +8,11 @@ import concurrent.futures
|
|||
from types import MethodType
|
||||
from timeit import default_timer
|
||||
|
||||
import httpx
|
||||
import h2.exceptions
|
||||
import aiohttp
|
||||
|
||||
from .network import get_network, initialize
|
||||
from .client import get_loop
|
||||
from .response import Response
|
||||
from .raise_for_httperror import raise_for_httperror
|
||||
|
||||
# queue.SimpleQueue: Support Python 3.6
|
||||
|
@ -73,12 +73,12 @@ def get_context_network():
|
|||
return THREADLOCAL.__dict__.get('network') or get_network()
|
||||
|
||||
|
||||
def request(method, url, **kwargs):
|
||||
def request(method, url, **kwargs) -> Response:
|
||||
"""same as requests/requests/api.py request(...)"""
|
||||
global THREADLOCAL
|
||||
time_before_request = default_timer()
|
||||
|
||||
# timeout (httpx)
|
||||
# timeout (aiohttp)
|
||||
if 'timeout' in kwargs:
|
||||
timeout = kwargs['timeout']
|
||||
else:
|
||||
|
@ -108,16 +108,15 @@ def request(method, url, **kwargs):
|
|||
# network
|
||||
network = get_context_network()
|
||||
|
||||
#
|
||||
# kwargs['compress'] = True
|
||||
|
||||
# do request
|
||||
future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
|
||||
try:
|
||||
response = future.result(timeout)
|
||||
except concurrent.futures.TimeoutError as e:
|
||||
raise httpx.TimeoutException('Timeout', request=None) from e
|
||||
|
||||
# requests compatibility
|
||||
# see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses
|
||||
response.ok = not response.is_error
|
||||
raise asyncio.TimeoutError() from e
|
||||
|
||||
# update total_time.
|
||||
# See get_time_for_thread() and reset_time_for_thread()
|
||||
|
@ -132,64 +131,53 @@ def request(method, url, **kwargs):
|
|||
return response
|
||||
|
||||
|
||||
def get(url, **kwargs):
|
||||
def get(url, **kwargs) -> Response:
|
||||
kwargs.setdefault('allow_redirects', True)
|
||||
return request('get', url, **kwargs)
|
||||
|
||||
|
||||
def options(url, **kwargs):
|
||||
def options(url, **kwargs) -> Response:
|
||||
kwargs.setdefault('allow_redirects', True)
|
||||
return request('options', url, **kwargs)
|
||||
|
||||
|
||||
def head(url, **kwargs):
|
||||
def head(url, **kwargs) -> Response:
|
||||
kwargs.setdefault('allow_redirects', False)
|
||||
return request('head', url, **kwargs)
|
||||
|
||||
|
||||
def post(url, data=None, **kwargs):
|
||||
def post(url, data=None, **kwargs) -> Response:
|
||||
return request('post', url, data=data, **kwargs)
|
||||
|
||||
|
||||
def put(url, data=None, **kwargs):
|
||||
def put(url, data=None, **kwargs) -> Response:
|
||||
return request('put', url, data=data, **kwargs)
|
||||
|
||||
|
||||
def patch(url, data=None, **kwargs):
|
||||
def patch(url, data=None, **kwargs) -> Response:
|
||||
return request('patch', url, data=data, **kwargs)
|
||||
|
||||
|
||||
def delete(url, **kwargs):
|
||||
def delete(url, **kwargs) -> Response:
|
||||
return request('delete', url, **kwargs)
|
||||
|
||||
|
||||
async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
|
||||
try:
|
||||
async with network.stream(method, url, **kwargs) as response:
|
||||
async with await network.request(method, url, stream=True, **kwargs) as response:
|
||||
queue.put(response)
|
||||
# aiter_raw: access the raw bytes on the response without applying any HTTP content decoding
|
||||
# https://www.python-httpx.org/quickstart/#streaming-responses
|
||||
async for chunk in response.aiter_raw(65536):
|
||||
if len(chunk) > 0:
|
||||
chunk = await response.iter_content(65536)
|
||||
while chunk:
|
||||
queue.put(chunk)
|
||||
except httpx.ResponseClosed:
|
||||
# the response was closed
|
||||
pass
|
||||
except (httpx.HTTPError, OSError, h2.exceptions.ProtocolError) as e:
|
||||
chunk = await response.iter_content(65536)
|
||||
except aiohttp.client.ClientError as e:
|
||||
queue.put(e)
|
||||
finally:
|
||||
queue.put(None)
|
||||
|
||||
|
||||
def _close_response_method(self):
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.aclose(),
|
||||
get_loop()
|
||||
)
|
||||
|
||||
|
||||
def stream(method, url, **kwargs):
|
||||
"""Replace httpx.stream.
|
||||
"""Stream Response in sync world
|
||||
|
||||
Usage:
|
||||
stream = poolrequests.stream(...)
|
||||
|
@ -197,8 +185,6 @@ def stream(method, url, **kwargs):
|
|||
for chunk in stream:
|
||||
...
|
||||
|
||||
httpx.Client.stream requires to write the httpx.HTTPTransport version of the
|
||||
the httpx.AsyncHTTPTransport declared above.
|
||||
"""
|
||||
queue = SimpleQueue()
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
|
@ -210,7 +196,6 @@ def stream(method, url, **kwargs):
|
|||
response = queue.get()
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
response.close = MethodType(_close_response_method, response)
|
||||
yield response
|
||||
|
||||
# yield chunks
|
||||
|
|
|
@ -1,21 +1,17 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# lint: pylint
|
||||
# pylint: disable=missing-module-docstring, missing-function-docstring, global-statement
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import httpcore
|
||||
import httpx
|
||||
from httpx_socks import AsyncProxyTransport
|
||||
from python_socks import (
|
||||
parse_proxy_url,
|
||||
ProxyConnectionError,
|
||||
ProxyTimeoutError,
|
||||
ProxyError
|
||||
)
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from searx import logger
|
||||
import aiohttp
|
||||
from aiohttp.client_reqrep import ClientRequest
|
||||
from aiohttp_socks import ProxyConnector
|
||||
from requests.models import InvalidURL
|
||||
|
||||
from yarl import URL
|
||||
|
||||
# Optional uvloop (support Python 3.6)
|
||||
try:
|
||||
|
@ -26,118 +22,38 @@ else:
|
|||
uvloop.install()
|
||||
|
||||
|
||||
logger = logger.getChild('searx.http.client')
|
||||
LOOP = None
|
||||
SSLCONTEXTS = {}
|
||||
TRANSPORT_KWARGS = {
|
||||
'backend': 'asyncio',
|
||||
'trust_env': False,
|
||||
}
|
||||
logger = logging.getLogger('searx.http.client')
|
||||
LOOP: Optional[asyncio.AbstractEventLoop] = None
|
||||
RESOLVER: Optional[aiohttp.ThreadedResolver] = None
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
async def close_connections_for_url(
|
||||
connection_pool: httpcore.AsyncConnectionPool,
|
||||
url: httpcore._utils.URL ):
|
||||
class ClientRequestNoHttp(ClientRequest):
|
||||
|
||||
origin = httpcore._utils.url_to_origin(url)
|
||||
logger.debug('Drop connections for %r', origin)
|
||||
connections_to_close = connection_pool._connections_for_origin(origin)
|
||||
for connection in connections_to_close:
|
||||
await connection_pool._remove_from_pool(connection)
|
||||
try:
|
||||
await connection.aclose()
|
||||
except httpcore.NetworkError as e:
|
||||
logger.warning('Error closing an existing connection', exc_info=e)
|
||||
# pylint: enable=protected-access
|
||||
def __init__(self, method: str, url: URL, *args, **kwargs):
|
||||
if url.scheme == 'http':
|
||||
raise InvalidURL(url)
|
||||
super().__init__(method, url, *args, **kwargs)
|
||||
|
||||
|
||||
def get_sslcontexts(proxy_url=None, cert=None, verify=True, trust_env=True, http2=False):
|
||||
global SSLCONTEXTS
|
||||
key = (proxy_url, cert, verify, trust_env, http2)
|
||||
if key not in SSLCONTEXTS:
|
||||
SSLCONTEXTS[key] = httpx.create_ssl_context(cert, verify, trust_env, http2)
|
||||
return SSLCONTEXTS[key]
|
||||
def new_client(
|
||||
# pylint: disable=too-many-arguments
|
||||
enable_http, verify, max_connections, max_keepalive_connections, keepalive_expiry,
|
||||
proxy_url, local_address):
|
||||
# connector
|
||||
conn_kwargs = {
|
||||
'ssl': verify,
|
||||
'keepalive_timeout': keepalive_expiry or 15,
|
||||
'limit': max_connections,
|
||||
'limit_per_host': max_keepalive_connections,
|
||||
'loop': LOOP,
|
||||
}
|
||||
if local_address:
|
||||
conn_kwargs['local_addr'] = (local_address, 0)
|
||||
|
||||
|
||||
class AsyncHTTPTransportNoHttp(httpcore.AsyncHTTPTransport):
|
||||
"""Block HTTP request"""
|
||||
|
||||
async def arequest(self, method, url, headers=None, stream=None, ext=None):
|
||||
raise httpcore.UnsupportedProtocol("HTTP protocol is disabled")
|
||||
|
||||
|
||||
class AsyncProxyTransportFixed(AsyncProxyTransport):
|
||||
"""Fix httpx_socks.AsyncProxyTransport
|
||||
|
||||
Map python_socks exceptions to httpcore.ProxyError
|
||||
|
||||
Map socket.gaierror to httpcore.ConnectError
|
||||
|
||||
Note: keepalive_expiry is ignored, AsyncProxyTransport should call:
|
||||
* self._keepalive_sweep()
|
||||
* self._response_closed(self, connection)
|
||||
|
||||
Note: AsyncProxyTransport inherit from AsyncConnectionPool
|
||||
|
||||
Note: the API is going to change on httpx 0.18.0
|
||||
see https://github.com/encode/httpx/pull/1522
|
||||
"""
|
||||
|
||||
async def arequest(self, method, url, headers=None, stream=None, ext=None):
|
||||
retry = 2
|
||||
while retry > 0:
|
||||
retry -= 1
|
||||
try:
|
||||
return await super().arequest(method, url, headers, stream, ext)
|
||||
except (ProxyConnectionError, ProxyTimeoutError, ProxyError) as e:
|
||||
raise httpcore.ProxyError(e)
|
||||
except OSError as e:
|
||||
# socket.gaierror when DNS resolution fails
|
||||
raise httpcore.NetworkError(e)
|
||||
except httpcore.RemoteProtocolError as e:
|
||||
# in case of httpcore.RemoteProtocolError: Server disconnected
|
||||
await close_connections_for_url(self, url)
|
||||
logger.warning('httpcore.RemoteProtocolError: retry', exc_info=e)
|
||||
# retry
|
||||
except (httpcore.NetworkError, httpcore.ProtocolError) as e:
|
||||
# httpcore.WriteError on HTTP/2 connection leaves a new opened stream
|
||||
# then each new request creates a new stream and raise the same WriteError
|
||||
await close_connections_for_url(self, url)
|
||||
raise e
|
||||
|
||||
|
||||
class AsyncHTTPTransportFixed(httpx.AsyncHTTPTransport):
|
||||
"""Fix httpx.AsyncHTTPTransport"""
|
||||
|
||||
async def arequest(self, method, url, headers=None, stream=None, ext=None):
|
||||
retry = 2
|
||||
while retry > 0:
|
||||
retry -= 1
|
||||
try:
|
||||
return await super().arequest(method, url, headers, stream, ext)
|
||||
except OSError as e:
|
||||
# socket.gaierror when DNS resolution fails
|
||||
raise httpcore.ConnectError(e)
|
||||
except httpcore.CloseError as e:
|
||||
# httpcore.CloseError: [Errno 104] Connection reset by peer
|
||||
# raised by _keepalive_sweep()
|
||||
# from https://github.com/encode/httpcore/blob/4b662b5c42378a61e54d673b4c949420102379f5/httpcore/_backends/asyncio.py#L198 # pylint: disable=line-too-long
|
||||
await close_connections_for_url(self._pool, url)
|
||||
logger.warning('httpcore.CloseError: retry', exc_info=e)
|
||||
# retry
|
||||
except httpcore.RemoteProtocolError as e:
|
||||
# in case of httpcore.RemoteProtocolError: Server disconnected
|
||||
await close_connections_for_url(self._pool, url)
|
||||
logger.warning('httpcore.RemoteProtocolError: retry', exc_info=e)
|
||||
# retry
|
||||
except (httpcore.ProtocolError, httpcore.NetworkError) as e:
|
||||
await close_connections_for_url(self._pool, url)
|
||||
raise e
|
||||
|
||||
|
||||
def get_transport_for_socks_proxy(verify, http2, local_address, proxy_url, limit, retries):
|
||||
global TRANSPORT_KWARGS
|
||||
if not proxy_url:
|
||||
conn_kwargs['resolver'] = RESOLVER
|
||||
connector = aiohttp.TCPConnector(**conn_kwargs)
|
||||
else:
|
||||
# support socks5h (requests compatibility):
|
||||
# https://requests.readthedocs.io/en/master/user/advanced/#socks
|
||||
# socks5:// hostname is resolved on client side
|
||||
|
@ -147,81 +63,14 @@ def get_transport_for_socks_proxy(verify, http2, local_address, proxy_url, limit
|
|||
if proxy_url.startswith(socks5h):
|
||||
proxy_url = 'socks5://' + proxy_url[len(socks5h):]
|
||||
rdns = True
|
||||
|
||||
proxy_type, proxy_host, proxy_port, proxy_username, proxy_password = parse_proxy_url(proxy_url)
|
||||
verify = get_sslcontexts(proxy_url, None, True, False, http2) if verify is True else verify
|
||||
return AsyncProxyTransportFixed(
|
||||
proxy_type=proxy_type, proxy_host=proxy_host, proxy_port=proxy_port,
|
||||
username=proxy_username, password=proxy_password,
|
||||
rdns=rdns,
|
||||
loop=get_loop(),
|
||||
verify=verify,
|
||||
http2=http2,
|
||||
local_address=local_address,
|
||||
max_connections=limit.max_connections,
|
||||
max_keepalive_connections=limit.max_keepalive_connections,
|
||||
keepalive_expiry=limit.keepalive_expiry,
|
||||
retries=retries,
|
||||
**TRANSPORT_KWARGS
|
||||
)
|
||||
|
||||
|
||||
def get_transport(verify, http2, local_address, proxy_url, limit, retries):
|
||||
global TRANSPORT_KWARGS
|
||||
verify = get_sslcontexts(None, None, True, False, http2) if verify is True else verify
|
||||
return AsyncHTTPTransportFixed(
|
||||
# pylint: disable=protected-access
|
||||
verify=verify,
|
||||
http2=http2,
|
||||
local_address=local_address,
|
||||
proxy=httpx._config.Proxy(proxy_url) if proxy_url else None,
|
||||
limits=limit,
|
||||
retries=retries,
|
||||
**TRANSPORT_KWARGS
|
||||
)
|
||||
|
||||
|
||||
def iter_proxies(proxies):
|
||||
# https://www.python-httpx.org/compatibility/#proxy-keys
|
||||
if isinstance(proxies, str):
|
||||
yield 'all://', proxies
|
||||
elif isinstance(proxies, dict):
|
||||
for pattern, proxy_url in proxies.items():
|
||||
yield pattern, proxy_url
|
||||
|
||||
|
||||
def new_client(
|
||||
# pylint: disable=too-many-arguments
|
||||
enable_http, verify, enable_http2,
|
||||
max_connections, max_keepalive_connections, keepalive_expiry,
|
||||
proxies, local_address, retries, max_redirects ):
|
||||
limit = httpx.Limits(
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry
|
||||
)
|
||||
# See https://www.python-httpx.org/advanced/#routing
|
||||
mounts = {}
|
||||
for pattern, proxy_url in iter_proxies(proxies):
|
||||
if not enable_http and (pattern == 'http' or pattern.startswith('http://')):
|
||||
continue
|
||||
if (proxy_url.startswith('socks4://')
|
||||
or proxy_url.startswith('socks5://')
|
||||
or proxy_url.startswith('socks5h://')
|
||||
):
|
||||
mounts[pattern] = get_transport_for_socks_proxy(
|
||||
verify, enable_http2, local_address, proxy_url, limit, retries
|
||||
)
|
||||
else:
|
||||
mounts[pattern] = get_transport(
|
||||
verify, enable_http2, local_address, proxy_url, limit, retries
|
||||
)
|
||||
|
||||
if not enable_http:
|
||||
mounts['http://'] = AsyncHTTPTransportNoHttp()
|
||||
|
||||
transport = get_transport(verify, enable_http2, local_address, None, limit, retries)
|
||||
return httpx.AsyncClient(transport=transport, mounts=mounts, max_redirects=max_redirects)
|
||||
conn_kwargs['resolver'] = RESOLVER
|
||||
connector = ProxyConnector.from_url(proxy_url, rdns=rdns, **conn_kwargs)
|
||||
# client
|
||||
session_kwargs = {}
|
||||
if enable_http:
|
||||
session_kwargs['request_class'] = ClientRequestNoHttp
|
||||
return aiohttp.ClientSession(connector=connector, **session_kwargs)
|
||||
|
||||
|
||||
def get_loop():
|
||||
|
@ -230,14 +79,11 @@ def get_loop():
|
|||
|
||||
|
||||
def init():
|
||||
# log
|
||||
for logger_name in ('hpack.hpack', 'hpack.table'):
|
||||
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
||||
|
||||
# loop
|
||||
def loop_thread():
|
||||
global LOOP
|
||||
global LOOP, RESOLVER
|
||||
LOOP = asyncio.new_event_loop()
|
||||
RESOLVER = aiohttp.resolver.DefaultResolver(LOOP)
|
||||
LOOP.run_forever()
|
||||
|
||||
thread = threading.Thread(
|
||||
|
|
|
@ -1,18 +1,24 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# lint: pylint
|
||||
# pylint: disable=global-statement
|
||||
# pylint: disable=missing-module-docstring, missing-class-docstring, missing-function-docstring
|
||||
# pylint: disable=missing-module-docstring, missing-class-docstring, missing-function-docstring, fixme
|
||||
|
||||
import typing
|
||||
import atexit
|
||||
import asyncio
|
||||
import ipaddress
|
||||
from itertools import cycle
|
||||
from logging import getLogger
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
from .client import new_client, get_loop
|
||||
from .response import Response
|
||||
from .utils import URLPattern
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
DEFAULT_NAME = '__DEFAULT__'
|
||||
NETWORKS = {}
|
||||
# requests compatibility when reading proxy settings from settings.yml
|
||||
|
@ -72,7 +78,7 @@ class Network:
|
|||
self.max_redirects = max_redirects
|
||||
self._local_addresses_cycle = self.get_ipaddress_cycle()
|
||||
self._proxies_cycle = self.get_proxy_cycles()
|
||||
self._clients = {}
|
||||
self._clients: typing.Dict[aiohttp.ClientSession] = {}
|
||||
self.check_parameters()
|
||||
|
||||
def check_parameters(self):
|
||||
|
@ -123,39 +129,44 @@ class Network:
|
|||
yield pattern, proxy_url
|
||||
|
||||
def get_proxy_cycles(self):
|
||||
# TODO : store proxy_settings, and match URLPattern
|
||||
proxy_settings = {}
|
||||
for pattern, proxy_urls in self.iter_proxies():
|
||||
proxy_settings[pattern] = cycle(proxy_urls)
|
||||
proxy_settings[URLPattern(pattern)] = cycle(proxy_urls)
|
||||
while True:
|
||||
# pylint: disable=stop-iteration-return
|
||||
yield tuple((pattern, next(proxy_url_cycle)) for pattern, proxy_url_cycle in proxy_settings.items())
|
||||
yield tuple((urlpattern, next(proxy_url_cycle)) for urlpattern, proxy_url_cycle in proxy_settings.items())
|
||||
|
||||
def get_client(self, verify=None, max_redirects=None):
|
||||
def get_proxy(self, url: URL):
|
||||
proxies = next(self._proxies_cycle)
|
||||
for urlpattern, proxy in proxies:
|
||||
if urlpattern.matches(url):
|
||||
return proxy
|
||||
return None
|
||||
|
||||
def get_client(self, url: URL, verify=None) -> aiohttp.ClientSession:
|
||||
verify = self.verify if verify is None else verify
|
||||
max_redirects = self.max_redirects if max_redirects is None else max_redirects
|
||||
local_address = next(self._local_addresses_cycle)
|
||||
proxies = next(self._proxies_cycle) # is a tuple so it can be part of the key
|
||||
key = (verify, max_redirects, local_address, proxies)
|
||||
if key not in self._clients or self._clients[key].is_closed:
|
||||
proxy = self.get_proxy(url) # is a tuple so it can be part of the key
|
||||
key = (verify, local_address, proxy)
|
||||
if key not in self._clients or self._clients[key].closed:
|
||||
# TODO add parameter self.enable_http
|
||||
self._clients[key] = new_client(
|
||||
self.enable_http,
|
||||
verify,
|
||||
self.enable_http2,
|
||||
self.max_connections,
|
||||
self.max_keepalive_connections,
|
||||
self.keepalive_expiry,
|
||||
dict(proxies),
|
||||
proxy,
|
||||
local_address,
|
||||
0,
|
||||
max_redirects
|
||||
)
|
||||
return self._clients[key]
|
||||
|
||||
async def aclose(self):
|
||||
async def close_client(client):
|
||||
async def close_client(client: aiohttp.ClientSession):
|
||||
try:
|
||||
await client.aclose()
|
||||
except httpx.HTTPError:
|
||||
await client.close()
|
||||
except aiohttp.ClientError:
|
||||
pass
|
||||
await asyncio.gather(*[close_client(client) for client in self._clients.values()], return_exceptions=False)
|
||||
|
||||
|
@ -164,11 +175,9 @@ class Network:
|
|||
kwargs_clients = {}
|
||||
if 'verify' in kwargs:
|
||||
kwargs_clients['verify'] = kwargs.pop('verify')
|
||||
if 'max_redirects' in kwargs:
|
||||
kwargs_clients['max_redirects'] = kwargs.pop('max_redirects')
|
||||
return kwargs_clients
|
||||
|
||||
def is_valid_respones(self, response):
|
||||
def is_valid_response(self, response):
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if ((self.retry_on_http_error is True and 400 <= response.status_code <= 599)
|
||||
or (isinstance(self.retry_on_http_error, list) and response.status_code in self.retry_on_http_error)
|
||||
|
@ -177,33 +186,24 @@ class Network:
|
|||
return False
|
||||
return True
|
||||
|
||||
async def request(self, method, url, **kwargs):
|
||||
async def request(self, method, url, stream=False, **kwargs) -> Response:
|
||||
retries = self.retries
|
||||
yarl_url = URL(url)
|
||||
while retries >= 0: # pragma: no cover
|
||||
kwargs_clients = Network.get_kwargs_clients(kwargs)
|
||||
client = self.get_client(**kwargs_clients)
|
||||
client = self.get_client(yarl_url, **kwargs_clients)
|
||||
kwargs.setdefault('max_redirects', self.max_redirects)
|
||||
try:
|
||||
response = await client.request(method, url, **kwargs)
|
||||
if self.is_valid_respones(response) or retries <= 0:
|
||||
return response
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
aiohttp_response: aiohttp.ClientResponse = await client.request(method, yarl_url, **kwargs)
|
||||
logger.debug('HTTP request "%s %s" %r', method.upper(), url, aiohttp_response.status)
|
||||
response = await Response.new_response(aiohttp_response, stream=stream)
|
||||
if self.is_valid_response(response) or retries <= 0:
|
||||
break
|
||||
except aiohttp.ClientError as e:
|
||||
if retries <= 0:
|
||||
raise e
|
||||
retries -= 1
|
||||
|
||||
def stream(self, method, url, **kwargs):
|
||||
retries = self.retries
|
||||
while retries >= 0: # pragma: no cover
|
||||
kwargs_clients = Network.get_kwargs_clients(kwargs)
|
||||
client = self.get_client(**kwargs_clients)
|
||||
try:
|
||||
response = client.stream(method, url, **kwargs)
|
||||
if self.is_valid_respones(response) or retries <= 0:
|
||||
return response
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
if retries <= 0:
|
||||
raise e
|
||||
retries -= 1
|
||||
|
||||
@classmethod
|
||||
async def aclose_all(cls):
|
||||
|
|
122
searx/network/response.py
Normal file
122
searx/network/response.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# lint: pylint
|
||||
# pylint: disable=missing-module-docstring, missing-function-docstring, invalid-name, fixme
|
||||
|
||||
from typing import Optional, Type, Any
|
||||
from types import TracebackType
|
||||
import json as jsonlib
|
||||
|
||||
import aiohttp
|
||||
import httpx._utils
|
||||
|
||||
|
||||
class Response:
|
||||
"""Look alike requests.Response from an aiohttp.ClientResponse
|
||||
|
||||
Only the required methods and attributes are implemented
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def new_response(cls, aiohttp_response: aiohttp.ClientResponse, stream=False) -> "Response":
|
||||
if stream:
|
||||
# streamed
|
||||
return StreamResponse(aiohttp_response)
|
||||
# not streamed
|
||||
await aiohttp_response.read()
|
||||
response = ContentResponse(aiohttp_response)
|
||||
await aiohttp_response.release()
|
||||
return response
|
||||
|
||||
def __init__(self, aio_response: aiohttp.ClientResponse):
|
||||
self._aio_response = aio_response
|
||||
# TODO check if it is the original request or the last one
|
||||
self.request = aio_response.request_info
|
||||
self.url = aio_response.request_info.url
|
||||
self.ok = aio_response.status < 400
|
||||
self.cookies = aio_response.cookies
|
||||
self.headers = aio_response.headers
|
||||
self.content = aio_response._body
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self._aio_response.get_encoding()
|
||||
|
||||
@property
|
||||
def status_code(self):
|
||||
return self._aio_response.status
|
||||
|
||||
@property
|
||||
def reason_phrase(self):
|
||||
return self._aio_response.reason
|
||||
|
||||
@property
|
||||
def elapsed(self):
|
||||
return 0
|
||||
|
||||
@property
|
||||
def links(self):
|
||||
return self._aio_response.links
|
||||
|
||||
def raise_for_status(self):
|
||||
return self._aio_response.raise_for_status()
|
||||
|
||||
@property
|
||||
def history(self):
|
||||
return [
|
||||
StreamResponse(r)
|
||||
for r in self._aio_response.history
|
||||
]
|
||||
|
||||
def close(self):
|
||||
return self._aio_response.release()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
ascii_encodable_url = str(self.url)
|
||||
if self.reason_phrase:
|
||||
ascii_encodable_reason = self.reason_phrase.encode(
|
||||
"ascii", "backslashreplace"
|
||||
).decode("ascii")
|
||||
else:
|
||||
ascii_encodable_reason = self.reason_phrase
|
||||
return "<{}({}) [{} {}]>".format(
|
||||
type(self).__name__, ascii_encodable_url, self.status_code, ascii_encodable_reason
|
||||
)
|
||||
|
||||
|
||||
class ContentResponse(Response):
|
||||
"""Similar to requests.Response
|
||||
"""
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
encoding = self._aio_response.get_encoding()
|
||||
return self.content.decode(encoding, errors='strict') # type: ignore
|
||||
|
||||
def json(self, **kwargs: Any) -> Any:
|
||||
stripped = self.content.strip() # type: ignore
|
||||
encoding = self._aio_response.get_encoding()
|
||||
if encoding is None and self.content and len(stripped) > 3:
|
||||
encoding = httpx._utils.guess_json_utf(stripped)
|
||||
if encoding is not None:
|
||||
return jsonlib.loads(self.content.decode(encoding), **kwargs)
|
||||
return jsonlib.loads(stripped.decode(encoding), **kwargs)
|
||||
|
||||
|
||||
class StreamResponse(Response):
|
||||
"""Streamed response, no .content, .text, .json()
|
||||
"""
|
||||
|
||||
async def __aenter__(self) -> "StreamResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self._aio_response.release()
|
||||
|
||||
async def iter_content(self, chunk_size=1):
|
||||
# no decode_unicode parameter
|
||||
return await self._aio_response.content.read(chunk_size)
|
113
searx/network/utils.py
Normal file
113
searx/network/utils.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
import typing
|
||||
import re
|
||||
|
||||
|
||||
from yarl import URL
|
||||
|
||||
|
||||
class URLPattern:
|
||||
"""
|
||||
A utility class currently used for making lookups against proxy keys...
|
||||
# Wildcard matching...
|
||||
>>> pattern = URLPattern("all")
|
||||
>>> pattern.matches(yarl.URL("http://example.com"))
|
||||
True
|
||||
# Witch scheme matching...
|
||||
>>> pattern = URLPattern("https")
|
||||
>>> pattern.matches(yarl.URL("https://example.com"))
|
||||
True
|
||||
>>> pattern.matches(yarl.URL("http://example.com"))
|
||||
False
|
||||
# With domain matching...
|
||||
>>> pattern = URLPattern("https://example.com")
|
||||
>>> pattern.matches(yarl.URL("https://example.com"))
|
||||
True
|
||||
>>> pattern.matches(yarl.URL("http://example.com"))
|
||||
False
|
||||
>>> pattern.matches(yarl.URL("https://other.com"))
|
||||
False
|
||||
# Wildcard scheme, with domain matching...
|
||||
>>> pattern = URLPattern("all://example.com")
|
||||
>>> pattern.matches(yarl.URL("https://example.com"))
|
||||
True
|
||||
>>> pattern.matches(yarl.URL("http://example.com"))
|
||||
True
|
||||
>>> pattern.matches(yarl.URL("https://other.com"))
|
||||
False
|
||||
# With port matching...
|
||||
>>> pattern = URLPattern("https://example.com:1234")
|
||||
>>> pattern.matches(yarl.URL("https://example.com:1234"))
|
||||
True
|
||||
>>> pattern.matches(yarl.URL("https://example.com"))
|
||||
False
|
||||
"""
|
||||
|
||||
def __init__(self, pattern: str) -> None:
|
||||
if pattern and ":" not in pattern:
|
||||
raise ValueError(
|
||||
f"Proxy keys should use proper URL forms rather "
|
||||
f"than plain scheme strings. "
|
||||
f'Instead of "{pattern}", use "{pattern}://"'
|
||||
)
|
||||
|
||||
url = URL(pattern)
|
||||
self.pattern = pattern
|
||||
self.scheme = "" if url.scheme == "all" else url.scheme
|
||||
self.host = "" if url.host == "*" else url.host
|
||||
self.port = url.port
|
||||
if not url.host or url.host == "*":
|
||||
self.host_regex: typing.Optional[typing.Pattern[str]] = None
|
||||
else:
|
||||
if url.host.startswith("*."):
|
||||
# *.example.com should match "www.example.com", but not "example.com"
|
||||
domain = re.escape(url.host[2:])
|
||||
self.host_regex = re.compile(f"^.+\\.{domain}$")
|
||||
elif url.host.startswith("*"):
|
||||
# *example.com should match "www.example.com" and "example.com"
|
||||
domain = re.escape(url.host[1:])
|
||||
self.host_regex = re.compile(f"^(.+\\.)?{domain}$")
|
||||
else:
|
||||
# example.com should match "example.com" but not "www.example.com"
|
||||
domain = re.escape(url.host)
|
||||
self.host_regex = re.compile(f"^{domain}$")
|
||||
|
||||
def matches(self, other: URL) -> bool:
|
||||
if self.scheme and self.scheme != other.scheme:
|
||||
return False
|
||||
if (
|
||||
self.host
|
||||
and self.host_regex is not None
|
||||
and not self.host_regex.match(other.host or '')
|
||||
):
|
||||
return False
|
||||
if self.port is not None and self.port != other.port:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def priority(self) -> tuple:
|
||||
"""
|
||||
The priority allows URLPattern instances to be sortable, so that
|
||||
we can match from most specific to least specific.
|
||||
"""
|
||||
# URLs with a port should take priority over URLs without a port.
|
||||
port_priority = 0 if self.port is not None else 1
|
||||
# Longer hostnames should match first.
|
||||
host_priority = -len(self.host or '')
|
||||
# Longer schemes should match first.
|
||||
scheme_priority = -len(self.scheme)
|
||||
return (port_priority, host_priority, scheme_priority)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.pattern)
|
||||
|
||||
def __lt__(self, other: "URLPattern") -> bool:
|
||||
return self.priority < other.priority
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, URLPattern) and self.pattern == other.pattern
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<URLPattern pattern=\"{self.pattern}\">"
|
|
@ -4,6 +4,7 @@ import typing
|
|||
import types
|
||||
import functools
|
||||
import itertools
|
||||
import asyncio
|
||||
from time import time
|
||||
from timeit import default_timer
|
||||
from urllib.parse import urlparse
|
||||
|
@ -11,7 +12,7 @@ from urllib.parse import urlparse
|
|||
import re
|
||||
from langdetect import detect_langs
|
||||
from langdetect.lang_detect_exception import LangDetectException
|
||||
import httpx
|
||||
import aiohttp
|
||||
|
||||
from searx import network, logger
|
||||
from searx.results import ResultContainer
|
||||
|
@ -91,10 +92,10 @@ def _is_url_image(image_url):
|
|||
if r.headers["content-type"].startswith('image/'):
|
||||
return True
|
||||
return False
|
||||
except httpx.TimeoutException:
|
||||
except asyncio.TimeoutError:
|
||||
logger.error('Timeout for %s: %i', image_url, int(time() - a))
|
||||
retry -= 1
|
||||
except httpx.HTTPError:
|
||||
except aiohttp.ClientError:
|
||||
logger.exception('Exception for %s', image_url)
|
||||
return False
|
||||
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
|
||||
from timeit import default_timer
|
||||
import asyncio
|
||||
import httpx
|
||||
import aiohttp
|
||||
import python_socks
|
||||
|
||||
import searx.network
|
||||
from searx import logger
|
||||
|
@ -143,7 +144,7 @@ class OnlineProcessor(EngineProcessor):
|
|||
# send requests and parse the results
|
||||
search_results = self._search_basic(query, params)
|
||||
self.extend_container(result_container, start_time, search_results)
|
||||
except (httpx.TimeoutException, asyncio.TimeoutError) as e:
|
||||
except (asyncio.TimeoutError, python_socks.ProxyTimeoutError) as e:
|
||||
# requests timeout (connect or read)
|
||||
self.handle_exception(result_container, e, suspend=True)
|
||||
logger.error("engine {0} : HTTP requests timeout"
|
||||
|
@ -151,7 +152,7 @@ class OnlineProcessor(EngineProcessor):
|
|||
.format(self.engine_name, default_timer() - start_time,
|
||||
timeout_limit,
|
||||
e.__class__.__name__))
|
||||
except (httpx.HTTPError, httpx.StreamError) as e:
|
||||
except (aiohttp.ClientError, python_socks.ProxyError, python_socks.ProxyConnectionError) as e:
|
||||
# other requests exception
|
||||
self.handle_exception(result_container, e, suspend=True)
|
||||
logger.exception("engine {0} : requests exception"
|
||||
|
|
|
@ -178,7 +178,7 @@ SCHEMA = {
|
|||
'pool_connections': SettingsValue(int, 100),
|
||||
# Picked from constructor
|
||||
'pool_maxsize': SettingsValue(int, 10),
|
||||
'keepalive_expiry': SettingsValue(numbers.Real, 5.0),
|
||||
'keepalive_expiry': SettingsValue(numbers.Real, 115.0),
|
||||
# default maximum redirect
|
||||
# from https://github.com/psf/requests/blob/8c211a96cdbe9fe320d63d9e1ae15c5c07e179f8/requests/models.py#L55
|
||||
'max_redirects': SettingsValue(int, 30),
|
||||
|
|
|
@ -19,7 +19,7 @@ from io import StringIO
|
|||
import urllib
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
|
||||
from pygments import highlight
|
||||
from pygments.lexers import get_lexer_by_name
|
||||
|
@ -205,14 +205,14 @@ exception_classname_to_text = {
|
|||
'httpx.ConnectTimeout': timeout_text,
|
||||
'httpx.ReadTimeout': timeout_text,
|
||||
'httpx.WriteTimeout': timeout_text,
|
||||
'httpx.HTTPStatusError': gettext('HTTP error'),
|
||||
'aiohttp.client_exceptions.ClientResponseError': gettext('HTTP error'),
|
||||
'httpx.ConnectError': gettext("HTTP connection error"),
|
||||
'httpx.RemoteProtocolError': http_protocol_error_text,
|
||||
'httpx.LocalProtocolError': http_protocol_error_text,
|
||||
'httpx.ProtocolError': http_protocol_error_text,
|
||||
'httpx.ReadError': network_error_text,
|
||||
'httpx.WriteError': network_error_text,
|
||||
'httpx.ProxyError': gettext("proxy error"),
|
||||
'python_socks._errors.ProxyConnectionError': gettext("proxy error"),
|
||||
'searx.exceptions.SearxEngineCaptchaException': gettext("CAPTCHA"),
|
||||
'searx.exceptions.SearxEngineTooManyRequestsException': gettext("too many requests"),
|
||||
'searx.exceptions.SearxEngineAccessDeniedException': gettext("access denied"),
|
||||
|
@ -1110,7 +1110,7 @@ def image_proxy():
|
|||
return '', 400
|
||||
|
||||
forward_resp = True
|
||||
except httpx.HTTPError:
|
||||
except aiohttp.ClientError:
|
||||
logger.exception('HTTP error')
|
||||
return '', 400
|
||||
finally:
|
||||
|
@ -1119,7 +1119,7 @@ def image_proxy():
|
|||
# we make sure to close the response between searxng and the HTTP server
|
||||
try:
|
||||
resp.close()
|
||||
except httpx.HTTPError:
|
||||
except aiohttp.ClientError:
|
||||
logger.exception('HTTP error on closing')
|
||||
|
||||
try:
|
||||
|
@ -1137,7 +1137,7 @@ def image_proxy():
|
|||
yield chunk
|
||||
|
||||
return Response(forward_chunk(), mimetype=resp.headers['Content-Type'], headers=headers)
|
||||
except httpx.HTTPError:
|
||||
except aiohttp.ClientError:
|
||||
return '', 400
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,59 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
import asyncio
|
||||
from searx.network.utils import URLPattern
|
||||
from mock import patch
|
||||
|
||||
import httpx
|
||||
import aiohttp
|
||||
from multidict import CIMultiDict, CIMultiDictProxy
|
||||
import yarl
|
||||
|
||||
from searx.network.network import Network, NETWORKS, initialize
|
||||
from searx.testing import SearxTestCase
|
||||
|
||||
|
||||
def create_fake_response(url, method='GET', content='', status_code=200):
|
||||
if isinstance(url, str):
|
||||
url = yarl.URL(url)
|
||||
if not isinstance(url, yarl.URL):
|
||||
raise ValueError('url must be of type yarl.URL. Currently of type ' + str(type(url)))
|
||||
loop = asyncio.get_event_loop()
|
||||
request_info = aiohttp.RequestInfo(
|
||||
url,
|
||||
method,
|
||||
CIMultiDictProxy(CIMultiDict()),
|
||||
url
|
||||
)
|
||||
response = aiohttp.ClientResponse(
|
||||
method,
|
||||
url,
|
||||
writer=None,
|
||||
continue100=False,
|
||||
timer=None,
|
||||
request_info=request_info,
|
||||
traces=[],
|
||||
loop=loop,
|
||||
session=None
|
||||
)
|
||||
|
||||
async def async_nothing():
|
||||
pass
|
||||
|
||||
def iter_content(_):
|
||||
yield content.encode()
|
||||
|
||||
response.status = status_code
|
||||
response._headers = {}
|
||||
response.read = async_nothing
|
||||
response.close = lambda: None
|
||||
response.release = async_nothing
|
||||
response.content = content.encode()
|
||||
response._body = response.content
|
||||
response.get_encoding = lambda: 'utf-8'
|
||||
response.iter_content = iter_content
|
||||
return response
|
||||
|
||||
|
||||
class TestNetwork(SearxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -49,25 +95,26 @@ class TestNetwork(SearxTestCase):
|
|||
|
||||
def test_proxy_cycles(self):
|
||||
network = Network(proxies='http://localhost:1337')
|
||||
self.assertEqual(next(network._proxies_cycle), (('all://', 'http://localhost:1337'),))
|
||||
P = URLPattern
|
||||
self.assertEqual(next(network._proxies_cycle), ((P('all://'), 'http://localhost:1337'),))
|
||||
|
||||
network = Network(proxies={
|
||||
'https': 'http://localhost:1337',
|
||||
'http': 'http://localhost:1338'
|
||||
})
|
||||
self.assertEqual(next(network._proxies_cycle),
|
||||
(('https://', 'http://localhost:1337'), ('http://', 'http://localhost:1338')))
|
||||
((P('https://'), 'http://localhost:1337'), (P('http://'), 'http://localhost:1338')))
|
||||
self.assertEqual(next(network._proxies_cycle),
|
||||
(('https://', 'http://localhost:1337'), ('http://', 'http://localhost:1338')))
|
||||
((P('https://'), 'http://localhost:1337'), (P('http://'), 'http://localhost:1338')))
|
||||
|
||||
network = Network(proxies={
|
||||
'https': ['http://localhost:1337', 'http://localhost:1339'],
|
||||
'http': 'http://localhost:1338'
|
||||
})
|
||||
self.assertEqual(next(network._proxies_cycle),
|
||||
(('https://', 'http://localhost:1337'), ('http://', 'http://localhost:1338')))
|
||||
((P('https://'), 'http://localhost:1337'), (P('http://'), 'http://localhost:1338')))
|
||||
self.assertEqual(next(network._proxies_cycle),
|
||||
(('https://', 'http://localhost:1339'), ('http://', 'http://localhost:1338')))
|
||||
((P('https://'), 'http://localhost:1339'), (P('http://'), 'http://localhost:1338')))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
Network(proxies=1)
|
||||
|
@ -75,45 +122,38 @@ class TestNetwork(SearxTestCase):
|
|||
def test_get_kwargs_clients(self):
|
||||
kwargs = {
|
||||
'verify': True,
|
||||
'max_redirects': 5,
|
||||
'timeout': 2,
|
||||
}
|
||||
kwargs_client = Network.get_kwargs_clients(kwargs)
|
||||
|
||||
self.assertEqual(len(kwargs_client), 2)
|
||||
self.assertEqual(len(kwargs), 1)
|
||||
|
||||
self.assertEqual(kwargs['timeout'], 2)
|
||||
self.assertEqual(len(kwargs_client), 1)
|
||||
self.assertEqual(len(kwargs), 0)
|
||||
|
||||
self.assertTrue(kwargs_client['verify'])
|
||||
self.assertEqual(kwargs_client['max_redirects'], 5)
|
||||
|
||||
async def test_get_client(self):
|
||||
network = Network(verify=True)
|
||||
client1 = network.get_client()
|
||||
client2 = network.get_client(verify=True)
|
||||
client3 = network.get_client(max_redirects=10)
|
||||
client4 = network.get_client(verify=True)
|
||||
client5 = network.get_client(verify=False)
|
||||
client6 = network.get_client(max_redirects=10)
|
||||
url = 'https://example.com'
|
||||
client1 = network.get_client(url)
|
||||
client2 = network.get_client(url, verify=True)
|
||||
client3 = network.get_client(url, verify=False)
|
||||
|
||||
self.assertEqual(client1, client2)
|
||||
self.assertEqual(client1, client4)
|
||||
self.assertNotEqual(client1, client3)
|
||||
self.assertNotEqual(client1, client5)
|
||||
self.assertEqual(client3, client6)
|
||||
|
||||
await network.aclose()
|
||||
|
||||
async def test_aclose(self):
|
||||
network = Network(verify=True)
|
||||
network.get_client()
|
||||
network.get_client('https://example.com')
|
||||
await network.aclose()
|
||||
|
||||
async def test_request(self):
|
||||
a_text = 'Lorem Ipsum'
|
||||
response = httpx.Response(status_code=200, text=a_text)
|
||||
with patch.object(httpx.AsyncClient, 'request', return_value=response):
|
||||
|
||||
async def get_response(*args, **kwargs):
|
||||
return create_fake_response(url='https://example.com/', status_code=200, content=a_text)
|
||||
|
||||
with patch.object(aiohttp.ClientSession, 'request', new=get_response):
|
||||
network = Network(enable_http=True)
|
||||
response = await network.request('GET', 'https://example.com/')
|
||||
self.assertEqual(response.text, a_text)
|
||||
|
@ -128,37 +168,47 @@ class TestNetworkRequestRetries(SearxTestCase):
|
|||
def get_response_404_then_200(cls):
|
||||
first = True
|
||||
|
||||
async def get_response(*args, **kwargs):
|
||||
async def get_response(method, url, *args, **kwargs):
|
||||
nonlocal first
|
||||
if first:
|
||||
first = False
|
||||
return httpx.Response(status_code=403, text=TestNetworkRequestRetries.TEXT)
|
||||
return httpx.Response(status_code=200, text=TestNetworkRequestRetries.TEXT)
|
||||
return create_fake_response(
|
||||
method=method,
|
||||
url=url,
|
||||
status_code=403,
|
||||
content=TestNetworkRequestRetries.TEXT
|
||||
)
|
||||
return create_fake_response(
|
||||
method=method,
|
||||
url=url,
|
||||
status_code=200,
|
||||
content=TestNetworkRequestRetries.TEXT
|
||||
)
|
||||
return get_response
|
||||
|
||||
async def test_retries_ok(self):
|
||||
with patch.object(httpx.AsyncClient, 'request', new=TestNetworkRequestRetries.get_response_404_then_200()):
|
||||
with patch.object(aiohttp.ClientSession, 'request', new=TestNetworkRequestRetries.get_response_404_then_200()):
|
||||
network = Network(enable_http=True, retries=1, retry_on_http_error=403)
|
||||
response = await network.request('GET', 'https://example.com/')
|
||||
self.assertEqual(response.text, TestNetworkRequestRetries.TEXT)
|
||||
await network.aclose()
|
||||
|
||||
async def test_retries_fail_int(self):
|
||||
with patch.object(httpx.AsyncClient, 'request', new=TestNetworkRequestRetries.get_response_404_then_200()):
|
||||
with patch.object(aiohttp.ClientSession, 'request', new=TestNetworkRequestRetries.get_response_404_then_200()):
|
||||
network = Network(enable_http=True, retries=0, retry_on_http_error=403)
|
||||
response = await network.request('GET', 'https://example.com/')
|
||||
self.assertEqual(response.status_code, 403)
|
||||
await network.aclose()
|
||||
|
||||
async def test_retries_fail_list(self):
|
||||
with patch.object(httpx.AsyncClient, 'request', new=TestNetworkRequestRetries.get_response_404_then_200()):
|
||||
with patch.object(aiohttp.ClientSession, 'request', new=TestNetworkRequestRetries.get_response_404_then_200()):
|
||||
network = Network(enable_http=True, retries=0, retry_on_http_error=[403, 429])
|
||||
response = await network.request('GET', 'https://example.com/')
|
||||
self.assertEqual(response.status_code, 403)
|
||||
await network.aclose()
|
||||
|
||||
async def test_retries_fail_bool(self):
|
||||
with patch.object(httpx.AsyncClient, 'request', new=TestNetworkRequestRetries.get_response_404_then_200()):
|
||||
with patch.object(aiohttp.ClientSession, 'request', new=TestNetworkRequestRetries.get_response_404_then_200()):
|
||||
network = Network(enable_http=True, retries=0, retry_on_http_error=True)
|
||||
response = await network.request('GET', 'https://example.com/')
|
||||
self.assertEqual(response.status_code, 403)
|
||||
|
@ -167,14 +217,19 @@ class TestNetworkRequestRetries(SearxTestCase):
|
|||
async def test_retries_exception_then_200(self):
|
||||
request_count = 0
|
||||
|
||||
async def get_response(*args, **kwargs):
|
||||
async def get_response(method, url, *args, **kwargs):
|
||||
nonlocal request_count
|
||||
request_count += 1
|
||||
if request_count < 3:
|
||||
raise httpx.RequestError('fake exception', request=None)
|
||||
return httpx.Response(status_code=200, text=TestNetworkRequestRetries.TEXT)
|
||||
raise aiohttp.ClientError()
|
||||
return create_fake_response(
|
||||
url,
|
||||
method,
|
||||
status_code=200,
|
||||
content=TestNetworkRequestRetries.TEXT
|
||||
)
|
||||
|
||||
with patch.object(httpx.AsyncClient, 'request', new=get_response):
|
||||
with patch.object(aiohttp.ClientSession, 'request', new=get_response):
|
||||
network = Network(enable_http=True, retries=2)
|
||||
response = await network.request('GET', 'https://example.com/')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
@ -182,12 +237,12 @@ class TestNetworkRequestRetries(SearxTestCase):
|
|||
await network.aclose()
|
||||
|
||||
async def test_retries_exception(self):
|
||||
async def get_response(*args, **kwargs):
|
||||
raise httpx.RequestError('fake exception', request=None)
|
||||
def get_response(*args, **kwargs):
|
||||
raise aiohttp.ClientError()
|
||||
|
||||
with patch.object(httpx.AsyncClient, 'request', new=get_response):
|
||||
with patch.object(aiohttp.ClientSession, 'request', new=get_response):
|
||||
network = Network(enable_http=True, retries=0)
|
||||
with self.assertRaises(httpx.RequestError):
|
||||
with self.assertRaises(aiohttp.ClientError):
|
||||
await network.request('GET', 'https://example.com/')
|
||||
await network.aclose()
|
||||
|
||||
|
@ -200,40 +255,48 @@ class TestNetworkStreamRetries(SearxTestCase):
|
|||
def get_response_exception_then_200(cls):
|
||||
first = True
|
||||
|
||||
def stream(*args, **kwargs):
|
||||
async def stream(method, url, *args, **kwargs):
|
||||
nonlocal first
|
||||
if first:
|
||||
first = False
|
||||
raise httpx.RequestError('fake exception', request=None)
|
||||
return httpx.Response(status_code=200, text=TestNetworkStreamRetries.TEXT)
|
||||
raise aiohttp.ClientError()
|
||||
return create_fake_response(url, method, content=TestNetworkStreamRetries.TEXT, status_code=200)
|
||||
return stream
|
||||
|
||||
async def test_retries_ok(self):
|
||||
with patch.object(httpx.AsyncClient, 'stream', new=TestNetworkStreamRetries.get_response_exception_then_200()):
|
||||
with patch.object(
|
||||
aiohttp.ClientSession,
|
||||
'request',
|
||||
new=TestNetworkStreamRetries.get_response_exception_then_200()
|
||||
):
|
||||
network = Network(enable_http=True, retries=1, retry_on_http_error=403)
|
||||
response = network.stream('GET', 'https://example.com/')
|
||||
response = await network.request('GET', 'https://example.com/', read_response=False)
|
||||
self.assertEqual(response.text, TestNetworkStreamRetries.TEXT)
|
||||
await network.aclose()
|
||||
|
||||
async def test_retries_fail(self):
|
||||
with patch.object(httpx.AsyncClient, 'stream', new=TestNetworkStreamRetries.get_response_exception_then_200()):
|
||||
with patch.object(
|
||||
aiohttp.ClientSession,
|
||||
'request',
|
||||
new=TestNetworkStreamRetries.get_response_exception_then_200()
|
||||
):
|
||||
network = Network(enable_http=True, retries=0, retry_on_http_error=403)
|
||||
with self.assertRaises(httpx.RequestError):
|
||||
network.stream('GET', 'https://example.com/')
|
||||
with self.assertRaises(aiohttp.ClientError):
|
||||
await network.request('GET', 'https://example.com/', read_response=False)
|
||||
await network.aclose()
|
||||
|
||||
async def test_retries_exception(self):
|
||||
first = True
|
||||
|
||||
def stream(*args, **kwargs):
|
||||
async def request(method, url, *args, **kwargs):
|
||||
nonlocal first
|
||||
if first:
|
||||
first = False
|
||||
return httpx.Response(status_code=403, text=TestNetworkRequestRetries.TEXT)
|
||||
return httpx.Response(status_code=200, text=TestNetworkRequestRetries.TEXT)
|
||||
return create_fake_response(url, method, status_code=403, content=TestNetworkRequestRetries.TEXT)
|
||||
return create_fake_response(url, method, status_code=200, content=TestNetworkRequestRetries.TEXT)
|
||||
|
||||
with patch.object(httpx.AsyncClient, 'stream', new=stream):
|
||||
with patch.object(aiohttp.ClientSession, 'request', new=request):
|
||||
network = Network(enable_http=True, retries=0, retry_on_http_error=403)
|
||||
response = network.stream('GET', 'https://example.com/')
|
||||
response = await network.request('GET', 'https://example.com/', read_response=False)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
await network.aclose()
|
||||
|
|
48
tests/unit/network/test_utils.py
Normal file
48
tests/unit/network/test_utils.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
import random
|
||||
from yarl import URL
|
||||
|
||||
from searx.network.utils import URLPattern
|
||||
from searx.testing import SearxTestCase
|
||||
|
||||
|
||||
class TestNetworkUtils(SearxTestCase):
|
||||
def test_pattern_priority(self):
|
||||
matchers = [
|
||||
URLPattern("all://"),
|
||||
URLPattern("http://"),
|
||||
URLPattern("http://example.com"),
|
||||
URLPattern("http://example.com:123"),
|
||||
]
|
||||
random.shuffle(matchers)
|
||||
self.maxDiff = None
|
||||
self.assertEqual(
|
||||
sorted(matchers),
|
||||
[
|
||||
URLPattern("http://example.com:123"),
|
||||
URLPattern("http://example.com"),
|
||||
URLPattern("http://"),
|
||||
URLPattern("all://"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_url_matches(self):
|
||||
parameters = [
|
||||
("http://example.com", "http://example.com", True),
|
||||
("http://example.com", "https://example.com", False),
|
||||
("http://example.com", "http://other.com", False),
|
||||
("http://example.com:123", "http://example.com:123", True),
|
||||
("http://example.com:123", "http://example.com:456", False),
|
||||
("http://example.com:123", "http://example.com", False),
|
||||
("all://example.com", "http://example.com", True),
|
||||
("all://example.com", "https://example.com", True),
|
||||
("http://", "http://example.com", True),
|
||||
("http://", "https://example.com", False),
|
||||
("all://", "https://example.com:123", True),
|
||||
("", "https://example.com:123", True),
|
||||
]
|
||||
|
||||
for pattern, url, expected in parameters:
|
||||
pattern = URLPattern(pattern)
|
||||
self.assertEqual(pattern.matches(URL(url)), expected)
|
Loading…
Add table
Reference in a new issue