| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- from abc import ABC, abstractmethod
- from datetime import datetime, timezone
- from redis.auth.err import InvalidTokenSchemaErr
- class TokenInterface(ABC):
- @abstractmethod
- def is_expired(self) -> bool:
- pass
- @abstractmethod
- def ttl(self) -> float:
- pass
- @abstractmethod
- def try_get(self, key: str) -> str:
- pass
- @abstractmethod
- def get_value(self) -> str:
- pass
- @abstractmethod
- def get_expires_at_ms(self) -> float:
- pass
- @abstractmethod
- def get_received_at_ms(self) -> float:
- pass
- class TokenResponse:
- def __init__(self, token: TokenInterface):
- self._token = token
- def get_token(self) -> TokenInterface:
- return self._token
- def get_ttl_ms(self) -> float:
- return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
- class SimpleToken(TokenInterface):
- def __init__(
- self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
- ) -> None:
- self.value = value
- self.expires_at = expires_at_ms
- self.received_at = received_at_ms
- self.claims = claims
- def ttl(self) -> float:
- if self.expires_at == -1:
- return -1
- return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
- def is_expired(self) -> bool:
- if self.expires_at == -1:
- return False
- return self.ttl() <= 0
- def try_get(self, key: str) -> str:
- return self.claims.get(key)
- def get_value(self) -> str:
- return self.value
- def get_expires_at_ms(self) -> float:
- return self.expires_at
- def get_received_at_ms(self) -> float:
- return self.received_at
- class JWToken(TokenInterface):
- REQUIRED_FIELDS = {"exp"}
- def __init__(self, token: str):
- try:
- import jwt
- except ImportError as ie:
- raise ImportError(
- f"The PyJWT library is required for {self.__class__.__name__}.",
- ) from ie
- self._value = token
- self._decoded = jwt.decode(
- self._value,
- options={"verify_signature": False},
- algorithms=[jwt.get_unverified_header(self._value).get("alg")],
- )
- self._validate_token()
- def is_expired(self) -> bool:
- exp = self._decoded["exp"]
- if exp == -1:
- return False
- return (
- self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
- )
- def ttl(self) -> float:
- exp = self._decoded["exp"]
- if exp == -1:
- return -1
- return (
- self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
- )
- def try_get(self, key: str) -> str:
- return self._decoded.get(key)
- def get_value(self) -> str:
- return self._value
- def get_expires_at_ms(self) -> float:
- return float(self._decoded["exp"] * 1000)
- def get_received_at_ms(self) -> float:
- return datetime.now(timezone.utc).timestamp() * 1000
- def _validate_token(self):
- actual_fields = {x for x in self._decoded.keys()}
- if len(self.REQUIRED_FIELDS - actual_fields) != 0:
- raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)
|