maint_notifications.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199
  1. import enum
  2. import ipaddress
  3. import logging
  4. import re
  5. import threading
  6. import time
  7. from abc import ABC, abstractmethod
  8. from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
  9. from redis.observability.attributes import get_pool_name
  10. from redis.observability.recorder import (
  11. record_connection_handoff,
  12. record_connection_relaxed_timeout,
  13. record_maint_notification_count,
  14. )
  15. from redis.typing import Number
  16. if TYPE_CHECKING:
  17. from redis.cluster import MaintNotificationsAbstractRedisCluster
  18. logger = logging.getLogger(__name__)
  19. class MaintenanceState(enum.Enum):
  20. NONE = "none"
  21. MOVING = "moving"
  22. MAINTENANCE = "maintenance"
  23. class EndpointType(enum.Enum):
  24. """Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
  25. INTERNAL_IP = "internal-ip"
  26. INTERNAL_FQDN = "internal-fqdn"
  27. EXTERNAL_IP = "external-ip"
  28. EXTERNAL_FQDN = "external-fqdn"
  29. NONE = "none"
  30. def __str__(self):
  31. """Return the string value of the enum."""
  32. return self.value
  33. if TYPE_CHECKING:
  34. from redis.connection import (
  35. MaintNotificationsAbstractConnection,
  36. MaintNotificationsAbstractConnectionPool,
  37. )
  38. class MaintenanceNotification(ABC):
  39. """
  40. Base class for maintenance notifications sent through push messages by Redis server.
  41. This class provides common functionality for all maintenance notifications including
  42. unique identification and TTL (Time-To-Live) functionality.
  43. Attributes:
  44. id (int): Unique identifier for this notification
  45. ttl (int): Time-to-live in seconds for this notification
  46. creation_time (float): Timestamp when the notification was created/read
  47. """
  48. def __init__(self, id: int, ttl: int):
  49. """
  50. Initialize a new MaintenanceNotification with unique ID and TTL functionality.
  51. Args:
  52. id (int): Unique identifier for this notification
  53. ttl (int): Time-to-live in seconds for this notification
  54. """
  55. self.id = id
  56. self.ttl = ttl
  57. self.creation_time = time.monotonic()
  58. self.expire_at = self.creation_time + self.ttl
  59. def is_expired(self) -> bool:
  60. """
  61. Check if this notification has expired based on its TTL
  62. and creation time.
  63. Returns:
  64. bool: True if the notification has expired, False otherwise
  65. """
  66. return time.monotonic() > (self.creation_time + self.ttl)
  67. @abstractmethod
  68. def __repr__(self) -> str:
  69. """
  70. Return a string representation of the maintenance notification.
  71. This method must be implemented by all concrete subclasses.
  72. Returns:
  73. str: String representation of the notification
  74. """
  75. pass
  76. @abstractmethod
  77. def __eq__(self, other) -> bool:
  78. """
  79. Compare two maintenance notifications for equality.
  80. This method must be implemented by all concrete subclasses.
  81. Notifications are typically considered equal if they have the same id
  82. and are of the same type.
  83. Args:
  84. other: The other object to compare with
  85. Returns:
  86. bool: True if the notifications are equal, False otherwise
  87. """
  88. pass
  89. @abstractmethod
  90. def __hash__(self) -> int:
  91. """
  92. Return a hash value for the maintenance notification.
  93. This method must be implemented by all concrete subclasses to allow
  94. instances to be used in sets and as dictionary keys.
  95. Returns:
  96. int: Hash value for the notification
  97. """
  98. pass
  99. class NodeMovingNotification(MaintenanceNotification):
  100. """
  101. This notification is received when a node is replaced with a new node
  102. during cluster rebalancing or maintenance operations.
  103. """
  104. def __init__(
  105. self,
  106. id: int,
  107. new_node_host: Optional[str],
  108. new_node_port: Optional[int],
  109. ttl: int,
  110. ):
  111. """
  112. Initialize a new NodeMovingNotification.
  113. Args:
  114. id (int): Unique identifier for this notification
  115. new_node_host (str): Hostname or IP address of the new replacement node
  116. new_node_port (int): Port number of the new replacement node
  117. ttl (int): Time-to-live in seconds for this notification
  118. """
  119. super().__init__(id, ttl)
  120. self.new_node_host = new_node_host
  121. self.new_node_port = new_node_port
  122. def __repr__(self) -> str:
  123. expiry_time = self.expire_at
  124. remaining = max(0, expiry_time - time.monotonic())
  125. return (
  126. f"{self.__class__.__name__}("
  127. f"id={self.id}, "
  128. f"new_node_host='{self.new_node_host}', "
  129. f"new_node_port={self.new_node_port}, "
  130. f"ttl={self.ttl}, "
  131. f"creation_time={self.creation_time}, "
  132. f"expires_at={expiry_time}, "
  133. f"remaining={remaining:.1f}s, "
  134. f"expired={self.is_expired()}"
  135. f")"
  136. )
  137. def __eq__(self, other) -> bool:
  138. """
  139. Two NodeMovingNotification notifications are considered equal if they have the same
  140. id, new_node_host, and new_node_port.
  141. """
  142. if not isinstance(other, NodeMovingNotification):
  143. return False
  144. return (
  145. self.id == other.id
  146. and self.new_node_host == other.new_node_host
  147. and self.new_node_port == other.new_node_port
  148. )
  149. def __hash__(self) -> int:
  150. """
  151. Return a hash value for the notification to allow
  152. instances to be used in sets and as dictionary keys.
  153. Returns:
  154. int: Hash value based on notification type class name, id,
  155. new_node_host and new_node_port
  156. """
  157. try:
  158. node_port = int(self.new_node_port) if self.new_node_port else None
  159. except ValueError:
  160. node_port = 0
  161. return hash(
  162. (
  163. self.__class__.__name__,
  164. int(self.id),
  165. str(self.new_node_host),
  166. node_port,
  167. )
  168. )
  169. class NodeMigratingNotification(MaintenanceNotification):
  170. """
  171. Notification for when a Redis cluster node is in the process of migrating slots.
  172. This notification is received when a node starts migrating its slots to another node
  173. during cluster rebalancing or maintenance operations.
  174. Args:
  175. id (int): Unique identifier for this notification
  176. ttl (int): Time-to-live in seconds for this notification
  177. """
  178. def __init__(self, id: int, ttl: int):
  179. super().__init__(id, ttl)
  180. def __repr__(self) -> str:
  181. expiry_time = self.creation_time + self.ttl
  182. remaining = max(0, expiry_time - time.monotonic())
  183. return (
  184. f"{self.__class__.__name__}("
  185. f"id={self.id}, "
  186. f"ttl={self.ttl}, "
  187. f"creation_time={self.creation_time}, "
  188. f"expires_at={expiry_time}, "
  189. f"remaining={remaining:.1f}s, "
  190. f"expired={self.is_expired()}"
  191. f")"
  192. )
  193. def __eq__(self, other) -> bool:
  194. """
  195. Two NodeMigratingNotification notifications are considered equal if they have the same
  196. id and are of the same type.
  197. """
  198. if not isinstance(other, NodeMigratingNotification):
  199. return False
  200. return self.id == other.id and type(self) is type(other)
  201. def __hash__(self) -> int:
  202. """
  203. Return a hash value for the notification to allow
  204. instances to be used in sets and as dictionary keys.
  205. Returns:
  206. int: Hash value based on notification type and id
  207. """
  208. return hash((self.__class__.__name__, int(self.id)))
  209. class NodeMigratedNotification(MaintenanceNotification):
  210. """
  211. Notification for when a Redis cluster node has completed migrating slots.
  212. This notification is received when a node has finished migrating all its slots
  213. to other nodes during cluster rebalancing or maintenance operations.
  214. Args:
  215. id (int): Unique identifier for this notification
  216. """
  217. DEFAULT_TTL = 5
  218. def __init__(self, id: int):
  219. super().__init__(id, NodeMigratedNotification.DEFAULT_TTL)
  220. def __repr__(self) -> str:
  221. expiry_time = self.creation_time + self.ttl
  222. remaining = max(0, expiry_time - time.monotonic())
  223. return (
  224. f"{self.__class__.__name__}("
  225. f"id={self.id}, "
  226. f"ttl={self.ttl}, "
  227. f"creation_time={self.creation_time}, "
  228. f"expires_at={expiry_time}, "
  229. f"remaining={remaining:.1f}s, "
  230. f"expired={self.is_expired()}"
  231. f")"
  232. )
  233. def __eq__(self, other) -> bool:
  234. """
  235. Two NodeMigratedNotification notifications are considered equal if they have the same
  236. id and are of the same type.
  237. """
  238. if not isinstance(other, NodeMigratedNotification):
  239. return False
  240. return self.id == other.id and type(self) is type(other)
  241. def __hash__(self) -> int:
  242. """
  243. Return a hash value for the notification to allow
  244. instances to be used in sets and as dictionary keys.
  245. Returns:
  246. int: Hash value based on notification type and id
  247. """
  248. return hash((self.__class__.__name__, int(self.id)))
  249. class NodeFailingOverNotification(MaintenanceNotification):
  250. """
  251. Notification for when a Redis cluster node is in the process of failing over.
  252. This notification is received when a node starts a failover process during
  253. cluster maintenance operations or when handling node failures.
  254. Args:
  255. id (int): Unique identifier for this notification
  256. ttl (int): Time-to-live in seconds for this notification
  257. """
  258. def __init__(self, id: int, ttl: int):
  259. super().__init__(id, ttl)
  260. def __repr__(self) -> str:
  261. expiry_time = self.creation_time + self.ttl
  262. remaining = max(0, expiry_time - time.monotonic())
  263. return (
  264. f"{self.__class__.__name__}("
  265. f"id={self.id}, "
  266. f"ttl={self.ttl}, "
  267. f"creation_time={self.creation_time}, "
  268. f"expires_at={expiry_time}, "
  269. f"remaining={remaining:.1f}s, "
  270. f"expired={self.is_expired()}"
  271. f")"
  272. )
  273. def __eq__(self, other) -> bool:
  274. """
  275. Two NodeFailingOverNotification notifications are considered equal if they have the same
  276. id and are of the same type.
  277. """
  278. if not isinstance(other, NodeFailingOverNotification):
  279. return False
  280. return self.id == other.id and type(self) is type(other)
  281. def __hash__(self) -> int:
  282. """
  283. Return a hash value for the notification to allow
  284. instances to be used in sets and as dictionary keys.
  285. Returns:
  286. int: Hash value based on notification type and id
  287. """
  288. return hash((self.__class__.__name__, int(self.id)))
  289. class NodeFailedOverNotification(MaintenanceNotification):
  290. """
  291. Notification for when a Redis cluster node has completed a failover.
  292. This notification is received when a node has finished the failover process
  293. during cluster maintenance operations or after handling node failures.
  294. Args:
  295. id (int): Unique identifier for this notification
  296. """
  297. DEFAULT_TTL = 5
  298. def __init__(self, id: int):
  299. super().__init__(id, NodeFailedOverNotification.DEFAULT_TTL)
  300. def __repr__(self) -> str:
  301. expiry_time = self.creation_time + self.ttl
  302. remaining = max(0, expiry_time - time.monotonic())
  303. return (
  304. f"{self.__class__.__name__}("
  305. f"id={self.id}, "
  306. f"ttl={self.ttl}, "
  307. f"creation_time={self.creation_time}, "
  308. f"expires_at={expiry_time}, "
  309. f"remaining={remaining:.1f}s, "
  310. f"expired={self.is_expired()}"
  311. f")"
  312. )
  313. def __eq__(self, other) -> bool:
  314. """
  315. Two NodeFailedOverNotification notifications are considered equal if they have the same
  316. id and are of the same type.
  317. """
  318. if not isinstance(other, NodeFailedOverNotification):
  319. return False
  320. return self.id == other.id and type(self) is type(other)
  321. def __hash__(self) -> int:
  322. """
  323. Return a hash value for the notification to allow
  324. instances to be used in sets and as dictionary keys.
  325. Returns:
  326. int: Hash value based on notification type and id
  327. """
  328. return hash((self.__class__.__name__, int(self.id)))
  329. class OSSNodeMigratingNotification(MaintenanceNotification):
  330. """
  331. Notification for when a Redis OSS API client is used and a node is in the process of migrating slots.
  332. This notification is received when a node starts migrating its slots to another node
  333. during cluster rebalancing or maintenance operations.
  334. Args:
  335. id (int): Unique identifier for this notification
  336. slots (Optional[List[int]]): List of slots being migrated
  337. """
  338. DEFAULT_TTL = 30
  339. def __init__(
  340. self,
  341. id: int,
  342. slots: Optional[str] = None,
  343. ):
  344. super().__init__(id, OSSNodeMigratingNotification.DEFAULT_TTL)
  345. self.slots = slots
  346. def __repr__(self) -> str:
  347. expiry_time = self.creation_time + self.ttl
  348. remaining = max(0, expiry_time - time.monotonic())
  349. return (
  350. f"{self.__class__.__name__}("
  351. f"id={self.id}, "
  352. f"slots={self.slots}, "
  353. f"ttl={self.ttl}, "
  354. f"creation_time={self.creation_time}, "
  355. f"expires_at={expiry_time}, "
  356. f"remaining={remaining:.1f}s, "
  357. f"expired={self.is_expired()}"
  358. f")"
  359. )
  360. def __eq__(self, other) -> bool:
  361. """
  362. Two OSSNodeMigratingNotification notifications are considered equal if they have the same
  363. id and are of the same type.
  364. """
  365. if not isinstance(other, OSSNodeMigratingNotification):
  366. return False
  367. return self.id == other.id and type(self) is type(other)
  368. def __hash__(self) -> int:
  369. """
  370. Return a hash value for the notification to allow
  371. instances to be used in sets and as dictionary keys.
  372. Returns:
  373. int: Hash value based on notification type and id
  374. """
  375. return hash((self.__class__.__name__, int(self.id)))
  376. class OSSNodeMigratedNotification(MaintenanceNotification):
  377. """
  378. Notification for when a Redis OSS API client is used and a node has completed migrating slots.
  379. This notification is received when a node has finished migrating all its slots
  380. to other nodes during cluster rebalancing or maintenance operations.
  381. Args:
  382. id (int): Unique identifier for this notification
  383. nodes_to_slots_mapping (Dict[str, List[Dict[str, str]]]): Map of source node address
  384. to list of destination mappings. Each destination mapping is a dict with
  385. the destination node address as key and the slot range as value.
  386. Structure example:
  387. {
  388. "127.0.0.1:6379": [
  389. {"127.0.0.1:6380": "1-100"},
  390. {"127.0.0.1:6381": "101-200"}
  391. ],
  392. "127.0.0.1:6382": [
  393. {"127.0.0.1:6383": "201-300"}
  394. ]
  395. }
  396. Where:
  397. - Key (str): Source node address in "host:port" format
  398. - Value (List[Dict[str, str]]): List of destination mappings where each dict
  399. contains destination node address as key and slot range as value
  400. """
  401. DEFAULT_TTL = 120
  402. def __init__(
  403. self,
  404. id: int,
  405. nodes_to_slots_mapping: Dict[str, List[Dict[str, str]]],
  406. ):
  407. super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL)
  408. self.nodes_to_slots_mapping = nodes_to_slots_mapping
  409. def __repr__(self) -> str:
  410. expiry_time = self.creation_time + self.ttl
  411. remaining = max(0, expiry_time - time.monotonic())
  412. return (
  413. f"{self.__class__.__name__}("
  414. f"id={self.id}, "
  415. f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, "
  416. f"ttl={self.ttl}, "
  417. f"creation_time={self.creation_time}, "
  418. f"expires_at={expiry_time}, "
  419. f"remaining={remaining:.1f}s, "
  420. f"expired={self.is_expired()}"
  421. f")"
  422. )
  423. def __eq__(self, other) -> bool:
  424. """
  425. Two OSSNodeMigratedNotification notifications are considered equal if they have the same
  426. id and are of the same type.
  427. """
  428. if not isinstance(other, OSSNodeMigratedNotification):
  429. return False
  430. return self.id == other.id and type(self) is type(other)
  431. def __hash__(self) -> int:
  432. """
  433. Return a hash value for the notification to allow
  434. instances to be used in sets and as dictionary keys.
  435. Returns:
  436. int: Hash value based on notification type and id
  437. """
  438. return hash((self.__class__.__name__, int(self.id)))
  439. def _is_private_fqdn(host: str) -> bool:
  440. """
  441. Determine if an FQDN is likely to be internal/private.
  442. This uses heuristics based on RFC 952 and RFC 1123 standards:
  443. - .local domains (RFC 6762 - Multicast DNS)
  444. - .internal domains (common internal convention)
  445. - Single-label hostnames (no dots)
  446. - Common internal TLDs
  447. Args:
  448. host (str): The FQDN to check
  449. Returns:
  450. bool: True if the FQDN appears to be internal/private
  451. """
  452. host_lower = host.lower().rstrip(".")
  453. # Single-label hostnames (no dots) are typically internal
  454. if "." not in host_lower:
  455. return True
  456. # Common internal/private domain patterns
  457. internal_patterns = [
  458. r"\.local$", # mDNS/Bonjour domains
  459. r"\.internal$", # Common internal convention
  460. r"\.corp$", # Corporate domains
  461. r"\.lan$", # Local area network
  462. r"\.intranet$", # Intranet domains
  463. r"\.private$", # Private domains
  464. ]
  465. for pattern in internal_patterns:
  466. if re.search(pattern, host_lower):
  467. return True
  468. # If none of the internal patterns match, assume it's external
  469. return False
  470. notification_types_mapping: dict[type[MaintenanceNotification], str] = {
  471. NodeMovingNotification: "MOVING",
  472. NodeMigratingNotification: "MIGRATING",
  473. NodeMigratedNotification: "MIGRATED",
  474. NodeFailingOverNotification: "FAILING_OVER",
  475. NodeFailedOverNotification: "FAILED_OVER",
  476. OSSNodeMigratingNotification: "SMIGRATING",
  477. OSSNodeMigratedNotification: "SMIGRATED",
  478. }
  479. def add_debug_log_for_notification(
  480. connection: "MaintNotificationsAbstractConnection",
  481. notification: Union[str, MaintenanceNotification],
  482. ):
  483. if logger.isEnabledFor(logging.DEBUG):
  484. socket_address = None
  485. try:
  486. socket_address = (
  487. connection._sock.getsockname() if connection._sock else None
  488. )
  489. socket_address = socket_address[1] if socket_address else None
  490. except (AttributeError, OSError):
  491. pass
  492. logger.debug(
  493. f"Handling maintenance notification: {notification}, "
  494. f"with connection: {connection}, connected to ip {connection.get_resolved_ip()}, "
  495. f"local socket port: {socket_address}",
  496. )
  497. class MaintNotificationsConfig:
  498. """
  499. Configuration class for maintenance notifications handling behaviour. Notifications are received through
  500. push notifications.
  501. This class defines how the Redis client should react to different push notifications
  502. such as node moving, migrations, etc. in a Redis cluster.
  503. """
  504. def __init__(
  505. self,
  506. enabled: Union[bool, Literal["auto"]] = "auto",
  507. proactive_reconnect: bool = True,
  508. relaxed_timeout: Optional[Number] = 10,
  509. endpoint_type: Optional[EndpointType] = None,
  510. ):
  511. """
  512. Initialize a new MaintNotificationsConfig.
  513. Args:
  514. enabled (bool | "auto"): Controls maintenance notifications handling behavior.
  515. - True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
  516. otherwise a ResponseError is raised.
  517. - "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
  518. gracefully handled - a warning is logged and normal operation continues.
  519. - False: Maintenance notifications are completely disabled.
  520. Defaults to "auto".
  521. proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
  522. Defaults to True.
  523. relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.
  524. If -1 is provided - the relaxed timeout is disabled. Defaults to 20.
  525. endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
  526. If None, the endpoint type will be automatically determined based on the host and TLS configuration.
  527. Defaults to None.
  528. Raises:
  529. ValueError: If endpoint_type is provided but is not a valid endpoint type.
  530. """
  531. self.enabled = enabled
  532. self.relaxed_timeout = relaxed_timeout
  533. self.proactive_reconnect = proactive_reconnect
  534. self.endpoint_type = endpoint_type
  535. def __repr__(self) -> str:
  536. return (
  537. f"{self.__class__.__name__}("
  538. f"enabled={self.enabled}, "
  539. f"proactive_reconnect={self.proactive_reconnect}, "
  540. f"relaxed_timeout={self.relaxed_timeout}, "
  541. f"endpoint_type={self.endpoint_type!r}"
  542. f")"
  543. )
  544. def is_relaxed_timeouts_enabled(self) -> bool:
  545. """
  546. Check if the relaxed_timeout is enabled. The '-1' value is used to disable the relaxed_timeout.
  547. If relaxed_timeout is set to None, it will make the operation blocking
  548. and waiting until any response is received.
  549. Returns:
  550. True if the relaxed_timeout is enabled, False otherwise.
  551. """
  552. return self.relaxed_timeout != -1
  553. def get_endpoint_type(
  554. self, host: str, connection: "MaintNotificationsAbstractConnection"
  555. ) -> EndpointType:
  556. """
  557. Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
  558. Logic:
  559. 1. If endpoint_type is explicitly set, use it
  560. 2. Otherwise, check the original host from connection.host:
  561. - If host is an IP address, use it directly to determine internal-ip vs external-ip
  562. - If host is an FQDN, get the resolved IP to determine internal-fqdn vs external-fqdn
  563. Args:
  564. host: User provided hostname to analyze
  565. connection: The connection object to analyze for endpoint type determination
  566. Returns:
  567. """
  568. # If endpoint_type is explicitly set, use it
  569. if self.endpoint_type is not None:
  570. return self.endpoint_type
  571. # Check if the host is an IP address
  572. try:
  573. ip_addr = ipaddress.ip_address(host)
  574. # Host is an IP address - use it directly
  575. is_private = ip_addr.is_private
  576. return EndpointType.INTERNAL_IP if is_private else EndpointType.EXTERNAL_IP
  577. except ValueError:
  578. # Host is an FQDN - need to check resolved IP to determine internal vs external
  579. pass
  580. # Host is an FQDN, get the resolved IP to determine if it's internal or external
  581. resolved_ip = connection.get_resolved_ip()
  582. if resolved_ip:
  583. try:
  584. ip_addr = ipaddress.ip_address(resolved_ip)
  585. is_private = ip_addr.is_private
  586. # Use FQDN types since the original host was an FQDN
  587. return (
  588. EndpointType.INTERNAL_FQDN
  589. if is_private
  590. else EndpointType.EXTERNAL_FQDN
  591. )
  592. except ValueError:
  593. # This shouldn't happen since we got the IP from the socket, but fallback
  594. pass
  595. # Final fallback: use heuristics on the FQDN itself
  596. is_private = _is_private_fqdn(host)
  597. return EndpointType.INTERNAL_FQDN if is_private else EndpointType.EXTERNAL_FQDN
  598. class MaintNotificationsPoolHandler:
  599. def __init__(
  600. self,
  601. pool: "MaintNotificationsAbstractConnectionPool",
  602. config: MaintNotificationsConfig,
  603. ) -> None:
  604. self.pool = pool
  605. self.config = config
  606. self._processed_notifications = set()
  607. self._lock = threading.RLock()
  608. self.connection = None
  609. def set_connection(self, connection: "MaintNotificationsAbstractConnection"):
  610. self.connection = connection
  611. def get_handler_for_connection(self):
  612. # Copy all data that should be shared between connections
  613. # but each connection should have its own pool handler
  614. # since each connection can be in a different state
  615. copy = MaintNotificationsPoolHandler(self.pool, self.config)
  616. copy._processed_notifications = self._processed_notifications
  617. copy._lock = self._lock
  618. copy.connection = None
  619. return copy
  620. def remove_expired_notifications(self):
  621. with self._lock:
  622. for notification in tuple(self._processed_notifications):
  623. if notification.is_expired():
  624. self._processed_notifications.remove(notification)
  625. def handle_notification(self, notification: MaintenanceNotification):
  626. self.remove_expired_notifications()
  627. if isinstance(notification, NodeMovingNotification):
  628. return self.handle_node_moving_notification(notification)
  629. else:
  630. logger.error(f"Unhandled notification type: {notification}")
  631. def handle_node_moving_notification(self, notification: NodeMovingNotification):
  632. if (
  633. not self.config.proactive_reconnect
  634. and not self.config.is_relaxed_timeouts_enabled()
  635. ):
  636. return
  637. with self._lock:
  638. if notification in self._processed_notifications:
  639. # nothing to do in the connection pool handling
  640. # the notification has already been handled or is expired
  641. # just return
  642. return
  643. with self.pool._lock:
  644. logger.debug(
  645. f"Handling node MOVING notification: {notification}, "
  646. f"with connection: {self.connection}, connected to ip "
  647. f"{self.connection.get_resolved_ip() if self.connection else None}"
  648. )
  649. if (
  650. self.config.proactive_reconnect
  651. or self.config.is_relaxed_timeouts_enabled()
  652. ):
  653. # Get the current connected address - if any
  654. # This is the address that is being moved
  655. # and we need to handle only connections
  656. # connected to the same address
  657. moving_address_src = (
  658. self.connection.getpeername() if self.connection else None
  659. )
  660. if getattr(self.pool, "set_in_maintenance", False):
  661. # Set pool in maintenance mode - executed only if
  662. # BlockingConnectionPool is used
  663. self.pool.set_in_maintenance(True)
  664. # Update maintenance state, timeout and optionally host address
  665. # connection settings for matching connections
  666. self.pool.update_connections_settings(
  667. state=MaintenanceState.MOVING,
  668. maintenance_notification_hash=hash(notification),
  669. relaxed_timeout=self.config.relaxed_timeout,
  670. host_address=notification.new_node_host,
  671. matching_address=moving_address_src,
  672. matching_pattern="connected_address",
  673. update_notification_hash=True,
  674. include_free_connections=True,
  675. )
  676. if self.config.proactive_reconnect:
  677. if notification.new_node_host is not None:
  678. self.run_proactive_reconnect(moving_address_src)
  679. else:
  680. threading.Timer(
  681. notification.ttl / 2,
  682. self.run_proactive_reconnect,
  683. args=(moving_address_src,),
  684. ).start()
  685. # Update config for new connections:
  686. # Set state to MOVING
  687. # update host
  688. # if relax timeouts are enabled - update timeouts
  689. kwargs: dict = {
  690. "maintenance_state": MaintenanceState.MOVING,
  691. "maintenance_notification_hash": hash(notification),
  692. }
  693. if notification.new_node_host is not None:
  694. # the host is not updated if the new node host is None
  695. # this happens when the MOVING push notification does not contain
  696. # the new node host - in this case we only update the timeouts
  697. kwargs.update(
  698. {
  699. "host": notification.new_node_host,
  700. }
  701. )
  702. if self.config.is_relaxed_timeouts_enabled():
  703. kwargs.update(
  704. {
  705. "socket_timeout": self.config.relaxed_timeout,
  706. "socket_connect_timeout": self.config.relaxed_timeout,
  707. }
  708. )
  709. self.pool.update_connection_kwargs(**kwargs)
  710. if getattr(self.pool, "set_in_maintenance", False):
  711. self.pool.set_in_maintenance(False)
  712. threading.Timer(
  713. notification.ttl,
  714. self.handle_node_moved_notification,
  715. args=(notification,),
  716. ).start()
  717. record_connection_handoff(
  718. pool_name=get_pool_name(self.pool),
  719. )
  720. self._processed_notifications.add(notification)
  721. def run_proactive_reconnect(self, moving_address_src: Optional[str] = None):
  722. """
  723. Run proactive reconnect for the pool.
  724. Active connections are marked for reconnect after they complete the current command.
  725. Inactive connections are disconnected and will be connected on next use.
  726. """
  727. with self._lock:
  728. with self.pool._lock:
  729. # take care for the active connections in the pool
  730. # mark them for reconnect after they complete the current command
  731. self.pool.update_active_connections_for_reconnect(
  732. moving_address_src=moving_address_src,
  733. )
  734. # take care for the inactive connections in the pool
  735. # delete them and create new ones
  736. self.pool.disconnect_free_connections(
  737. moving_address_src=moving_address_src,
  738. )
  739. def handle_node_moved_notification(self, notification: NodeMovingNotification):
  740. """
  741. Handle the cleanup after a node moving notification expires.
  742. """
  743. notification_hash = hash(notification)
  744. with self._lock:
  745. logger.debug(
  746. f"Reverting temporary changes related to notification: {notification}, "
  747. f"with connection: {self.connection}, connected to ip "
  748. f"{self.connection.get_resolved_ip() if self.connection else None}"
  749. )
  750. # if the current maintenance_notification_hash in kwargs is not matching the notification
  751. # it means there has been a new moving notification after this one
  752. # and we don't need to revert the kwargs yet
  753. if (
  754. self.pool.connection_kwargs.get("maintenance_notification_hash")
  755. == notification_hash
  756. ):
  757. orig_host = self.pool.connection_kwargs.get("orig_host_address")
  758. orig_socket_timeout = self.pool.connection_kwargs.get(
  759. "orig_socket_timeout"
  760. )
  761. orig_connect_timeout = self.pool.connection_kwargs.get(
  762. "orig_socket_connect_timeout"
  763. )
  764. kwargs: dict = {
  765. "maintenance_state": MaintenanceState.NONE,
  766. "maintenance_notification_hash": None,
  767. "host": orig_host,
  768. "socket_timeout": orig_socket_timeout,
  769. "socket_connect_timeout": orig_connect_timeout,
  770. }
  771. self.pool.update_connection_kwargs(**kwargs)
  772. with self.pool._lock:
  773. reset_relaxed_timeout = self.config.is_relaxed_timeouts_enabled()
  774. reset_host_address = self.config.proactive_reconnect
  775. self.pool.update_connections_settings(
  776. relaxed_timeout=-1,
  777. state=MaintenanceState.NONE,
  778. maintenance_notification_hash=None,
  779. matching_notification_hash=notification_hash,
  780. matching_pattern="notification_hash",
  781. update_notification_hash=True,
  782. reset_relaxed_timeout=reset_relaxed_timeout,
  783. reset_host_address=reset_host_address,
  784. include_free_connections=True,
  785. )
  786. class MaintNotificationsConnectionHandler:
  787. # 1 = "starting maintenance" notifications, 0 = "completed maintenance" notifications
  788. _NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = {
  789. NodeMigratingNotification: 1,
  790. NodeFailingOverNotification: 1,
  791. OSSNodeMigratingNotification: 1,
  792. NodeMigratedNotification: 0,
  793. NodeFailedOverNotification: 0,
  794. OSSNodeMigratedNotification: 0,
  795. }
  796. def __init__(
  797. self,
  798. connection: "MaintNotificationsAbstractConnection",
  799. config: MaintNotificationsConfig,
  800. ) -> None:
  801. self.connection = connection
  802. self.config = config
  803. def _get_pool_name(self) -> str:
  804. """
  805. Get the pool name from the connection's pool handler.
  806. Falls back to connection representation if pool is not available.
  807. """
  808. pool_handler = getattr(
  809. self.connection, "_maint_notifications_pool_handler", None
  810. )
  811. if pool_handler and getattr(pool_handler, "pool", None):
  812. return get_pool_name(pool_handler.pool)
  813. # Fallback for standalone connections without a pool
  814. return repr(self.connection)
  815. def handle_notification(self, notification: MaintenanceNotification):
  816. # get the notification type by checking its class in the _NOTIFICATION_TYPES dict
  817. notification_type = self._NOTIFICATION_TYPES.get(notification.__class__, None)
  818. maint_notification = notification_types_mapping.get(notification.__class__, "")
  819. record_maint_notification_count(
  820. server_address=self.connection.host,
  821. server_port=self.connection.port,
  822. network_peer_address=self.connection.host,
  823. network_peer_port=self.connection.port,
  824. maint_notification=maint_notification,
  825. )
  826. if notification_type is None:
  827. logger.error(f"Unhandled notification type: {notification}")
  828. return
  829. if notification_type:
  830. self.handle_maintenance_start_notification(
  831. MaintenanceState.MAINTENANCE, notification
  832. )
  833. else:
  834. self.handle_maintenance_completed_notification(notification=notification)
  835. def handle_maintenance_start_notification(
  836. self, maintenance_state: MaintenanceState, notification: MaintenanceNotification
  837. ):
  838. add_debug_log_for_notification(self.connection, notification)
  839. if (
  840. self.connection.maintenance_state == MaintenanceState.MOVING
  841. or not self.config.is_relaxed_timeouts_enabled()
  842. ):
  843. return
  844. self.connection.maintenance_state = maintenance_state
  845. self.connection.set_tmp_settings(
  846. tmp_relaxed_timeout=self.config.relaxed_timeout
  847. )
  848. # extend the timeout for all created connections
  849. self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
  850. if isinstance(notification, OSSNodeMigratingNotification):
  851. # add the notification id to the set of processed start maint notifications
  852. # this is used to skip the unrelaxing of the timeouts if we have received more than
  853. # one start notification before the the final end notification
  854. self.connection.add_maint_start_notification(notification.id)
  855. maint_notification = notification_types_mapping.get(notification.__class__, "")
  856. record_connection_relaxed_timeout(
  857. connection_name=self._get_pool_name(),
  858. maint_notification=maint_notification,
  859. relaxed=True,
  860. )
  861. def handle_maintenance_completed_notification(self, **kwargs):
  862. # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
  863. if (
  864. self.connection.maintenance_state == MaintenanceState.MOVING
  865. or not self.config.is_relaxed_timeouts_enabled()
  866. ):
  867. return
  868. notification = None
  869. if kwargs.get("notification"):
  870. notification = kwargs["notification"]
  871. add_debug_log_for_notification(self.connection, "MAINTENANCE_COMPLETED")
  872. self.connection.reset_tmp_settings(reset_relaxed_timeout=True)
  873. # Maintenance completed - reset the connection
  874. # timeouts by providing -1 as the relaxed timeout
  875. self.connection.update_current_socket_timeout(-1)
  876. self.connection.maintenance_state = MaintenanceState.NONE
  877. # reset the sets that keep track of received start maint
  878. # notifications and skipped end maint notifications
  879. self.connection.reset_received_notifications()
  880. if notification:
  881. maint_notification = notification_types_mapping.get(
  882. notification.__class__, ""
  883. )
  884. record_connection_relaxed_timeout(
  885. connection_name=self._get_pool_name(),
  886. maint_notification=maint_notification,
  887. relaxed=False,
  888. )
  889. class OSSMaintNotificationsHandler:
  890. def __init__(
  891. self,
  892. cluster_client: "MaintNotificationsAbstractRedisCluster",
  893. config: MaintNotificationsConfig,
  894. ) -> None:
  895. self.cluster_client = cluster_client
  896. self.config = config
  897. self._processed_notifications = set()
  898. self._in_progress = set()
  899. self._lock = threading.RLock()
  900. def get_handler_for_connection(self):
  901. # Copy all data that should be shared between connections
  902. # but each connection should have its own pool handler
  903. # since each connection can be in a different state
  904. copy = OSSMaintNotificationsHandler(self.cluster_client, self.config)
  905. copy._processed_notifications = self._processed_notifications
  906. copy._in_progress = self._in_progress
  907. copy._lock = self._lock
  908. return copy
  909. def remove_expired_notifications(self):
  910. with self._lock:
  911. for notification in tuple(self._processed_notifications):
  912. if notification.is_expired():
  913. self._processed_notifications.remove(notification)
  914. def handle_notification(self, notification: MaintenanceNotification):
  915. if isinstance(notification, OSSNodeMigratedNotification):
  916. self.handle_oss_maintenance_completed_notification(notification)
  917. else:
  918. logger.error(f"Unhandled notification type: {notification}")
  919. def handle_oss_maintenance_completed_notification(
  920. self, notification: OSSNodeMigratedNotification
  921. ):
  922. self.remove_expired_notifications()
  923. with self._lock:
  924. if (
  925. notification in self._in_progress
  926. or notification in self._processed_notifications
  927. ):
  928. # we are already handling this notification or it has already been processed
  929. # we should skip in_progress notification since when we reinitialize the cluster
  930. # we execute a CLUSTER SLOTS command that can use a different connection
  931. # that has also has the notification and we don't want to
  932. # process the same notification twice
  933. return
  934. logger.debug(f"Handling SMIGRATED notification: {notification}")
  935. self._in_progress.add(notification)
  936. # Extract the information about the src and destination nodes that are affected
  937. # by the maintenance. nodes_to_slots_mapping structure:
  938. # {
  939. # "src_host:port": [
  940. # {"dest_host:port": "slot_range"},
  941. # ...
  942. # ],
  943. # ...
  944. # }
  945. additional_startup_nodes_info = []
  946. affected_nodes = set()
  947. for (
  948. src_address,
  949. dest_mappings,
  950. ) in notification.nodes_to_slots_mapping.items():
  951. src_host, src_port = src_address.split(":")
  952. src_node = self.cluster_client.nodes_manager.get_node(
  953. host=src_host, port=src_port
  954. )
  955. if src_node is not None:
  956. affected_nodes.add(src_node)
  957. for dest_mapping in dest_mappings:
  958. for dest_address in dest_mapping.keys():
  959. dest_host, dest_port = dest_address.split(":")
  960. additional_startup_nodes_info.append(
  961. (dest_host, int(dest_port))
  962. )
  963. # Updates the cluster slots cache with the new slots mapping
  964. # This will also update the nodes cache with the new nodes mapping
  965. self.cluster_client.nodes_manager.initialize(
  966. disconnect_startup_nodes_pools=False,
  967. additional_startup_nodes_info=additional_startup_nodes_info,
  968. )
  969. all_nodes = set(affected_nodes)
  970. all_nodes = all_nodes.union(
  971. self.cluster_client.nodes_manager.nodes_cache.values()
  972. )
  973. for current_node in all_nodes:
  974. if current_node.redis_connection is None:
  975. continue
  976. with current_node.redis_connection.connection_pool._lock:
  977. handoff_recorded = False
  978. if current_node in affected_nodes:
  979. # mark for reconnect all in use connections to the node - this will force them to
  980. # disconnect after they complete their current commands
  981. # Some of them might be used by sub sub and we don't know which ones - so we disconnect
  982. # all in flight connections after they are done with current command execution
  983. for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
  984. conn.mark_for_reconnect()
  985. record_connection_handoff(
  986. pool_name=get_pool_name(
  987. current_node.redis_connection.connection_pool
  988. )
  989. )
  990. handoff_recorded = True
  991. else:
  992. if logger.isEnabledFor(logging.DEBUG):
  993. logger.debug(
  994. f"SMIGRATED: Node {current_node.name} not affected by maintenance, "
  995. f"skipping mark for reconnect"
  996. )
  997. if (
  998. current_node
  999. not in self.cluster_client.nodes_manager.nodes_cache.values()
  1000. ):
  1001. # disconnect all free connections to the node - this node will be dropped
  1002. # from the cluster, so we don't need to revert the timeouts
  1003. for conn in current_node.redis_connection.connection_pool._get_free_connections():
  1004. conn.disconnect()
  1005. # Only record handoff if not already recorded for this node
  1006. if not handoff_recorded:
  1007. record_connection_handoff(
  1008. pool_name=get_pool_name(
  1009. current_node.redis_connection.connection_pool
  1010. )
  1011. )
  1012. # mark the notification as processed
  1013. self._processed_notifications.add(notification)
  1014. self._in_progress.remove(notification)