event.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import asyncio
  2. import threading
  3. from abc import ABC, abstractmethod
  4. from enum import Enum
  5. from typing import Dict, List, Optional, Type, Union
  6. from redis.auth.token import TokenInterface
  7. from redis.credentials import CredentialProvider, StreamingCredentialProvider
  8. from redis.observability.recorder import (
  9. init_connection_count,
  10. register_pools_connection_count,
  11. )
  12. from redis.utils import check_protocol_version, deprecated_function
  13. class EventListenerInterface(ABC):
  14. """
  15. Represents a listener for given event object.
  16. """
  17. @abstractmethod
  18. def listen(self, event: object):
  19. pass
  20. class AsyncEventListenerInterface(ABC):
  21. """
  22. Represents an async listener for given event object.
  23. """
  24. @abstractmethod
  25. async def listen(self, event: object):
  26. pass
  27. class EventDispatcherInterface(ABC):
  28. """
  29. Represents a dispatcher that dispatches events to listeners
  30. associated with given event.
  31. """
  32. @abstractmethod
  33. def dispatch(self, event: object):
  34. pass
  35. @abstractmethod
  36. async def dispatch_async(self, event: object):
  37. pass
  38. @abstractmethod
  39. def register_listeners(
  40. self,
  41. mappings: Dict[
  42. Type[object],
  43. List[Union[EventListenerInterface, AsyncEventListenerInterface]],
  44. ],
  45. ):
  46. """Register additional listeners."""
  47. pass
  48. class EventException(Exception):
  49. """
  50. Exception wrapper that adds an event object into exception context.
  51. """
  52. def __init__(self, exception: Exception, event: object):
  53. self.exception = exception
  54. self.event = event
  55. super().__init__(exception)
  56. class EventDispatcher(EventDispatcherInterface):
  57. # TODO: Make dispatcher to accept external mappings.
  58. def __init__(
  59. self,
  60. event_listeners: Optional[
  61. Dict[Type[object], List[EventListenerInterface]]
  62. ] = None,
  63. ):
  64. """
  65. Dispatcher that dispatches events to listeners associated with given event.
  66. """
  67. self._event_listeners_mapping: Dict[
  68. Type[object], List[EventListenerInterface]
  69. ] = {
  70. AfterConnectionReleasedEvent: [
  71. ReAuthConnectionListener(),
  72. ],
  73. AfterPooledConnectionsInstantiationEvent: [
  74. RegisterReAuthForPooledConnections(),
  75. ],
  76. AfterSingleConnectionInstantiationEvent: [
  77. RegisterReAuthForSingleConnection()
  78. ],
  79. AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()],
  80. AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()],
  81. AsyncAfterConnectionReleasedEvent: [
  82. AsyncReAuthConnectionListener(),
  83. ],
  84. }
  85. self._lock = threading.Lock()
  86. self._async_lock = None
  87. if event_listeners:
  88. self.register_listeners(event_listeners)
  89. def dispatch(self, event: object):
  90. with self._lock:
  91. listeners = self._event_listeners_mapping.get(type(event), [])
  92. for listener in listeners:
  93. listener.listen(event)
  94. async def dispatch_async(self, event: object):
  95. if self._async_lock is None:
  96. self._async_lock = asyncio.Lock()
  97. async with self._async_lock:
  98. listeners = self._event_listeners_mapping.get(type(event), [])
  99. for listener in listeners:
  100. await listener.listen(event)
  101. def register_listeners(
  102. self,
  103. mappings: Dict[
  104. Type[object],
  105. List[Union[EventListenerInterface, AsyncEventListenerInterface]],
  106. ],
  107. ):
  108. with self._lock:
  109. for event_type in mappings:
  110. if event_type in self._event_listeners_mapping:
  111. self._event_listeners_mapping[event_type] = list(
  112. set(
  113. self._event_listeners_mapping[event_type]
  114. + mappings[event_type]
  115. )
  116. )
  117. else:
  118. self._event_listeners_mapping[event_type] = mappings[event_type]
  119. class AfterConnectionReleasedEvent:
  120. """
  121. Event that will be fired before each command execution.
  122. """
  123. def __init__(self, connection):
  124. self._connection = connection
  125. @property
  126. def connection(self):
  127. return self._connection
  128. class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent):
  129. pass
  130. class ClientType(Enum):
  131. SYNC = ("sync",)
  132. ASYNC = ("async",)
  133. class AfterPooledConnectionsInstantiationEvent:
  134. """
  135. Event that will be fired after pooled connection instances was created.
  136. """
  137. def __init__(
  138. self,
  139. connection_pools: List,
  140. client_type: ClientType,
  141. credential_provider: Optional[CredentialProvider] = None,
  142. ):
  143. self._connection_pools = connection_pools
  144. self._client_type = client_type
  145. self._credential_provider = credential_provider
  146. @property
  147. def connection_pools(self):
  148. return self._connection_pools
  149. @property
  150. def client_type(self) -> ClientType:
  151. return self._client_type
  152. @property
  153. def credential_provider(self) -> Union[CredentialProvider, None]:
  154. return self._credential_provider
  155. class AfterSingleConnectionInstantiationEvent:
  156. """
  157. Event that will be fired after single connection instances was created.
  158. :param connection_lock: For sync client thread-lock should be provided,
  159. for async asyncio.Lock
  160. """
  161. def __init__(
  162. self,
  163. connection,
  164. client_type: ClientType,
  165. connection_lock: Union[threading.RLock, asyncio.Lock],
  166. ):
  167. self._connection = connection
  168. self._client_type = client_type
  169. self._connection_lock = connection_lock
  170. @property
  171. def connection(self):
  172. return self._connection
  173. @property
  174. def client_type(self) -> ClientType:
  175. return self._client_type
  176. @property
  177. def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
  178. return self._connection_lock
  179. class AfterPubSubConnectionInstantiationEvent:
  180. def __init__(
  181. self,
  182. pubsub_connection,
  183. connection_pool,
  184. client_type: ClientType,
  185. connection_lock: Union[threading.RLock, asyncio.Lock],
  186. ):
  187. self._pubsub_connection = pubsub_connection
  188. self._connection_pool = connection_pool
  189. self._client_type = client_type
  190. self._connection_lock = connection_lock
  191. @property
  192. def pubsub_connection(self):
  193. return self._pubsub_connection
  194. @property
  195. def connection_pool(self):
  196. return self._connection_pool
  197. @property
  198. def client_type(self) -> ClientType:
  199. return self._client_type
  200. @property
  201. def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
  202. return self._connection_lock
  203. class AfterAsyncClusterInstantiationEvent:
  204. """
  205. Event that will be fired after async cluster instance was created.
  206. Async cluster doesn't use connection pools,
  207. instead ClusterNode object manages connections.
  208. """
  209. def __init__(
  210. self,
  211. nodes: dict,
  212. credential_provider: Optional[CredentialProvider] = None,
  213. ):
  214. self._nodes = nodes
  215. self._credential_provider = credential_provider
  216. @property
  217. def nodes(self) -> dict:
  218. return self._nodes
  219. @property
  220. def credential_provider(self) -> Union[CredentialProvider, None]:
  221. return self._credential_provider
  222. class OnCommandsFailEvent:
  223. """
  224. Event fired whenever a command fails during the execution.
  225. """
  226. def __init__(
  227. self,
  228. commands: tuple,
  229. exception: Exception,
  230. ):
  231. self._commands = commands
  232. self._exception = exception
  233. @property
  234. def commands(self) -> tuple:
  235. return self._commands
  236. @property
  237. def exception(self) -> Exception:
  238. return self._exception
  239. class AsyncOnCommandsFailEvent(OnCommandsFailEvent):
  240. pass
  241. class ReAuthConnectionListener(EventListenerInterface):
  242. """
  243. Listener that performs re-authentication of given connection.
  244. """
  245. def listen(self, event: AfterConnectionReleasedEvent):
  246. event.connection.re_auth()
  247. class AsyncReAuthConnectionListener(AsyncEventListenerInterface):
  248. """
  249. Async listener that performs re-authentication of given connection.
  250. """
  251. async def listen(self, event: AsyncAfterConnectionReleasedEvent):
  252. await event.connection.re_auth()
  253. class RegisterReAuthForPooledConnections(EventListenerInterface):
  254. """
  255. Listener that registers a re-authentication callback for pooled connections.
  256. Required by :class:`StreamingCredentialProvider`.
  257. """
  258. def __init__(self):
  259. self._event = None
  260. def listen(self, event: AfterPooledConnectionsInstantiationEvent):
  261. if isinstance(event.credential_provider, StreamingCredentialProvider):
  262. self._event = event
  263. if event.client_type == ClientType.SYNC:
  264. event.credential_provider.on_next(self._re_auth)
  265. event.credential_provider.on_error(self._raise_on_error)
  266. else:
  267. event.credential_provider.on_next(self._re_auth_async)
  268. event.credential_provider.on_error(self._raise_on_error_async)
  269. def _re_auth(self, token):
  270. for pool in self._event.connection_pools:
  271. pool.re_auth_callback(token)
  272. async def _re_auth_async(self, token):
  273. for pool in self._event.connection_pools:
  274. await pool.re_auth_callback(token)
  275. def _raise_on_error(self, error: Exception):
  276. raise EventException(error, self._event)
  277. async def _raise_on_error_async(self, error: Exception):
  278. raise EventException(error, self._event)
  279. class RegisterReAuthForSingleConnection(EventListenerInterface):
  280. """
  281. Listener that registers a re-authentication callback for single connection.
  282. Required by :class:`StreamingCredentialProvider`.
  283. """
  284. def __init__(self):
  285. self._event = None
  286. def listen(self, event: AfterSingleConnectionInstantiationEvent):
  287. if isinstance(
  288. event.connection.credential_provider, StreamingCredentialProvider
  289. ):
  290. self._event = event
  291. if event.client_type == ClientType.SYNC:
  292. event.connection.credential_provider.on_next(self._re_auth)
  293. event.connection.credential_provider.on_error(self._raise_on_error)
  294. else:
  295. event.connection.credential_provider.on_next(self._re_auth_async)
  296. event.connection.credential_provider.on_error(
  297. self._raise_on_error_async
  298. )
  299. def _re_auth(self, token):
  300. with self._event.connection_lock:
  301. self._event.connection.send_command(
  302. "AUTH", token.try_get("oid"), token.get_value()
  303. )
  304. self._event.connection.read_response()
  305. async def _re_auth_async(self, token):
  306. async with self._event.connection_lock:
  307. await self._event.connection.send_command(
  308. "AUTH", token.try_get("oid"), token.get_value()
  309. )
  310. await self._event.connection.read_response()
  311. def _raise_on_error(self, error: Exception):
  312. raise EventException(error, self._event)
  313. async def _raise_on_error_async(self, error: Exception):
  314. raise EventException(error, self._event)
  315. class RegisterReAuthForAsyncClusterNodes(EventListenerInterface):
  316. def __init__(self):
  317. self._event = None
  318. def listen(self, event: AfterAsyncClusterInstantiationEvent):
  319. if isinstance(event.credential_provider, StreamingCredentialProvider):
  320. self._event = event
  321. event.credential_provider.on_next(self._re_auth)
  322. event.credential_provider.on_error(self._raise_on_error)
  323. async def _re_auth(self, token: TokenInterface):
  324. for key in self._event.nodes:
  325. await self._event.nodes[key].re_auth_callback(token)
  326. async def _raise_on_error(self, error: Exception):
  327. raise EventException(error, self._event)
  328. class RegisterReAuthForPubSub(EventListenerInterface):
  329. def __init__(self):
  330. self._connection = None
  331. self._connection_pool = None
  332. self._client_type = None
  333. self._connection_lock = None
  334. self._event = None
  335. def listen(self, event: AfterPubSubConnectionInstantiationEvent):
  336. if isinstance(
  337. event.pubsub_connection.credential_provider, StreamingCredentialProvider
  338. ) and check_protocol_version(event.pubsub_connection.get_protocol(), 3):
  339. self._event = event
  340. self._connection = event.pubsub_connection
  341. self._connection_pool = event.connection_pool
  342. self._client_type = event.client_type
  343. self._connection_lock = event.connection_lock
  344. if self._client_type == ClientType.SYNC:
  345. self._connection.credential_provider.on_next(self._re_auth)
  346. self._connection.credential_provider.on_error(self._raise_on_error)
  347. else:
  348. self._connection.credential_provider.on_next(self._re_auth_async)
  349. self._connection.credential_provider.on_error(
  350. self._raise_on_error_async
  351. )
  352. def _re_auth(self, token: TokenInterface):
  353. with self._connection_lock:
  354. self._connection.send_command(
  355. "AUTH", token.try_get("oid"), token.get_value()
  356. )
  357. self._connection.read_response()
  358. self._connection_pool.re_auth_callback(token)
  359. async def _re_auth_async(self, token: TokenInterface):
  360. async with self._connection_lock:
  361. await self._connection.send_command(
  362. "AUTH", token.try_get("oid"), token.get_value()
  363. )
  364. await self._connection.read_response()
  365. await self._connection_pool.re_auth_callback(token)
  366. def _raise_on_error(self, error: Exception):
  367. raise EventException(error, self._event)
  368. async def _raise_on_error_async(self, error: Exception):
  369. raise EventException(error, self._event)
  370. class InitializeConnectionCountObservability(EventListenerInterface):
  371. """
  372. Listener that initializes connection count observability.
  373. """
  374. @deprecated_function(
  375. reason="Connection count is now tracked via record_connection_count(). "
  376. "This functionality will be removed in the next major version",
  377. version="7.4.0",
  378. )
  379. def listen(self, event: AfterPooledConnectionsInstantiationEvent):
  380. # Initialize gauge only once, subsequent calls won't have an affect.
  381. # Note: init_connection_count() and register_pools_connection_count()
  382. # are deprecated and will emit their own warnings.
  383. init_connection_count()
  384. # Register pools for connection count observability.
  385. register_pools_connection_count(event.connection_pools)