diff --git a/searx/engines/__init__.py b/searx/engines/__init__.py index 9a6706871..bcfd20455 100644 --- a/searx/engines/__init__.py +++ b/searx/engines/__init__.py @@ -45,7 +45,7 @@ ENGINE_DEFAULT_ARGS = { "using_tor_proxy": False, "display_error_messages": True, "send_accept_language_header": False, - "rate_limit": [{"max_requests": float('inf'), "interval": 1}], + "rate_limit": [{"max_requests": None, "interval": 1}], "tokens": [], "about": {}, } diff --git a/searx/search/__init__.py b/searx/search/__init__.py index cc85a834c..9d337916c 100644 --- a/searx/search/__init__.py +++ b/searx/search/__init__.py @@ -98,10 +98,6 @@ class Search: if request_params is None: continue - # stop request if it exceeds engine's rate limit - if processor.exceeds_rate_limit(): - continue - counter_inc('engine', engineref.name, 'search', 'count', 'sent') # append request to list diff --git a/searx/search/processors/abstract.py b/searx/search/processors/abstract.py index a775fd904..7ba7c521c 100644 --- a/searx/search/processors/abstract.py +++ b/searx/search/processors/abstract.py @@ -136,34 +136,27 @@ class EngineProcessor(ABC): self.engine_name, self.suspended_status.suspend_reason, suspended=True ) return True + if self.exceeds_rate_limit(): + result_container.add_unresponsive_engine(self.engine_name, "Exceed rate limit", suspended=True) + return True return False def exceeds_rate_limit(self): - def check_rate_limiter(engine_name, max_requests, interval): - key = f'rate_limiter_{engine_name}_{max_requests}r/{interval}s' - # check requests count - count = storage.get_int(key) - if count is None: - # initialize counter with expiration time - storage.set_int(key, 1, interval) - elif count >= max_requests: - logger.debug(f"{engine_name} exceeded rate limit of {max_requests} requests per {interval} seconds") - return True - else: - # update counter - storage.set_int(key, count + 1) - return False - result = False + # add counter to all of the engine's rate limiters for rate_limit in self.engine.rate_limit: max_requests = rate_limit['max_requests'] interval = rate_limit.get('interval', 1) - if max_requests == float('inf'): + if not max_requests: continue - if run_locked(check_rate_limiter, self.engine_name, max_requests, interval): + name = f'{self.engine_name}_{max_requests}r/{interval}s' + if run_locked(storage.incr_counter, name, max_requests, interval) >= max_requests: + logger.debug( + f"{self.engine_name} exceeded rate limit of {max_requests} requests per {interval} seconds" + ) result = True return result diff --git a/searx/shared/shared_abstract.py b/searx/shared/shared_abstract.py index 32f9d98bf..31478566e 100644 --- a/searx/shared/shared_abstract.py +++ b/searx/shared/shared_abstract.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # pyright: strict +import hmac from abc import ABC, abstractmethod from typing import Optional +from searx import get_setting + class SharedDict(ABC): @abstractmethod @@ -20,3 +23,23 @@ class SharedDict(ABC): @abstractmethod def set_str(self, key: str, value: str, expire: Optional[int] = None): pass + + def incr_counter(self, name: str, limit: int = 0, expire: int = 0) -> int: + # generate dict key from name + m = hmac.new(bytes(name, encoding='utf-8'), digestmod='sha256') + m.update(bytes(get_setting('server.secret_key'), encoding='utf-8')) + key = 'SearXNG_counter_' + m.hexdigest() + + # check requests count + count = self.get_int(key) + if count is None: + # initialize counter with expiration time + self.set_int(key, 1, expire) + return 1 + elif limit >= count or not limit: + # update counter + new_count = count + 1 + self.set_int(key, new_count, expire) + return new_count + else: + return count diff --git a/searx/shared/shared_simple.py b/searx/shared/shared_simple.py index 0b0866982..f66f0c78a 100644 --- a/searx/shared/shared_simple.py +++ b/searx/shared/shared_simple.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: AGPL-3.0-or-later +import time import threading from typing import Optional @@ -12,27 +13,32 @@ class SimpleSharedDict(shared_abstract.SharedDict): def __init__(self): self.d = {} + self.expire_times = {} + schedule(1, self._expire) def get_int(self, key: str) -> Optional[int]: return self.d.get(key, None) def set_int(self, key: str, value: int, expire: Optional[int] = None): self.d[key] = value - if expire: - self._expire(key, expire) + if expire and not self.expire_times.get(key): + self.expire_times[key] = (time.time(), expire) def get_str(self, key: str) -> Optional[str]: return self.d.get(key, None) def set_str(self, key: str, value: str, expire: Optional[int] = None): self.d[key] = value - if expire: - self._expire(key, expire) + if expire and not self.expire_times.get(key): + self.expire_times[key] = (time.time(), expire) - def _expire(self, key: str, expire: int): - t = threading.Timer(expire, lambda k, d: d.pop(k), args=[key, self.d]) - t.daemon = True - t.start() + def _expire(self): + now = time.time() + for key, val in self.expire_times.items(): + created_at, expire = val + if now - created_at >= expire: + self.d.pop(key) + self.expire_times.pop(key) def run_locked(func, *args): diff --git a/searx/shared/shared_uwsgi.py b/searx/shared/shared_uwsgi.py index d2c250bf6..aa19ce7e9 100644 --- a/searx/shared/shared_uwsgi.py +++ b/searx/shared/shared_uwsgi.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: AGPL-3.0-or-later +import json import time -from typing import Optional -import threading +from typing import Optional, Tuple, Union import uwsgi # pyright: ignore # pylint: disable=E0401 from . import shared_abstract @@ -12,35 +12,50 @@ _last_signal = 10 class UwsgiCacheSharedDict(shared_abstract.SharedDict): def get_int(self, key: str) -> Optional[int]: - value = uwsgi.cache_get(key) + value, _, _ = self._get_value(key) if value is None: return value else: - return int.from_bytes(value, 'big') + return int(value) def set_int(self, key: str, value: int, expire: Optional[int] = None): - b = value.to_bytes(4, 'big') - uwsgi.cache_update(key, b) - if expire: - self._expire(key, expire) + self._set_value(key, value, expire) def get_str(self, key: str) -> Optional[str]: - value = uwsgi.cache_get(key) + value, _, _ = self._get_value(key) if value is None: return value else: - return value.decode('utf-8') + return str(value) def set_str(self, key: str, value: str, expire: Optional[int] = None): - b = value.encode('utf-8') - uwsgi.cache_update(key, b) - if expire: - self._expire(key, expire) + self._set_value(key, value, expire) - def _expire(self, key: str, expire: int): - t = threading.Timer(expire, uwsgi.cache_del, args=[key]) - t.daemon = True - t.start() + def _get_value(self, key: str) -> Tuple[Optional[Union[str, int]], Optional[float], Optional[int]]: + serialized_data = uwsgi.cache_get(key) + if not serialized_data: + return None, None, None + else: + data = json.loads(serialized_data.decode()) + if 'expire' in data: + now = time.time() + if now - data['created_at'] >= data['expire']: + uwsgi.cache_del(key) + return None, None, None + return data.get('value'), data.get('created_at'), data.get('expire') + + def _set_value(self, key: str, value: Union[str, int], expire: Optional[int] = None): + _, created_at, original_expire = self._get_value(key) + + data = {'value': value} + if expire is None and created_at is None: + serialized_data = json.dumps(data).encode() + uwsgi.cache_update(key, serialized_data) + else: + data['created_at'] = created_at or time.time() + data['expire'] = original_expire or expire + serialized_data = json.dumps(data).encode() + uwsgi.cache_update(key, serialized_data, expire) def run_locked(func, *args):