utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. import datetime
  2. import inspect
  3. import logging
  4. import textwrap
  5. import warnings
  6. from collections.abc import Callable
  7. from contextlib import contextmanager
  8. from functools import wraps
  9. from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, TypeVar, Union
  10. from redis.exceptions import DataError
  11. from redis.typing import AbsExpiryT, EncodableT, ExpiryT
  12. if TYPE_CHECKING:
  13. from redis.client import Redis
  14. try:
  15. import hiredis # noqa
  16. # Only support Hiredis >= 3.0:
  17. hiredis_version = hiredis.__version__.split(".")
  18. HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or (
  19. int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2
  20. )
  21. if not HIREDIS_AVAILABLE:
  22. raise ImportError("hiredis package should be >= 3.2.0")
  23. except ImportError:
  24. HIREDIS_AVAILABLE = False
  25. try:
  26. import ssl # noqa
  27. SSL_AVAILABLE = True
  28. except ImportError:
  29. SSL_AVAILABLE = False
  30. try:
  31. import cryptography # noqa
  32. CRYPTOGRAPHY_AVAILABLE = True
  33. except ImportError:
  34. CRYPTOGRAPHY_AVAILABLE = False
  35. from importlib import metadata
  36. def from_url(url: str, **kwargs: Any) -> "Redis":
  37. """
  38. Returns an active Redis client generated from the given database URL.
  39. Will attempt to extract the database id from the path url fragment, if
  40. none is provided.
  41. """
  42. from redis.client import Redis
  43. return Redis.from_url(url, **kwargs)
  44. @contextmanager
  45. def pipeline(redis_obj):
  46. p = redis_obj.pipeline()
  47. yield p
  48. p.execute()
  49. def str_if_bytes(value: Union[str, bytes]) -> str:
  50. return (
  51. value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value
  52. )
  53. def safe_str(value):
  54. return str(str_if_bytes(value))
  55. def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]:
  56. """
  57. Merge all provided dicts into 1 dict.
  58. *dicts : `dict`
  59. dictionaries to merge
  60. """
  61. merged = {}
  62. for d in dicts:
  63. merged.update(d)
  64. return merged
  65. def list_keys_to_dict(key_list, callback):
  66. return dict.fromkeys(key_list, callback)
  67. def merge_result(command, res):
  68. """
  69. Merge all items in `res` into a list.
  70. This command is used when sending a command to multiple nodes
  71. and the result from each node should be merged into a single list.
  72. res : 'dict'
  73. """
  74. result = set()
  75. for v in res.values():
  76. for value in v:
  77. result.add(value)
  78. return list(result)
  79. def warn_deprecated(name, reason="", version="", stacklevel=2):
  80. import warnings
  81. msg = f"Call to deprecated {name}."
  82. if reason:
  83. msg += f" ({reason})"
  84. if version:
  85. msg += f" -- Deprecated since version {version}."
  86. warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
  87. def deprecated_function(reason="", version="", name=None):
  88. """
  89. Decorator to mark a function as deprecated.
  90. """
  91. def decorator(func):
  92. if inspect.iscoroutinefunction(func):
  93. # Create async wrapper for async functions
  94. @wraps(func)
  95. async def async_wrapper(*args, **kwargs):
  96. warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
  97. return await func(*args, **kwargs)
  98. return async_wrapper
  99. else:
  100. # Create regular wrapper for sync functions
  101. @wraps(func)
  102. def wrapper(*args, **kwargs):
  103. warn_deprecated(name or func.__name__, reason, version, stacklevel=3)
  104. return func(*args, **kwargs)
  105. return wrapper
  106. return decorator
  107. def warn_deprecated_arg_usage(
  108. arg_name: Union[list, str],
  109. function_name: str,
  110. reason: str = "",
  111. version: str = "",
  112. stacklevel: int = 2,
  113. ):
  114. import warnings
  115. msg = (
  116. f"Call to '{function_name}' function with deprecated"
  117. f" usage of input argument/s '{arg_name}'."
  118. )
  119. if reason:
  120. msg += f" ({reason})"
  121. if version:
  122. msg += f" -- Deprecated since version {version}."
  123. warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
  124. C = TypeVar("C", bound=Callable)
  125. def _get_filterable_args(
  126. func: Callable, args: tuple, kwargs: dict, allowed_args: Optional[List[str]] = None
  127. ) -> dict:
  128. """
  129. Extract arguments from function call that should be checked for deprecation/experimental warnings.
  130. Excludes 'self' and any explicitly allowed args.
  131. """
  132. arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
  133. filterable_args = dict(zip(arg_names, args))
  134. filterable_args.update(kwargs)
  135. filterable_args.pop("self", None)
  136. if allowed_args:
  137. for allowed_arg in allowed_args:
  138. filterable_args.pop(allowed_arg, None)
  139. return filterable_args
  140. def deprecated_args(
  141. args_to_warn: Optional[List[str]] = None,
  142. allowed_args: Optional[List[str]] = None,
  143. reason: str = "",
  144. version: str = "",
  145. ) -> Callable[[C], C]:
  146. """
  147. Decorator to mark specified args of a function as deprecated.
  148. If '*' is in args_to_warn, all arguments will be marked as deprecated.
  149. """
  150. if args_to_warn is None:
  151. args_to_warn = ["*"]
  152. if allowed_args is None:
  153. allowed_args = []
  154. def _check_deprecated_args(func, filterable_args):
  155. """Check and warn about deprecated arguments."""
  156. for arg in args_to_warn:
  157. if arg == "*" and len(filterable_args) > 0:
  158. warn_deprecated_arg_usage(
  159. list(filterable_args.keys()),
  160. func.__name__,
  161. reason,
  162. version,
  163. stacklevel=5,
  164. )
  165. elif arg in filterable_args:
  166. warn_deprecated_arg_usage(
  167. arg, func.__name__, reason, version, stacklevel=5
  168. )
  169. def decorator(func: C) -> C:
  170. if inspect.iscoroutinefunction(func):
  171. @wraps(func)
  172. async def async_wrapper(*args, **kwargs):
  173. filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
  174. _check_deprecated_args(func, filterable_args)
  175. return await func(*args, **kwargs)
  176. return async_wrapper
  177. else:
  178. @wraps(func)
  179. def wrapper(*args, **kwargs):
  180. filterable_args = _get_filterable_args(func, args, kwargs, allowed_args)
  181. _check_deprecated_args(func, filterable_args)
  182. return func(*args, **kwargs)
  183. return wrapper
  184. return decorator
  185. def _set_info_logger():
  186. """
  187. Set up a logger that log info logs to stdout.
  188. (This is used by the default push response handler)
  189. """
  190. if "push_response" not in logging.root.manager.loggerDict.keys():
  191. logger = logging.getLogger("push_response")
  192. logger.setLevel(logging.INFO)
  193. handler = logging.StreamHandler()
  194. handler.setLevel(logging.INFO)
  195. logger.addHandler(handler)
  196. def check_protocol_version(
  197. protocol: Optional[Union[str, int]], expected_version: int = 3
  198. ) -> bool:
  199. if protocol is None:
  200. return False
  201. if isinstance(protocol, str):
  202. try:
  203. protocol = int(protocol)
  204. except ValueError:
  205. return False
  206. return protocol == expected_version
  207. def get_lib_version():
  208. try:
  209. libver = metadata.version("redis")
  210. except metadata.PackageNotFoundError:
  211. libver = "99.99.99"
  212. return libver
  213. def format_error_message(host_error: str, exception: BaseException) -> str:
  214. if not exception.args:
  215. return f"Error connecting to {host_error}."
  216. elif len(exception.args) == 1:
  217. return f"Error {exception.args[0]} connecting to {host_error}."
  218. else:
  219. return (
  220. f"Error {exception.args[0]} connecting to {host_error}. "
  221. f"{exception.args[1]}."
  222. )
  223. def compare_versions(version1: str, version2: str) -> int:
  224. """
  225. Compare two versions.
  226. :return: -1 if version1 > version2
  227. 0 if both versions are equal
  228. 1 if version1 < version2
  229. """
  230. num_versions1 = list(map(int, version1.split(".")))
  231. num_versions2 = list(map(int, version2.split(".")))
  232. if len(num_versions1) > len(num_versions2):
  233. diff = len(num_versions1) - len(num_versions2)
  234. for _ in range(diff):
  235. num_versions2.append(0)
  236. elif len(num_versions1) < len(num_versions2):
  237. diff = len(num_versions2) - len(num_versions1)
  238. for _ in range(diff):
  239. num_versions1.append(0)
  240. for i, ver in enumerate(num_versions1):
  241. if num_versions1[i] > num_versions2[i]:
  242. return -1
  243. elif num_versions1[i] < num_versions2[i]:
  244. return 1
  245. return 0
  246. def ensure_string(key):
  247. if isinstance(key, bytes):
  248. return key.decode("utf-8")
  249. elif isinstance(key, str):
  250. return key
  251. else:
  252. raise TypeError("Key must be either a string or bytes")
  253. def extract_expire_flags(
  254. ex: Optional[ExpiryT] = None,
  255. px: Optional[ExpiryT] = None,
  256. exat: Optional[AbsExpiryT] = None,
  257. pxat: Optional[AbsExpiryT] = None,
  258. ) -> List[EncodableT]:
  259. exp_options: list[EncodableT] = []
  260. if ex is not None:
  261. exp_options.append("EX")
  262. if isinstance(ex, datetime.timedelta):
  263. exp_options.append(int(ex.total_seconds()))
  264. elif isinstance(ex, int):
  265. exp_options.append(ex)
  266. elif isinstance(ex, str) and ex.isdigit():
  267. exp_options.append(int(ex))
  268. else:
  269. raise DataError("ex must be datetime.timedelta or int")
  270. elif px is not None:
  271. exp_options.append("PX")
  272. if isinstance(px, datetime.timedelta):
  273. exp_options.append(int(px.total_seconds() * 1000))
  274. elif isinstance(px, int):
  275. exp_options.append(px)
  276. else:
  277. raise DataError("px must be datetime.timedelta or int")
  278. elif exat is not None:
  279. if isinstance(exat, datetime.datetime):
  280. exat = int(exat.timestamp())
  281. exp_options.extend(["EXAT", exat])
  282. elif pxat is not None:
  283. if isinstance(pxat, datetime.datetime):
  284. pxat = int(pxat.timestamp() * 1000)
  285. exp_options.extend(["PXAT", pxat])
  286. return exp_options
  287. def truncate_text(txt, max_length=100):
  288. return textwrap.shorten(
  289. text=txt, width=max_length, placeholder="...", break_long_words=True
  290. )
  291. def dummy_fail():
  292. """
  293. Fake function for a Retry object if you don't need to handle each failure.
  294. """
  295. pass
  296. async def dummy_fail_async():
  297. """
  298. Async fake function for a Retry object if you don't need to handle each failure.
  299. """
  300. pass
  301. def experimental(cls):
  302. """
  303. Decorator to mark a class as experimental.
  304. """
  305. original_init = cls.__init__
  306. @wraps(original_init)
  307. def new_init(self, *args, **kwargs):
  308. warnings.warn(
  309. f"{cls.__name__} is an experimental and may change or be removed in future versions.",
  310. category=UserWarning,
  311. stacklevel=2,
  312. )
  313. original_init(self, *args, **kwargs)
  314. cls.__init__ = new_init
  315. return cls
  316. def warn_experimental(name, stacklevel=2):
  317. import warnings
  318. msg = (
  319. f"Call to experimental method {name}. "
  320. "Be aware that the function arguments can "
  321. "change or be removed in future versions."
  322. )
  323. warnings.warn(msg, category=UserWarning, stacklevel=stacklevel)
  324. def experimental_method() -> Callable[[C], C]:
  325. """
  326. Decorator to mark a function as experimental.
  327. """
  328. def decorator(func: C) -> C:
  329. if inspect.iscoroutinefunction(func):
  330. # Create async wrapper for async functions
  331. @wraps(func)
  332. async def async_wrapper(*args, **kwargs):
  333. warn_experimental(func.__name__, stacklevel=2)
  334. return await func(*args, **kwargs)
  335. return async_wrapper
  336. else:
  337. # Create regular wrapper for sync functions
  338. @wraps(func)
  339. def wrapper(*args, **kwargs):
  340. warn_experimental(func.__name__, stacklevel=2)
  341. return func(*args, **kwargs)
  342. return wrapper
  343. return decorator
  344. def warn_experimental_arg_usage(
  345. arg_name: Union[list, str],
  346. function_name: str,
  347. stacklevel: int = 2,
  348. ):
  349. import warnings
  350. msg = (
  351. f"Call to '{function_name}' method with experimental"
  352. f" usage of input argument/s '{arg_name}'."
  353. )
  354. warnings.warn(msg, category=UserWarning, stacklevel=stacklevel)
  355. def experimental_args(
  356. args_to_warn: Optional[List[str]] = None,
  357. ) -> Callable[[C], C]:
  358. """
  359. Decorator to mark specified args of a function as experimental.
  360. If '*' is in args_to_warn, all arguments will be marked as experimental.
  361. """
  362. if args_to_warn is None:
  363. args_to_warn = ["*"]
  364. def _check_experimental_args(func, filterable_args):
  365. """Check and warn about experimental arguments."""
  366. for arg in args_to_warn:
  367. if arg == "*" and len(filterable_args) > 0:
  368. warn_experimental_arg_usage(
  369. list(filterable_args.keys()), func.__name__, stacklevel=4
  370. )
  371. elif arg in filterable_args:
  372. warn_experimental_arg_usage(arg, func.__name__, stacklevel=4)
  373. def decorator(func: C) -> C:
  374. if inspect.iscoroutinefunction(func):
  375. @wraps(func)
  376. async def async_wrapper(*args, **kwargs):
  377. filterable_args = _get_filterable_args(func, args, kwargs)
  378. if len(filterable_args) > 0:
  379. _check_experimental_args(func, filterable_args)
  380. return await func(*args, **kwargs)
  381. return async_wrapper
  382. else:
  383. @wraps(func)
  384. def wrapper(*args, **kwargs):
  385. filterable_args = _get_filterable_args(func, args, kwargs)
  386. if len(filterable_args) > 0:
  387. _check_experimental_args(func, filterable_args)
  388. return func(*args, **kwargs)
  389. return wrapper
  390. return decorator