mirror of
https://github.com/searxng/searxng
synced 2024-01-01 19:24:07 +01:00
fix rate limiting per engine
This commit is contained in:
parent
85eca4fb22
commit
17b9b334a3
6 changed files with 81 additions and 48 deletions
|
@ -45,7 +45,7 @@ ENGINE_DEFAULT_ARGS = {
|
||||||
"using_tor_proxy": False,
|
"using_tor_proxy": False,
|
||||||
"display_error_messages": True,
|
"display_error_messages": True,
|
||||||
"send_accept_language_header": False,
|
"send_accept_language_header": False,
|
||||||
"rate_limit": [{"max_requests": float('inf'), "interval": 1}],
|
"rate_limit": [{"max_requests": None, "interval": 1}],
|
||||||
"tokens": [],
|
"tokens": [],
|
||||||
"about": {},
|
"about": {},
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,10 +98,6 @@ class Search:
|
||||||
if request_params is None:
|
if request_params is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# stop request if it exceeds engine's rate limit
|
|
||||||
if processor.exceeds_rate_limit():
|
|
||||||
continue
|
|
||||||
|
|
||||||
counter_inc('engine', engineref.name, 'search', 'count', 'sent')
|
counter_inc('engine', engineref.name, 'search', 'count', 'sent')
|
||||||
|
|
||||||
# append request to list
|
# append request to list
|
||||||
|
|
|
@ -136,34 +136,27 @@ class EngineProcessor(ABC):
|
||||||
self.engine_name, self.suspended_status.suspend_reason, suspended=True
|
self.engine_name, self.suspended_status.suspend_reason, suspended=True
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
if self.exceeds_rate_limit():
|
||||||
|
result_container.add_unresponsive_engine(self.engine_name, "Exceed rate limit", suspended=True)
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def exceeds_rate_limit(self):
|
def exceeds_rate_limit(self):
|
||||||
def check_rate_limiter(engine_name, max_requests, interval):
|
|
||||||
key = f'rate_limiter_{engine_name}_{max_requests}r/{interval}s'
|
|
||||||
# check requests count
|
|
||||||
count = storage.get_int(key)
|
|
||||||
if count is None:
|
|
||||||
# initialize counter with expiration time
|
|
||||||
storage.set_int(key, 1, interval)
|
|
||||||
elif count >= max_requests:
|
|
||||||
logger.debug(f"{engine_name} exceeded rate limit of {max_requests} requests per {interval} seconds")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
# update counter
|
|
||||||
storage.set_int(key, count + 1)
|
|
||||||
return False
|
|
||||||
|
|
||||||
result = False
|
result = False
|
||||||
|
|
||||||
# add counter to all of the engine's rate limiters
|
# add counter to all of the engine's rate limiters
|
||||||
for rate_limit in self.engine.rate_limit:
|
for rate_limit in self.engine.rate_limit:
|
||||||
max_requests = rate_limit['max_requests']
|
max_requests = rate_limit['max_requests']
|
||||||
interval = rate_limit.get('interval', 1)
|
interval = rate_limit.get('interval', 1)
|
||||||
|
|
||||||
if max_requests == float('inf'):
|
if not max_requests:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if run_locked(check_rate_limiter, self.engine_name, max_requests, interval):
|
name = f'{self.engine_name}_{max_requests}r/{interval}s'
|
||||||
|
if run_locked(storage.incr_counter, name, max_requests, interval) >= max_requests:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.engine_name} exceeded rate limit of {max_requests} requests per {interval} seconds"
|
||||||
|
)
|
||||||
result = True
|
result = True
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# pyright: strict
|
# pyright: strict
|
||||||
|
import hmac
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from searx import get_setting
|
||||||
|
|
||||||
|
|
||||||
class SharedDict(ABC):
|
class SharedDict(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -20,3 +23,23 @@ class SharedDict(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_str(self, key: str, value: str, expire: Optional[int] = None):
|
def set_str(self, key: str, value: str, expire: Optional[int] = None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def incr_counter(self, name: str, limit: int = 0, expire: int = 0) -> int:
|
||||||
|
# generate dict key from name
|
||||||
|
m = hmac.new(bytes(name, encoding='utf-8'), digestmod='sha256')
|
||||||
|
m.update(bytes(get_setting('server.secret_key'), encoding='utf-8'))
|
||||||
|
key = 'SearXNG_counter_' + m.hexdigest()
|
||||||
|
|
||||||
|
# check requests count
|
||||||
|
count = self.get_int(key)
|
||||||
|
if count is None:
|
||||||
|
# initialize counter with expiration time
|
||||||
|
self.set_int(key, 1, expire)
|
||||||
|
return 1
|
||||||
|
elif limit >= count or not limit:
|
||||||
|
# update counter
|
||||||
|
new_count = count + 1
|
||||||
|
self.set_int(key, new_count, expire)
|
||||||
|
return new_count
|
||||||
|
else:
|
||||||
|
return count
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
|
||||||
|
import time
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -12,27 +13,32 @@ class SimpleSharedDict(shared_abstract.SharedDict):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.d = {}
|
self.d = {}
|
||||||
|
self.expire_times = {}
|
||||||
|
schedule(1, self._expire)
|
||||||
|
|
||||||
def get_int(self, key: str) -> Optional[int]:
|
def get_int(self, key: str) -> Optional[int]:
|
||||||
return self.d.get(key, None)
|
return self.d.get(key, None)
|
||||||
|
|
||||||
def set_int(self, key: str, value: int, expire: Optional[int] = None):
|
def set_int(self, key: str, value: int, expire: Optional[int] = None):
|
||||||
self.d[key] = value
|
self.d[key] = value
|
||||||
if expire:
|
if expire and not self.expire_times.get(key):
|
||||||
self._expire(key, expire)
|
self.expire_times[key] = (time.time(), expire)
|
||||||
|
|
||||||
def get_str(self, key: str) -> Optional[str]:
|
def get_str(self, key: str) -> Optional[str]:
|
||||||
return self.d.get(key, None)
|
return self.d.get(key, None)
|
||||||
|
|
||||||
def set_str(self, key: str, value: str, expire: Optional[int] = None):
|
def set_str(self, key: str, value: str, expire: Optional[int] = None):
|
||||||
self.d[key] = value
|
self.d[key] = value
|
||||||
if expire:
|
if expire and not self.expire_times.get(key):
|
||||||
self._expire(key, expire)
|
self.expire_times[key] = (time.time(), expire)
|
||||||
|
|
||||||
def _expire(self, key: str, expire: int):
|
def _expire(self):
|
||||||
t = threading.Timer(expire, lambda k, d: d.pop(k), args=[key, self.d])
|
now = time.time()
|
||||||
t.daemon = True
|
for key, val in self.expire_times.items():
|
||||||
t.start()
|
created_at, expire = val
|
||||||
|
if now - created_at >= expire:
|
||||||
|
self.d.pop(key)
|
||||||
|
self.expire_times.pop(key)
|
||||||
|
|
||||||
|
|
||||||
def run_locked(func, *args):
|
def run_locked(func, *args):
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, Union
|
||||||
import threading
|
|
||||||
import uwsgi # pyright: ignore # pylint: disable=E0401
|
import uwsgi # pyright: ignore # pylint: disable=E0401
|
||||||
from . import shared_abstract
|
from . import shared_abstract
|
||||||
|
|
||||||
|
@ -12,35 +12,50 @@ _last_signal = 10
|
||||||
|
|
||||||
class UwsgiCacheSharedDict(shared_abstract.SharedDict):
|
class UwsgiCacheSharedDict(shared_abstract.SharedDict):
|
||||||
def get_int(self, key: str) -> Optional[int]:
|
def get_int(self, key: str) -> Optional[int]:
|
||||||
value = uwsgi.cache_get(key)
|
value, _, _ = self._get_value(key)
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
else:
|
else:
|
||||||
return int.from_bytes(value, 'big')
|
return int(value)
|
||||||
|
|
||||||
def set_int(self, key: str, value: int, expire: Optional[int] = None):
|
def set_int(self, key: str, value: int, expire: Optional[int] = None):
|
||||||
b = value.to_bytes(4, 'big')
|
self._set_value(key, value, expire)
|
||||||
uwsgi.cache_update(key, b)
|
|
||||||
if expire:
|
|
||||||
self._expire(key, expire)
|
|
||||||
|
|
||||||
def get_str(self, key: str) -> Optional[str]:
|
def get_str(self, key: str) -> Optional[str]:
|
||||||
value = uwsgi.cache_get(key)
|
value, _, _ = self._get_value(key)
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
else:
|
else:
|
||||||
return value.decode('utf-8')
|
return str(value)
|
||||||
|
|
||||||
def set_str(self, key: str, value: str, expire: Optional[int] = None):
|
def set_str(self, key: str, value: str, expire: Optional[int] = None):
|
||||||
b = value.encode('utf-8')
|
self._set_value(key, value, expire)
|
||||||
uwsgi.cache_update(key, b)
|
|
||||||
if expire:
|
|
||||||
self._expire(key, expire)
|
|
||||||
|
|
||||||
def _expire(self, key: str, expire: int):
|
def _get_value(self, key: str) -> Tuple[Optional[Union[str, int]], Optional[float], Optional[int]]:
|
||||||
t = threading.Timer(expire, uwsgi.cache_del, args=[key])
|
serialized_data = uwsgi.cache_get(key)
|
||||||
t.daemon = True
|
if not serialized_data:
|
||||||
t.start()
|
return None, None, None
|
||||||
|
else:
|
||||||
|
data = json.loads(serialized_data.decode())
|
||||||
|
if 'expire' in data:
|
||||||
|
now = time.time()
|
||||||
|
if now - data['created_at'] >= data['expire']:
|
||||||
|
uwsgi.cache_del(key)
|
||||||
|
return None, None, None
|
||||||
|
return data.get('value'), data.get('created_at'), data.get('expire')
|
||||||
|
|
||||||
|
def _set_value(self, key: str, value: Union[str, int], expire: Optional[int] = None):
|
||||||
|
_, created_at, original_expire = self._get_value(key)
|
||||||
|
|
||||||
|
data = {'value': value}
|
||||||
|
if expire is None and created_at is None:
|
||||||
|
serialized_data = json.dumps(data).encode()
|
||||||
|
uwsgi.cache_update(key, serialized_data)
|
||||||
|
else:
|
||||||
|
data['created_at'] = created_at or time.time()
|
||||||
|
data['expire'] = original_expire or expire
|
||||||
|
serialized_data = json.dumps(data).encode()
|
||||||
|
uwsgi.cache_update(key, serialized_data, expire)
|
||||||
|
|
||||||
|
|
||||||
def run_locked(func, *args):
|
def run_locked(func, *args):
|
||||||
|
|
Loading…
Add table
Reference in a new issue