circuit.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from abc import ABC, abstractmethod
  2. from enum import Enum
  3. from typing import Callable
  4. import pybreaker
  5. DEFAULT_GRACE_PERIOD = 60
  6. class State(Enum):
  7. CLOSED = "closed"
  8. OPEN = "open"
  9. HALF_OPEN = "half-open"
  10. class CircuitBreaker(ABC):
  11. @property
  12. @abstractmethod
  13. def grace_period(self) -> float:
  14. """The grace period in seconds when the circle should be kept open."""
  15. pass
  16. @grace_period.setter
  17. @abstractmethod
  18. def grace_period(self, grace_period: float):
  19. """Set the grace period in seconds."""
  20. @property
  21. @abstractmethod
  22. def state(self) -> State:
  23. """The current state of the circuit."""
  24. pass
  25. @state.setter
  26. @abstractmethod
  27. def state(self, state: State):
  28. """Set current state of the circuit."""
  29. pass
  30. @property
  31. @abstractmethod
  32. def database(self):
  33. """Database associated with this circuit."""
  34. pass
  35. @database.setter
  36. @abstractmethod
  37. def database(self, database):
  38. """Set database associated with this circuit."""
  39. pass
  40. @abstractmethod
  41. def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
  42. """Callback called when the state of the circuit changes."""
  43. pass
  44. class BaseCircuitBreaker(CircuitBreaker):
  45. """
  46. Base implementation of Circuit Breaker interface.
  47. """
  48. def __init__(self, cb: pybreaker.CircuitBreaker):
  49. self._cb = cb
  50. self._state_pb_mapper = {
  51. State.CLOSED: self._cb.close,
  52. State.OPEN: self._cb.open,
  53. State.HALF_OPEN: self._cb.half_open,
  54. }
  55. self._database = None
  56. @property
  57. def grace_period(self) -> float:
  58. return self._cb.reset_timeout
  59. @grace_period.setter
  60. def grace_period(self, grace_period: float):
  61. self._cb.reset_timeout = grace_period
  62. @property
  63. def state(self) -> State:
  64. return State(value=self._cb.state.name)
  65. @state.setter
  66. def state(self, state: State):
  67. self._state_pb_mapper[state]()
  68. @property
  69. def database(self):
  70. return self._database
  71. @database.setter
  72. def database(self, database):
  73. self._database = database
  74. @abstractmethod
  75. def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
  76. """Callback called when the state of the circuit changes."""
  77. pass
  78. class PBListener(pybreaker.CircuitBreakerListener):
  79. """Wrapper for callback to be compatible with pybreaker implementation."""
  80. def __init__(
  81. self,
  82. cb: Callable[[CircuitBreaker, State, State], None],
  83. database,
  84. ):
  85. """
  86. Initialize a PBListener instance.
  87. Args:
  88. cb: Callback function that will be called when the circuit breaker state changes.
  89. database: Database instance associated with this circuit breaker.
  90. """
  91. self._cb = cb
  92. self._database = database
  93. def state_change(self, cb, old_state, new_state):
  94. cb = PBCircuitBreakerAdapter(cb)
  95. cb.database = self._database
  96. old_state = State(value=old_state.name)
  97. new_state = State(value=new_state.name)
  98. self._cb(cb, old_state, new_state)
  99. class PBCircuitBreakerAdapter(BaseCircuitBreaker):
  100. def __init__(self, cb: pybreaker.CircuitBreaker):
  101. """
  102. Initialize a PBCircuitBreakerAdapter instance.
  103. This adapter wraps pybreaker's CircuitBreaker implementation to make it compatible
  104. with our CircuitBreaker interface.
  105. Args:
  106. cb: A pybreaker CircuitBreaker instance to be adapted.
  107. """
  108. super().__init__(cb)
  109. def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
  110. listener = PBListener(cb, self.database)
  111. self._cb.add_listener(listener)