client.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. import asyncio
  2. import logging
  3. import threading
  4. from typing import Any, Callable, List, Optional
  5. from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy
  6. from redis.background import BackgroundScheduler
  7. from redis.backoff import NoBackoff
  8. from redis.client import PubSubWorkerThread
  9. from redis.commands import CoreCommands, RedisModuleCommands
  10. from redis.maint_notifications import MaintNotificationsConfig
  11. from redis.multidb.circuit import CircuitBreaker
  12. from redis.multidb.circuit import State as CBState
  13. from redis.multidb.command_executor import DefaultCommandExecutor
  14. from redis.multidb.config import (
  15. DEFAULT_GRACE_PERIOD,
  16. DatabaseConfig,
  17. InitialHealthCheck,
  18. MultiDbConfig,
  19. )
  20. from redis.multidb.database import Database, Databases, SyncDatabase
  21. from redis.multidb.exception import (
  22. InitialHealthCheckFailedError,
  23. NoValidDatabaseException,
  24. UnhealthyDatabaseException,
  25. )
  26. from redis.multidb.failure_detector import FailureDetector
  27. from redis.observability.attributes import GeoFailoverReason
  28. from redis.retry import Retry
  29. from redis.utils import experimental
  30. logger = logging.getLogger(__name__)
  31. @experimental
  32. class MultiDBClient(RedisModuleCommands, CoreCommands):
  33. """
  34. Client that operates on multiple logical Redis databases.
  35. Should be used in Client-side geographic failover database setups.
  36. """
  37. def __init__(self, config: MultiDbConfig):
  38. self._databases = config.databases()
  39. self._health_checks = (
  40. config.default_health_checks()
  41. if not config.health_checks
  42. else config.health_checks
  43. )
  44. self._health_check_interval = config.health_check_interval
  45. self._health_check_policy: HealthCheckPolicy = (
  46. config.health_check_policy.value()
  47. )
  48. self._failure_detectors = (
  49. config.default_failure_detectors()
  50. if not config.failure_detectors
  51. else config.failure_detectors
  52. )
  53. self._failover_strategy = (
  54. config.default_failover_strategy()
  55. if config.failover_strategy is None
  56. else config.failover_strategy
  57. )
  58. self._failover_strategy.set_databases(self._databases)
  59. self._auto_fallback_interval = config.auto_fallback_interval
  60. self._event_dispatcher = config.event_dispatcher
  61. self._command_retry = config.command_retry
  62. self._command_retry.update_supported_errors((ConnectionRefusedError,))
  63. self.command_executor = DefaultCommandExecutor(
  64. failure_detectors=self._failure_detectors,
  65. databases=self._databases,
  66. command_retry=self._command_retry,
  67. failover_strategy=self._failover_strategy,
  68. failover_attempts=config.failover_attempts,
  69. failover_delay=config.failover_delay,
  70. event_dispatcher=self._event_dispatcher,
  71. auto_fallback_interval=self._auto_fallback_interval,
  72. )
  73. self.initialized = False
  74. self._bg_scheduler = BackgroundScheduler()
  75. self._hc_lock = threading.Lock()
  76. self._config = config
  77. def __del__(self):
  78. try:
  79. self.close()
  80. except Exception:
  81. # Suppress exceptions during garbage collection.
  82. # close() may fail if called during interpreter shutdown
  83. # or while an event loop is already running.
  84. pass
  85. def initialize(self):
  86. """
  87. Perform initialization of databases to define their initial state.
  88. """
  89. # Initial databases check to define initial state.
  90. # Uses run_coro_sync to run in the shared background loop - this ensures
  91. # connection pools created during initial health check remain valid for
  92. # subsequent recurring health checks (they use the same event loop).
  93. self._bg_scheduler.run_coro_sync(self._perform_initial_health_check)
  94. # Starts recurring health checks on the background.
  95. # Uses run_recurring_coro which shares the same event loop as run_coro_sync
  96. self._bg_scheduler.run_recurring_coro(
  97. self._health_check_interval,
  98. self._check_databases_health,
  99. )
  100. is_active_db_found = False
  101. for database, weight in self._databases:
  102. # Set on state changed callback for each circuit.
  103. database.circuit.on_state_changed(self._on_circuit_state_change_callback)
  104. # Set states according to a weights and circuit state
  105. if database.circuit.state == CBState.CLOSED and not is_active_db_found:
  106. # Directly set the active database during initialization
  107. # without recording a geo failover metric
  108. self.command_executor._active_database = database
  109. is_active_db_found = True
  110. if not is_active_db_found:
  111. raise NoValidDatabaseException(
  112. "Initial connection failed - no active database found"
  113. )
  114. self.initialized = True
  115. def get_databases(self) -> Databases:
  116. """
  117. Returns a sorted (by weight) list of all databases.
  118. """
  119. return self._databases
  120. def set_active_database(self, database: SyncDatabase) -> None:
  121. """
  122. Promote one of the existing databases to become an active.
  123. """
  124. exists = None
  125. for existing_db, _ in self._databases:
  126. if existing_db == database:
  127. exists = True
  128. break
  129. if not exists:
  130. raise ValueError("Given database is not a member of database list")
  131. self._bg_scheduler.run_coro_sync(self._check_db_health, database)
  132. if database.circuit.state == CBState.CLOSED:
  133. highest_weighted_db, _ = self._databases.get_top_n(1)[0]
  134. self.command_executor.active_database = (
  135. database,
  136. GeoFailoverReason.MANUAL,
  137. )
  138. return
  139. raise NoValidDatabaseException(
  140. "Cannot set active database, database is unhealthy"
  141. )
  142. def add_database(
  143. self, config: DatabaseConfig, skip_initial_health_check: bool = True
  144. ):
  145. """
  146. Adds a new database to the database list.
  147. Args:
  148. config: DatabaseConfig object that contains the database configuration.
  149. skip_initial_health_check: If True, adds the database even if it is unhealthy.
  150. """
  151. # The retry object is not used in the lower level clients, so we can safely remove it.
  152. # We rely on command_retry in terms of global retries.
  153. config.client_kwargs["retry"] = Retry(retries=0, backoff=NoBackoff())
  154. # Maintenance notifications are disabled by default in underlying clients,
  155. # but user can override this by providing their own config.
  156. if "maint_notifications_config" not in config.client_kwargs:
  157. config.client_kwargs["maint_notifications_config"] = (
  158. MaintNotificationsConfig(enabled=False)
  159. )
  160. if config.from_url:
  161. client = self._config.client_class.from_url(
  162. config.from_url, **config.client_kwargs
  163. )
  164. elif config.from_pool:
  165. config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff()))
  166. client = self._config.client_class.from_pool(
  167. connection_pool=config.from_pool
  168. )
  169. else:
  170. client = self._config.client_class(**config.client_kwargs)
  171. circuit = (
  172. config.default_circuit_breaker()
  173. if config.circuit is None
  174. else config.circuit
  175. )
  176. database = Database(
  177. client=client,
  178. circuit=circuit,
  179. weight=config.weight,
  180. health_check_url=config.health_check_url,
  181. )
  182. try:
  183. self._bg_scheduler.run_coro_sync(self._check_db_health, database)
  184. except UnhealthyDatabaseException:
  185. if not skip_initial_health_check:
  186. raise
  187. highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
  188. self._databases.add(database, database.weight)
  189. self._change_active_database(database, highest_weighted_db)
  190. def _change_active_database(
  191. self, new_database: SyncDatabase, highest_weight_database: SyncDatabase
  192. ):
  193. if (
  194. new_database.weight > highest_weight_database.weight
  195. and new_database.circuit.state == CBState.CLOSED
  196. ):
  197. self.command_executor.active_database = (
  198. new_database,
  199. GeoFailoverReason.AUTOMATIC,
  200. )
  201. def remove_database(self, database: Database):
  202. """
  203. Removes a database from the database list.
  204. """
  205. weight = self._databases.remove(database)
  206. highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
  207. if (
  208. highest_weight <= weight
  209. and highest_weighted_db.circuit.state == CBState.CLOSED
  210. ):
  211. self.command_executor.active_database = (
  212. highest_weighted_db,
  213. GeoFailoverReason.MANUAL,
  214. )
  215. def update_database_weight(self, database: SyncDatabase, weight: float):
  216. """
  217. Updates a database from the database list.
  218. """
  219. exists = None
  220. for existing_db, _ in self._databases:
  221. if existing_db == database:
  222. exists = True
  223. break
  224. if not exists:
  225. raise ValueError("Given database is not a member of database list")
  226. highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
  227. self._databases.update_weight(database, weight)
  228. database.weight = weight
  229. self._change_active_database(database, highest_weighted_db)
  230. def add_failure_detector(self, failure_detector: FailureDetector):
  231. """
  232. Adds a new failure detector to the database.
  233. """
  234. self._failure_detectors.append(failure_detector)
  235. def add_health_check(self, healthcheck: HealthCheck):
  236. """
  237. Adds a new health check to the database.
  238. """
  239. with self._hc_lock:
  240. self._health_checks.append(healthcheck)
  241. def execute_command(self, *args, **options):
  242. """
  243. Executes a single command and return its result.
  244. """
  245. if not self.initialized:
  246. self.initialize()
  247. return self.command_executor.execute_command(*args, **options)
  248. def pipeline(self):
  249. """
  250. Enters into pipeline mode of the client.
  251. """
  252. return Pipeline(self)
  253. def transaction(self, func: Callable[["Pipeline"], None], *watches, **options):
  254. """
  255. Executes callable as transaction.
  256. """
  257. if not self.initialized:
  258. self.initialize()
  259. return self.command_executor.execute_transaction(func, *watches, *options)
  260. def pubsub(self, **kwargs):
  261. """
  262. Return a Publish/Subscribe object. With this object, you can
  263. subscribe to channels and listen for messages that get published to
  264. them.
  265. """
  266. if not self.initialized:
  267. self.initialize()
  268. return PubSub(self, **kwargs)
  269. async def _check_db_health(self, database: SyncDatabase) -> bool:
  270. """
  271. Runs health checks on the given database until first failure.
  272. """
  273. with self._hc_lock:
  274. health_checks = list(self._health_checks)
  275. # Health check will setup circuit state
  276. is_healthy = await self._health_check_policy.execute(health_checks, database)
  277. if not is_healthy:
  278. if database.circuit.state != CBState.OPEN:
  279. database.circuit.state = CBState.OPEN
  280. return is_healthy
  281. elif is_healthy and database.circuit.state != CBState.CLOSED:
  282. database.circuit.state = CBState.CLOSED
  283. return is_healthy
  284. async def _check_databases_health(self) -> dict[Database, bool]:
  285. """
  286. Runs health checks as a recurring task.
  287. Runs health checks against all databases.
  288. """
  289. task_to_db: dict[asyncio.Task, Database] = {}
  290. self._hc_tasks = []
  291. for database, _ in self._databases:
  292. task = asyncio.create_task(self._check_db_health(database))
  293. task_to_db[task] = database
  294. self._hc_tasks.append(task)
  295. results = await asyncio.gather(*self._hc_tasks, return_exceptions=True)
  296. # Map end results to databases
  297. db_results = {
  298. task_to_db[task]: result for task, result in zip(self._hc_tasks, results)
  299. }
  300. for database, result in db_results.items():
  301. if isinstance(result, UnhealthyDatabaseException):
  302. unhealthy_db = result.database
  303. unhealthy_db.circuit.state = CBState.OPEN
  304. logger.debug(
  305. "Health check failed, due to exception",
  306. exc_info=result.original_exception,
  307. )
  308. db_results[unhealthy_db] = False
  309. return db_results
  310. async def _perform_initial_health_check(self):
  311. """
  312. Runs initial health check and evaluate healthiness based on initial_health_check_policy.
  313. """
  314. results = await self._check_databases_health()
  315. is_healthy = True
  316. if self._config.initial_health_check_policy == InitialHealthCheck.ALL_AVAILABLE:
  317. is_healthy = False not in results.values()
  318. elif (
  319. self._config.initial_health_check_policy
  320. == InitialHealthCheck.MAJORITY_AVAILABLE
  321. ):
  322. is_healthy = sum(results.values()) > len(results) / 2
  323. elif (
  324. self._config.initial_health_check_policy == InitialHealthCheck.ONE_AVAILABLE
  325. ):
  326. is_healthy = True in results.values()
  327. if not is_healthy:
  328. raise InitialHealthCheckFailedError(
  329. f"Initial health check failed. Initial health check policy: {self._config.initial_health_check_policy}"
  330. )
  331. def _on_circuit_state_change_callback(
  332. self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState
  333. ):
  334. if new_state == CBState.HALF_OPEN:
  335. self._bg_scheduler.run_coro_fire_and_forget(
  336. self._check_db_health, circuit.database
  337. )
  338. return
  339. if old_state == CBState.CLOSED and new_state == CBState.OPEN:
  340. logger.warning(
  341. f"Database {circuit.database} is unreachable. Failover has been initiated."
  342. )
  343. self._bg_scheduler.run_once(
  344. DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit
  345. )
  346. if old_state != CBState.CLOSED and new_state == CBState.CLOSED:
  347. logger.info(f"Database {circuit.database} is reachable again.")
  348. def close(self):
  349. """
  350. Closes the client and all its resources.
  351. """
  352. # Close health check policy BEFORE stopping the scheduler.
  353. # The policy's connection pools were created on the shared health check
  354. # event loop, so they must be disconnected on that same loop to avoid
  355. # leaking sockets/file descriptors.
  356. if self._bg_scheduler:
  357. try:
  358. self._bg_scheduler.run_coro_sync(self._health_check_policy.close)
  359. except Exception:
  360. pass
  361. self._bg_scheduler.stop()
  362. if self.command_executor.active_database:
  363. self.command_executor.active_database.client.close()
  364. def _half_open_circuit(circuit: CircuitBreaker):
  365. circuit.state = CBState.HALF_OPEN
  366. class Pipeline(RedisModuleCommands, CoreCommands):
  367. """
  368. Pipeline implementation for multiple logical Redis databases.
  369. """
  370. def __init__(self, client: MultiDBClient):
  371. self._command_stack = []
  372. self._client = client
  373. def __enter__(self) -> "Pipeline":
  374. return self
  375. def __exit__(self, exc_type, exc_value, traceback):
  376. self.reset()
  377. def __del__(self):
  378. try:
  379. self.reset()
  380. except Exception:
  381. pass
  382. def __len__(self) -> int:
  383. return len(self._command_stack)
  384. def __bool__(self) -> bool:
  385. """Pipeline instances should always evaluate to True"""
  386. return True
  387. def reset(self) -> None:
  388. self._command_stack = []
  389. def close(self) -> None:
  390. """Close the pipeline"""
  391. self.reset()
  392. def pipeline_execute_command(self, *args, **options) -> "Pipeline":
  393. """
  394. Stage a command to be executed when execute() is next called
  395. Returns the current Pipeline object back so commands can be
  396. chained together, such as:
  397. pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
  398. At some other point, you can then run: pipe.execute(),
  399. which will execute all commands queued in the pipe.
  400. """
  401. self._command_stack.append((args, options))
  402. return self
  403. def execute_command(self, *args, **kwargs):
  404. """Adds a command to the stack"""
  405. return self.pipeline_execute_command(*args, **kwargs)
  406. def execute(self) -> List[Any]:
  407. """Execute all the commands in the current pipeline"""
  408. if not self._client.initialized:
  409. self._client.initialize()
  410. try:
  411. return self._client.command_executor.execute_pipeline(
  412. tuple(self._command_stack)
  413. )
  414. finally:
  415. self.reset()
  416. class PubSub:
  417. """
  418. PubSub object for multi database client.
  419. """
  420. def __init__(self, client: MultiDBClient, **kwargs):
  421. """Initialize the PubSub object for a multi-database client.
  422. Args:
  423. client: MultiDBClient instance to use for pub/sub operations
  424. **kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
  425. """
  426. self._client = client
  427. self._client.command_executor.pubsub(**kwargs)
  428. def __enter__(self) -> "PubSub":
  429. return self
  430. def __del__(self) -> None:
  431. try:
  432. # if this object went out of scope prior to shutting down
  433. # subscriptions, close the connection manually before
  434. # returning it to the connection pool
  435. self.reset()
  436. except Exception:
  437. pass
  438. def reset(self) -> None:
  439. return self._client.command_executor.execute_pubsub_method("reset")
  440. def close(self) -> None:
  441. self.reset()
  442. @property
  443. def subscribed(self) -> bool:
  444. return self._client.command_executor.active_pubsub.subscribed
  445. def execute_command(self, *args):
  446. return self._client.command_executor.execute_pubsub_method(
  447. "execute_command", *args
  448. )
  449. def psubscribe(self, *args, **kwargs):
  450. """
  451. Subscribe to channel patterns. Patterns supplied as keyword arguments
  452. expect a pattern name as the key and a callable as the value. A
  453. pattern's callable will be invoked automatically when a message is
  454. received on that pattern rather than producing a message via
  455. ``listen()``.
  456. """
  457. return self._client.command_executor.execute_pubsub_method(
  458. "psubscribe", *args, **kwargs
  459. )
  460. def punsubscribe(self, *args):
  461. """
  462. Unsubscribe from the supplied patterns. If empty, unsubscribe from
  463. all patterns.
  464. """
  465. return self._client.command_executor.execute_pubsub_method(
  466. "punsubscribe", *args
  467. )
  468. def subscribe(self, *args, **kwargs):
  469. """
  470. Subscribe to channels. Channels supplied as keyword arguments expect
  471. a channel name as the key and a callable as the value. A channel's
  472. callable will be invoked automatically when a message is received on
  473. that channel rather than producing a message via ``listen()`` or
  474. ``get_message()``.
  475. """
  476. return self._client.command_executor.execute_pubsub_method(
  477. "subscribe", *args, **kwargs
  478. )
  479. def unsubscribe(self, *args):
  480. """
  481. Unsubscribe from the supplied channels. If empty, unsubscribe from
  482. all channels
  483. """
  484. return self._client.command_executor.execute_pubsub_method("unsubscribe", *args)
  485. def ssubscribe(self, *args, **kwargs):
  486. """
  487. Subscribes the client to the specified shard channels.
  488. Channels supplied as keyword arguments expect a channel name as the key
  489. and a callable as the value. A channel's callable will be invoked automatically
  490. when a message is received on that channel rather than producing a message via
  491. ``listen()`` or ``get_sharded_message()``.
  492. """
  493. return self._client.command_executor.execute_pubsub_method(
  494. "ssubscribe", *args, **kwargs
  495. )
  496. def sunsubscribe(self, *args):
  497. """
  498. Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
  499. all shard_channels
  500. """
  501. return self._client.command_executor.execute_pubsub_method(
  502. "sunsubscribe", *args
  503. )
  504. def get_message(
  505. self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
  506. ):
  507. """
  508. Get the next message if one is available, otherwise None.
  509. If timeout is specified, the system will wait for `timeout` seconds
  510. before returning. Timeout should be specified as a floating point
  511. number, or None, to wait indefinitely.
  512. """
  513. return self._client.command_executor.execute_pubsub_method(
  514. "get_message",
  515. ignore_subscribe_messages=ignore_subscribe_messages,
  516. timeout=timeout,
  517. )
  518. def get_sharded_message(
  519. self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
  520. ):
  521. """
  522. Get the next message if one is available in a sharded channel, otherwise None.
  523. If timeout is specified, the system will wait for `timeout` seconds
  524. before returning. Timeout should be specified as a floating point
  525. number, or None, to wait indefinitely.
  526. """
  527. return self._client.command_executor.execute_pubsub_method(
  528. "get_sharded_message",
  529. ignore_subscribe_messages=ignore_subscribe_messages,
  530. timeout=timeout,
  531. )
  532. def run_in_thread(
  533. self,
  534. sleep_time: float = 0.0,
  535. daemon: bool = False,
  536. exception_handler: Optional[Callable] = None,
  537. sharded_pubsub: bool = False,
  538. ) -> "PubSubWorkerThread":
  539. return self._client.command_executor.execute_pubsub_run(
  540. sleep_time,
  541. daemon=daemon,
  542. exception_handler=exception_handler,
  543. pubsub=self,
  544. sharded_pubsub=sharded_pubsub,
  545. )