token.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from abc import ABC, abstractmethod
  2. from datetime import datetime, timezone
  3. from redis.auth.err import InvalidTokenSchemaErr
  4. class TokenInterface(ABC):
  5. @abstractmethod
  6. def is_expired(self) -> bool:
  7. pass
  8. @abstractmethod
  9. def ttl(self) -> float:
  10. pass
  11. @abstractmethod
  12. def try_get(self, key: str) -> str:
  13. pass
  14. @abstractmethod
  15. def get_value(self) -> str:
  16. pass
  17. @abstractmethod
  18. def get_expires_at_ms(self) -> float:
  19. pass
  20. @abstractmethod
  21. def get_received_at_ms(self) -> float:
  22. pass
  23. class TokenResponse:
  24. def __init__(self, token: TokenInterface):
  25. self._token = token
  26. def get_token(self) -> TokenInterface:
  27. return self._token
  28. def get_ttl_ms(self) -> float:
  29. return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
  30. class SimpleToken(TokenInterface):
  31. def __init__(
  32. self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
  33. ) -> None:
  34. self.value = value
  35. self.expires_at = expires_at_ms
  36. self.received_at = received_at_ms
  37. self.claims = claims
  38. def ttl(self) -> float:
  39. if self.expires_at == -1:
  40. return -1
  41. return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
  42. def is_expired(self) -> bool:
  43. if self.expires_at == -1:
  44. return False
  45. return self.ttl() <= 0
  46. def try_get(self, key: str) -> str:
  47. return self.claims.get(key)
  48. def get_value(self) -> str:
  49. return self.value
  50. def get_expires_at_ms(self) -> float:
  51. return self.expires_at
  52. def get_received_at_ms(self) -> float:
  53. return self.received_at
  54. class JWToken(TokenInterface):
  55. REQUIRED_FIELDS = {"exp"}
  56. def __init__(self, token: str):
  57. try:
  58. import jwt
  59. except ImportError as ie:
  60. raise ImportError(
  61. f"The PyJWT library is required for {self.__class__.__name__}.",
  62. ) from ie
  63. self._value = token
  64. self._decoded = jwt.decode(
  65. self._value,
  66. options={"verify_signature": False},
  67. algorithms=[jwt.get_unverified_header(self._value).get("alg")],
  68. )
  69. self._validate_token()
  70. def is_expired(self) -> bool:
  71. exp = self._decoded["exp"]
  72. if exp == -1:
  73. return False
  74. return (
  75. self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
  76. )
  77. def ttl(self) -> float:
  78. exp = self._decoded["exp"]
  79. if exp == -1:
  80. return -1
  81. return (
  82. self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
  83. )
  84. def try_get(self, key: str) -> str:
  85. return self._decoded.get(key)
  86. def get_value(self) -> str:
  87. return self._value
  88. def get_expires_at_ms(self) -> float:
  89. return float(self._decoded["exp"] * 1000)
  90. def get_received_at_ms(self) -> float:
  91. return datetime.now(timezone.utc).timestamp() * 1000
  92. def _validate_token(self):
  93. actual_fields = {x for x in self._decoded.keys()}
  94. if len(self.REQUIRED_FIELDS - actual_fields) != 0:
  95. raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)