| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717 |
- import asyncio
- import copy
- import enum
- import inspect
- import socket
- import sys
- import time
- import warnings
- import weakref
- from abc import abstractmethod
- from itertools import chain
- from types import MappingProxyType
- from typing import (
- Any,
- Callable,
- Iterable,
- List,
- Mapping,
- Optional,
- Protocol,
- Set,
- Tuple,
- Type,
- TypedDict,
- TypeVar,
- Union,
- )
- from urllib.parse import ParseResult, parse_qs, unquote, urlparse
- from ..observability.attributes import (
- DB_CLIENT_CONNECTION_POOL_NAME,
- DB_CLIENT_CONNECTION_STATE,
- AttributeBuilder,
- ConnectionState,
- get_pool_name,
- )
- from ..utils import SSL_AVAILABLE
- if SSL_AVAILABLE:
- import ssl
- from ssl import SSLContext, TLSVersion, VerifyFlags
- else:
- ssl = None
- TLSVersion = None
- SSLContext = None
- VerifyFlags = None
- from ..auth.token import TokenInterface
- from ..driver_info import DriverInfo, resolve_driver_info
- from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
- from ..utils import deprecated_args, format_error_message
- # the functionality is available in 3.11.x but has a major issue before
- # 3.11.3. See https://github.com/redis/redis-py/issues/2633
- if sys.version_info >= (3, 11, 3):
- from asyncio import timeout as async_timeout
- else:
- from async_timeout import timeout as async_timeout
- from redis.asyncio.observability.recorder import (
- record_connection_closed,
- record_connection_count,
- record_connection_create_time,
- record_connection_wait_time,
- record_error_count,
- )
- from redis.asyncio.retry import Retry
- from redis.backoff import NoBackoff
- from redis.connection import DEFAULT_RESP_VERSION
- from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
- from redis.exceptions import (
- AuthenticationError,
- AuthenticationWrongNumberOfArgsError,
- ConnectionError,
- DataError,
- MaxConnectionsError,
- RedisError,
- ResponseError,
- TimeoutError,
- )
- from redis.observability.metrics import CloseReason
- from redis.typing import EncodableT
- from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
- from .._parsers import (
- BaseParser,
- Encoder,
- _AsyncHiredisParser,
- _AsyncRESP2Parser,
- _AsyncRESP3Parser,
- )
- SYM_STAR = b"*"
- SYM_DOLLAR = b"$"
- SYM_CRLF = b"\r\n"
- SYM_LF = b"\n"
- SYM_EMPTY = b""
- class _Sentinel(enum.Enum):
- sentinel = object()
- SENTINEL = _Sentinel.sentinel
- DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
- if HIREDIS_AVAILABLE:
- DefaultParser = _AsyncHiredisParser
- else:
- DefaultParser = _AsyncRESP2Parser
- class ConnectCallbackProtocol(Protocol):
- def __call__(self, connection: "AbstractConnection"): ...
- class AsyncConnectCallbackProtocol(Protocol):
- async def __call__(self, connection: "AbstractConnection"): ...
- ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
- class AbstractConnection:
- """Manages communication to and from a Redis server"""
- __slots__ = (
- "db",
- "username",
- "client_name",
- "lib_name",
- "lib_version",
- "credential_provider",
- "password",
- "socket_timeout",
- "socket_connect_timeout",
- "redis_connect_func",
- "retry_on_timeout",
- "retry_on_error",
- "health_check_interval",
- "next_health_check",
- "last_active_at",
- "encoder",
- "ssl_context",
- "protocol",
- "_reader",
- "_writer",
- "_parser",
- "_connect_callbacks",
- "_buffer_cutoff",
- "_lock",
- "_socket_read_size",
- "__dict__",
- )
- @deprecated_args(
- args_to_warn=["lib_name", "lib_version"],
- reason="Use 'driver_info' parameter instead. "
- "lib_name and lib_version will be removed in a future version.",
- )
- def __init__(
- self,
- *,
- db: Union[str, int] = 0,
- password: Optional[str] = None,
- socket_timeout: Optional[float] = None,
- socket_connect_timeout: Optional[float] = None,
- retry_on_timeout: bool = False,
- retry_on_error: Union[list, _Sentinel] = SENTINEL,
- encoding: str = "utf-8",
- encoding_errors: str = "strict",
- decode_responses: bool = False,
- parser_class: Type[BaseParser] = DefaultParser,
- socket_read_size: int = 65536,
- health_check_interval: float = 0,
- client_name: Optional[str] = None,
- lib_name: Optional[str] = None,
- lib_version: Optional[str] = None,
- driver_info: Optional[DriverInfo] = None,
- username: Optional[str] = None,
- retry: Optional[Retry] = None,
- redis_connect_func: Optional[ConnectCallbackT] = None,
- encoder_class: Type[Encoder] = Encoder,
- credential_provider: Optional[CredentialProvider] = None,
- protocol: Optional[int] = 2,
- event_dispatcher: Optional[EventDispatcher] = None,
- ):
- """
- Initialize a new async Connection.
- Parameters
- ----------
- driver_info : DriverInfo, optional
- Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version
- are ignored. If not provided, a DriverInfo will be created from lib_name
- and lib_version (or defaults if those are also None).
- lib_name : str, optional
- **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO.
- lib_version : str, optional
- **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO.
- """
- if (username or password) and credential_provider is not None:
- raise DataError(
- "'username' and 'password' cannot be passed along with 'credential_"
- "provider'. Please provide only one of the following arguments: \n"
- "1. 'password' and (optional) 'username'\n"
- "2. 'credential_provider'"
- )
- if event_dispatcher is None:
- self._event_dispatcher = EventDispatcher()
- else:
- self._event_dispatcher = event_dispatcher
- self.db = db
- self.client_name = client_name
- # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version
- self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version)
- self.credential_provider = credential_provider
- self.password = password
- self.username = username
- self.socket_timeout = socket_timeout
- if socket_connect_timeout is None:
- socket_connect_timeout = socket_timeout
- self.socket_connect_timeout = socket_connect_timeout
- self.retry_on_timeout = retry_on_timeout
- if retry_on_error is SENTINEL:
- retry_on_error = []
- if retry_on_timeout:
- retry_on_error.append(TimeoutError)
- retry_on_error.append(socket.timeout)
- retry_on_error.append(asyncio.TimeoutError)
- self.retry_on_error = retry_on_error
- if retry or retry_on_error:
- if not retry:
- self.retry = Retry(NoBackoff(), 1)
- else:
- # deep-copy the Retry object as it is mutable
- self.retry = copy.deepcopy(retry)
- # Update the retry's supported errors with the specified errors
- self.retry.update_supported_errors(retry_on_error)
- else:
- self.retry = Retry(NoBackoff(), 0)
- self.health_check_interval = health_check_interval
- self.next_health_check: float = -1
- self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
- self.redis_connect_func = redis_connect_func
- self._reader: Optional[asyncio.StreamReader] = None
- self._writer: Optional[asyncio.StreamWriter] = None
- self._socket_read_size = socket_read_size
- self.set_parser(parser_class)
- self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
- self._buffer_cutoff = 6000
- self._re_auth_token: Optional[TokenInterface] = None
- self._should_reconnect = False
- try:
- p = int(protocol)
- except TypeError:
- p = DEFAULT_RESP_VERSION
- except ValueError:
- raise ConnectionError("protocol must be an integer")
- else:
- if p < 2 or p > 3:
- raise ConnectionError("protocol must be either 2 or 3")
- self.protocol = p
- def __del__(self, _warnings: Any = warnings):
- # For some reason, the individual streams don't get properly garbage
- # collected and therefore produce no resource warnings. We add one
- # here, in the same style as those from the stdlib.
- if getattr(self, "_writer", None):
- _warnings.warn(
- f"unclosed Connection {self!r}", ResourceWarning, source=self
- )
- try:
- asyncio.get_running_loop()
- self._close()
- except RuntimeError:
- # No actions been taken if pool already closed.
- pass
- def _close(self):
- """
- Internal method to silently close the connection without waiting
- """
- if self._writer:
- self._writer.close()
- self._writer = self._reader = None
- def __repr__(self):
- repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
- return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
- @abstractmethod
- def repr_pieces(self):
- pass
- @property
- def is_connected(self):
- return self._reader is not None and self._writer is not None
- def register_connect_callback(self, callback):
- """
- Register a callback to be called when the connection is established either
- initially or reconnected. This allows listeners to issue commands that
- are ephemeral to the connection, for example pub/sub subscription or
- key tracking. The callback must be a _method_ and will be kept as
- a weak reference.
- """
- wm = weakref.WeakMethod(callback)
- if wm not in self._connect_callbacks:
- self._connect_callbacks.append(wm)
- def deregister_connect_callback(self, callback):
- """
- De-register a previously registered callback. It will no-longer receive
- notifications on connection events. Calling this is not required when the
- listener goes away, since the callbacks are kept as weak methods.
- """
- try:
- self._connect_callbacks.remove(weakref.WeakMethod(callback))
- except ValueError:
- pass
- def set_parser(self, parser_class: Type[BaseParser]) -> None:
- """
- Creates a new instance of parser_class with socket size:
- _socket_read_size and assigns it to the parser for the connection
- :param parser_class: The required parser class
- """
- self._parser = parser_class(socket_read_size=self._socket_read_size)
- async def connect(self):
- """Connects to the Redis server if not already connected"""
- # try once the socket connect with the handshake, retry the whole
- # connect/handshake flow based on retry policy
- await self.retry.call_with_retry(
- lambda: self.connect_check_health(
- check_health=True, retry_socket_connect=False
- ),
- lambda error, failure_count: self.disconnect(
- error=error, failure_count=failure_count
- ),
- with_failure_count=True,
- )
- async def connect_check_health(
- self, check_health: bool = True, retry_socket_connect: bool = True
- ):
- if self.is_connected:
- return
- # Track actual retry attempts for error reporting
- actual_retry_attempts = 0
- def failure_callback(error, failure_count):
- nonlocal actual_retry_attempts
- actual_retry_attempts = failure_count
- return self.disconnect(error=error, failure_count=failure_count)
- try:
- if retry_socket_connect:
- await self.retry.call_with_retry(
- lambda: self._connect(),
- failure_callback,
- with_failure_count=True,
- )
- else:
- await self._connect()
- except asyncio.CancelledError:
- raise # in 3.7 and earlier, this is an Exception, not BaseException
- except (socket.timeout, asyncio.TimeoutError):
- e = TimeoutError("Timeout connecting to server")
- await record_error_count(
- server_address=getattr(self, "host", None),
- server_port=getattr(self, "port", None),
- network_peer_address=getattr(self, "host", None),
- network_peer_port=getattr(self, "port", None),
- error_type=e,
- retry_attempts=actual_retry_attempts,
- is_internal=False,
- )
- raise e
- except OSError as e:
- e = ConnectionError(self._error_message(e))
- await record_error_count(
- server_address=getattr(self, "host", None),
- server_port=getattr(self, "port", None),
- network_peer_address=getattr(self, "host", None),
- network_peer_port=getattr(self, "port", None),
- error_type=e,
- retry_attempts=actual_retry_attempts,
- is_internal=False,
- )
- raise e
- except Exception as exc:
- raise ConnectionError(exc) from exc
- try:
- if not self.redis_connect_func:
- # Use the default on_connect function
- await self.on_connect_check_health(check_health=check_health)
- else:
- # Use the passed function redis_connect_func
- (
- await self.redis_connect_func(self)
- if asyncio.iscoroutinefunction(self.redis_connect_func)
- else self.redis_connect_func(self)
- )
- except RedisError:
- # clean up after any error in on_connect
- await self.disconnect()
- raise
- # run any user callbacks. right now the only internal callback
- # is for pubsub channel/pattern resubscription
- # first, remove any dead weakrefs
- self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
- for ref in self._connect_callbacks:
- callback = ref()
- task = callback(self)
- if task and inspect.isawaitable(task):
- await task
- def mark_for_reconnect(self):
- self._should_reconnect = True
- def should_reconnect(self):
- return self._should_reconnect
- def reset_should_reconnect(self):
- self._should_reconnect = False
- @abstractmethod
- async def _connect(self):
- pass
- @abstractmethod
- def _host_error(self) -> str:
- pass
- def _error_message(self, exception: BaseException) -> str:
- return format_error_message(self._host_error(), exception)
- def get_protocol(self):
- return self.protocol
- async def on_connect(self) -> None:
- """Initialize the connection, authenticate and select a database"""
- await self.on_connect_check_health(check_health=True)
- async def on_connect_check_health(self, check_health: bool = True) -> None:
- self._parser.on_connect(self)
- parser = self._parser
- auth_args = None
- # if credential provider or username and/or password are set, authenticate
- if self.credential_provider or (self.username or self.password):
- cred_provider = (
- self.credential_provider
- or UsernamePasswordCredentialProvider(self.username, self.password)
- )
- auth_args = await cred_provider.get_credentials_async()
- # if resp version is specified and we have auth args,
- # we need to send them via HELLO
- if auth_args and self.protocol not in [2, "2"]:
- if isinstance(self._parser, _AsyncRESP2Parser):
- self.set_parser(_AsyncRESP3Parser)
- # update cluster exception classes
- self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
- self._parser.on_connect(self)
- if len(auth_args) == 1:
- auth_args = ["default", auth_args[0]]
- # avoid checking health here -- PING will fail if we try
- # to check the health prior to the AUTH
- await self.send_command(
- "HELLO", self.protocol, "AUTH", *auth_args, check_health=False
- )
- response = await self.read_response()
- if response.get(b"proto") != int(self.protocol) and response.get(
- "proto"
- ) != int(self.protocol):
- raise ConnectionError("Invalid RESP version")
- # avoid checking health here -- PING will fail if we try
- # to check the health prior to the AUTH
- elif auth_args:
- await self.send_command("AUTH", *auth_args, check_health=False)
- try:
- auth_response = await self.read_response()
- except AuthenticationWrongNumberOfArgsError:
- # a username and password were specified but the Redis
- # server seems to be < 6.0.0 which expects a single password
- # arg. retry auth with just the password.
- # https://github.com/andymccurdy/redis-py/issues/1274
- await self.send_command("AUTH", auth_args[-1], check_health=False)
- auth_response = await self.read_response()
- if str_if_bytes(auth_response) != "OK":
- raise AuthenticationError("Invalid Username or Password")
- # if resp version is specified, switch to it
- elif self.protocol not in [2, "2"]:
- if isinstance(self._parser, _AsyncRESP2Parser):
- self.set_parser(_AsyncRESP3Parser)
- # update cluster exception classes
- self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
- self._parser.on_connect(self)
- await self.send_command("HELLO", self.protocol, check_health=check_health)
- response = await self.read_response()
- # if response.get(b"proto") != self.protocol and response.get(
- # "proto"
- # ) != self.protocol:
- # raise ConnectionError("Invalid RESP version")
- # if a client_name is given, set it
- if self.client_name:
- await self.send_command(
- "CLIENT",
- "SETNAME",
- self.client_name,
- check_health=check_health,
- )
- if str_if_bytes(await self.read_response()) != "OK":
- raise ConnectionError("Error setting client name")
- # Set the library name and version from driver_info, pipeline for lower startup latency
- lib_name_sent = False
- lib_version_sent = False
- if self.driver_info and self.driver_info.formatted_name:
- await self.send_command(
- "CLIENT",
- "SETINFO",
- "LIB-NAME",
- self.driver_info.formatted_name,
- check_health=check_health,
- )
- lib_name_sent = True
- if self.driver_info and self.driver_info.lib_version:
- await self.send_command(
- "CLIENT",
- "SETINFO",
- "LIB-VER",
- self.driver_info.lib_version,
- check_health=check_health,
- )
- lib_version_sent = True
- # if a database is specified, switch to it. Also pipeline this
- if self.db:
- await self.send_command("SELECT", self.db, check_health=check_health)
- # read responses from pipeline
- for _ in range(sum([lib_name_sent, lib_version_sent])):
- try:
- await self.read_response()
- except ResponseError:
- pass
- if self.db:
- if str_if_bytes(await self.read_response()) != "OK":
- raise ConnectionError("Invalid Database")
- async def disconnect(
- self,
- nowait: bool = False,
- error: Optional[Exception] = None,
- failure_count: Optional[int] = None,
- health_check_failed: bool = False,
- ) -> None:
- """Disconnects from the Redis server"""
- try:
- async with async_timeout(self.socket_connect_timeout):
- self._parser.on_disconnect()
- # Reset the reconnect flag
- self.reset_should_reconnect()
- if not self.is_connected:
- return
- try:
- self._writer.close() # type: ignore[union-attr]
- # wait for close to finish, except when handling errors and
- # forcefully disconnecting.
- if not nowait:
- await self._writer.wait_closed() # type: ignore[union-attr]
- except OSError:
- pass
- finally:
- self._reader = None
- self._writer = None
- except asyncio.TimeoutError:
- raise TimeoutError(
- f"Timed out closing connection after {self.socket_connect_timeout}"
- ) from None
- if error:
- if health_check_failed:
- close_reason = CloseReason.HEALTHCHECK_FAILED
- else:
- close_reason = CloseReason.ERROR
- if failure_count is not None and failure_count > self.retry.get_retries():
- await record_error_count(
- server_address=getattr(self, "host", None),
- server_port=getattr(self, "port", None),
- network_peer_address=getattr(self, "host", None),
- network_peer_port=getattr(self, "port", None),
- error_type=error,
- retry_attempts=failure_count,
- )
- await record_connection_closed(
- close_reason=close_reason,
- error_type=error,
- )
- else:
- await record_connection_closed(
- close_reason=CloseReason.APPLICATION_CLOSE,
- )
- async def _send_ping(self):
- """Send PING, expect PONG in return"""
- await self.send_command("PING", check_health=False)
- if str_if_bytes(await self.read_response()) != "PONG":
- raise ConnectionError("Bad response from PING health check")
- async def _ping_failed(self, error, failure_count):
- """Function to call when PING fails"""
- await self.disconnect(
- error=error, failure_count=failure_count, health_check_failed=True
- )
- async def check_health(self):
- """Check the health of the connection with a PING/PONG"""
- if (
- self.health_check_interval
- and asyncio.get_running_loop().time() > self.next_health_check
- ):
- await self.retry.call_with_retry(
- self._send_ping, self._ping_failed, with_failure_count=True
- )
- async def _send_packed_command(self, command: Iterable[bytes]) -> None:
- self._writer.writelines(command)
- await self._writer.drain()
- async def send_packed_command(
- self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
- ) -> None:
- if not self.is_connected:
- await self.connect_check_health(check_health=False)
- if check_health:
- await self.check_health()
- try:
- if isinstance(command, str):
- command = command.encode()
- if isinstance(command, bytes):
- command = [command]
- if self.socket_timeout:
- await asyncio.wait_for(
- self._send_packed_command(command), self.socket_timeout
- )
- else:
- self._writer.writelines(command)
- await self._writer.drain()
- except asyncio.TimeoutError:
- await self.disconnect(nowait=True)
- raise TimeoutError("Timeout writing to socket") from None
- except OSError as e:
- await self.disconnect(nowait=True)
- if len(e.args) == 1:
- err_no, errmsg = "UNKNOWN", e.args[0]
- else:
- err_no = e.args[0]
- errmsg = e.args[1]
- raise ConnectionError(
- f"Error {err_no} while writing to socket. {errmsg}."
- ) from e
- except BaseException:
- # BaseExceptions can be raised when a socket send operation is not
- # finished, e.g. due to a timeout. Ideally, a caller could then re-try
- # to send un-sent data. However, the send_packed_command() API
- # does not support it so there is no point in keeping the connection open.
- await self.disconnect(nowait=True)
- raise
- async def send_command(self, *args: Any, **kwargs: Any) -> None:
- """Pack and send a command to the Redis server"""
- await self.send_packed_command(
- self.pack_command(*args), check_health=kwargs.get("check_health", True)
- )
- async def can_read_destructive(self):
- """Poll the socket to see if there's data that can be read."""
- try:
- return await self._parser.can_read_destructive()
- except OSError as e:
- await self.disconnect(nowait=True)
- host_error = self._host_error()
- raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
- async def read_response(
- self,
- disable_decoding: bool = False,
- timeout: Optional[float] = None,
- *,
- disconnect_on_error: bool = True,
- push_request: Optional[bool] = False,
- ):
- """Read the response from a previously sent command"""
- read_timeout = timeout if timeout is not None else self.socket_timeout
- host_error = self._host_error()
- try:
- if read_timeout is not None and self.protocol in ["3", 3]:
- async with async_timeout(read_timeout):
- response = await self._parser.read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- elif read_timeout is not None:
- async with async_timeout(read_timeout):
- response = await self._parser.read_response(
- disable_decoding=disable_decoding
- )
- elif self.protocol in ["3", 3]:
- response = await self._parser.read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- else:
- response = await self._parser.read_response(
- disable_decoding=disable_decoding
- )
- except asyncio.TimeoutError:
- if timeout is not None:
- # user requested timeout, return None. Operation can be retried
- return None
- # it was a self.socket_timeout error.
- if disconnect_on_error:
- await self.disconnect(nowait=True)
- raise TimeoutError(f"Timeout reading from {host_error}")
- except OSError as e:
- if disconnect_on_error:
- await self.disconnect(nowait=True)
- raise ConnectionError(f"Error while reading from {host_error} : {e.args}")
- except BaseException:
- # Also by default close in case of BaseException. A lot of code
- # relies on this behaviour when doing Command/Response pairs.
- # See #1128.
- if disconnect_on_error:
- await self.disconnect(nowait=True)
- raise
- if self.health_check_interval:
- next_time = asyncio.get_running_loop().time() + self.health_check_interval
- self.next_health_check = next_time
- if isinstance(response, ResponseError):
- raise response from None
- return response
- def pack_command(self, *args: EncodableT) -> List[bytes]:
- """Pack a series of arguments into the Redis protocol"""
- output = []
- # the client might have included 1 or more literal arguments in
- # the command name, e.g., 'CONFIG GET'. The Redis server expects these
- # arguments to be sent separately, so split the first argument
- # manually. These arguments should be bytestrings so that they are
- # not encoded.
- assert not isinstance(args[0], float)
- if isinstance(args[0], str):
- args = tuple(args[0].encode().split()) + args[1:]
- elif b" " in args[0]:
- args = tuple(args[0].split()) + args[1:]
- buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
- buffer_cutoff = self._buffer_cutoff
- for arg in map(self.encoder.encode, args):
- # to avoid large string mallocs, chunk the command into the
- # output list if we're sending large values or memoryviews
- arg_length = len(arg)
- if (
- len(buff) > buffer_cutoff
- or arg_length > buffer_cutoff
- or isinstance(arg, memoryview)
- ):
- buff = SYM_EMPTY.join(
- (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
- )
- output.append(buff)
- output.append(arg)
- buff = SYM_CRLF
- else:
- buff = SYM_EMPTY.join(
- (
- buff,
- SYM_DOLLAR,
- str(arg_length).encode(),
- SYM_CRLF,
- arg,
- SYM_CRLF,
- )
- )
- output.append(buff)
- return output
- def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
- """Pack multiple commands into the Redis protocol"""
- output: List[bytes] = []
- pieces: List[bytes] = []
- buffer_length = 0
- buffer_cutoff = self._buffer_cutoff
- for cmd in commands:
- for chunk in self.pack_command(*cmd):
- chunklen = len(chunk)
- if (
- buffer_length > buffer_cutoff
- or chunklen > buffer_cutoff
- or isinstance(chunk, memoryview)
- ):
- if pieces:
- output.append(SYM_EMPTY.join(pieces))
- buffer_length = 0
- pieces = []
- if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
- output.append(chunk)
- else:
- pieces.append(chunk)
- buffer_length += chunklen
- if pieces:
- output.append(SYM_EMPTY.join(pieces))
- return output
- def _socket_is_empty(self):
- """Check if the socket is empty"""
- return len(self._reader._buffer) == 0
- async def process_invalidation_messages(self):
- while not self._socket_is_empty():
- await self.read_response(push_request=True)
- def set_re_auth_token(self, token: TokenInterface):
- self._re_auth_token = token
- async def re_auth(self):
- if self._re_auth_token is not None:
- await self.send_command(
- "AUTH",
- self._re_auth_token.try_get("oid"),
- self._re_auth_token.get_value(),
- )
- await self.read_response()
- self._re_auth_token = None
- class Connection(AbstractConnection):
- "Manages TCP communication to and from a Redis server"
- def __init__(
- self,
- *,
- host: str = "localhost",
- port: Union[str, int] = 6379,
- socket_keepalive: bool = False,
- socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
- socket_type: int = 0,
- **kwargs,
- ):
- self.host = host
- self.port = int(port)
- self.socket_keepalive = socket_keepalive
- self.socket_keepalive_options = socket_keepalive_options or {}
- self.socket_type = socket_type
- super().__init__(**kwargs)
- def repr_pieces(self):
- pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
- if self.client_name:
- pieces.append(("client_name", self.client_name))
- return pieces
- def _connection_arguments(self) -> Mapping:
- return {"host": self.host, "port": self.port}
- async def _connect(self):
- """Create a TCP socket connection"""
- async with async_timeout(self.socket_connect_timeout):
- reader, writer = await asyncio.open_connection(
- **self._connection_arguments()
- )
- self._reader = reader
- self._writer = writer
- sock = writer.transport.get_extra_info("socket")
- if sock:
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
- try:
- # TCP_KEEPALIVE
- if self.socket_keepalive:
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- for k, v in self.socket_keepalive_options.items():
- sock.setsockopt(socket.SOL_TCP, k, v)
- except (OSError, TypeError):
- # `socket_keepalive_options` might contain invalid options
- # causing an error. Do not leave the connection open.
- writer.close()
- raise
- def _host_error(self) -> str:
- return f"{self.host}:{self.port}"
- class SSLConnection(Connection):
- """Manages SSL connections to and from the Redis server(s).
- This class extends the Connection class, adding SSL functionality, and making
- use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext)
- """
- def __init__(
- self,
- ssl_keyfile: Optional[str] = None,
- ssl_certfile: Optional[str] = None,
- ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
- ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
- ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
- ssl_ca_certs: Optional[str] = None,
- ssl_ca_data: Optional[str] = None,
- ssl_ca_path: Optional[str] = None,
- ssl_check_hostname: bool = True,
- ssl_min_version: Optional[TLSVersion] = None,
- ssl_ciphers: Optional[str] = None,
- ssl_password: Optional[str] = None,
- **kwargs,
- ):
- if not SSL_AVAILABLE:
- raise RedisError("Python wasn't built with SSL support")
- self.ssl_context: RedisSSLContext = RedisSSLContext(
- keyfile=ssl_keyfile,
- certfile=ssl_certfile,
- cert_reqs=ssl_cert_reqs,
- include_verify_flags=ssl_include_verify_flags,
- exclude_verify_flags=ssl_exclude_verify_flags,
- ca_certs=ssl_ca_certs,
- ca_data=ssl_ca_data,
- ca_path=ssl_ca_path,
- check_hostname=ssl_check_hostname,
- min_version=ssl_min_version,
- ciphers=ssl_ciphers,
- password=ssl_password,
- )
- super().__init__(**kwargs)
- def _connection_arguments(self) -> Mapping:
- kwargs = super()._connection_arguments()
- kwargs["ssl"] = self.ssl_context.get()
- return kwargs
- @property
- def keyfile(self):
- return self.ssl_context.keyfile
- @property
- def certfile(self):
- return self.ssl_context.certfile
- @property
- def cert_reqs(self):
- return self.ssl_context.cert_reqs
- @property
- def include_verify_flags(self):
- return self.ssl_context.include_verify_flags
- @property
- def exclude_verify_flags(self):
- return self.ssl_context.exclude_verify_flags
- @property
- def ca_certs(self):
- return self.ssl_context.ca_certs
- @property
- def ca_data(self):
- return self.ssl_context.ca_data
- @property
- def check_hostname(self):
- return self.ssl_context.check_hostname
- @property
- def min_version(self):
- return self.ssl_context.min_version
- class RedisSSLContext:
- __slots__ = (
- "keyfile",
- "certfile",
- "cert_reqs",
- "include_verify_flags",
- "exclude_verify_flags",
- "ca_certs",
- "ca_data",
- "ca_path",
- "context",
- "check_hostname",
- "min_version",
- "ciphers",
- "password",
- )
- def __init__(
- self,
- keyfile: Optional[str] = None,
- certfile: Optional[str] = None,
- cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
- include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
- exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
- ca_certs: Optional[str] = None,
- ca_data: Optional[str] = None,
- ca_path: Optional[str] = None,
- check_hostname: bool = False,
- min_version: Optional[TLSVersion] = None,
- ciphers: Optional[str] = None,
- password: Optional[str] = None,
- ):
- if not SSL_AVAILABLE:
- raise RedisError("Python wasn't built with SSL support")
- self.keyfile = keyfile
- self.certfile = certfile
- if cert_reqs is None:
- cert_reqs = ssl.CERT_NONE
- elif isinstance(cert_reqs, str):
- CERT_REQS = { # noqa: N806
- "none": ssl.CERT_NONE,
- "optional": ssl.CERT_OPTIONAL,
- "required": ssl.CERT_REQUIRED,
- }
- if cert_reqs not in CERT_REQS:
- raise RedisError(
- f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
- )
- cert_reqs = CERT_REQS[cert_reqs]
- self.cert_reqs = cert_reqs
- self.include_verify_flags = include_verify_flags
- self.exclude_verify_flags = exclude_verify_flags
- self.ca_certs = ca_certs
- self.ca_data = ca_data
- self.ca_path = ca_path
- self.check_hostname = (
- check_hostname if self.cert_reqs != ssl.CERT_NONE else False
- )
- self.min_version = min_version
- self.ciphers = ciphers
- self.password = password
- self.context: Optional[SSLContext] = None
- def get(self) -> SSLContext:
- if not self.context:
- context = ssl.create_default_context()
- context.check_hostname = self.check_hostname
- context.verify_mode = self.cert_reqs
- if self.include_verify_flags:
- for flag in self.include_verify_flags:
- context.verify_flags |= flag
- if self.exclude_verify_flags:
- for flag in self.exclude_verify_flags:
- context.verify_flags &= ~flag
- if self.certfile or self.keyfile:
- context.load_cert_chain(
- certfile=self.certfile,
- keyfile=self.keyfile,
- password=self.password,
- )
- if self.ca_certs or self.ca_data or self.ca_path:
- context.load_verify_locations(
- cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
- )
- if self.min_version is not None:
- context.minimum_version = self.min_version
- if self.ciphers is not None:
- context.set_ciphers(self.ciphers)
- self.context = context
- return self.context
- class UnixDomainSocketConnection(AbstractConnection):
- "Manages UDS communication to and from a Redis server"
- def __init__(self, *, path: str = "", **kwargs):
- self.path = path
- super().__init__(**kwargs)
- def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
- pieces = [("path", self.path), ("db", self.db)]
- if self.client_name:
- pieces.append(("client_name", self.client_name))
- return pieces
- async def _connect(self):
- async with async_timeout(self.socket_connect_timeout):
- reader, writer = await asyncio.open_unix_connection(path=self.path)
- self._reader = reader
- self._writer = writer
- await self.on_connect()
- def _host_error(self) -> str:
- return self.path
- FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
- def to_bool(value) -> Optional[bool]:
- if value is None or value == "":
- return None
- if isinstance(value, str) and value.upper() in FALSE_STRINGS:
- return False
- return bool(value)
- def parse_ssl_verify_flags(value):
- # flags are passed in as a string representation of a list,
- # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
- verify_flags_str = value.replace("[", "").replace("]", "")
- verify_flags = []
- for flag in verify_flags_str.split(","):
- flag = flag.strip()
- if not hasattr(VerifyFlags, flag):
- raise ValueError(f"Invalid ssl verify flag: {flag}")
- verify_flags.append(getattr(VerifyFlags, flag))
- return verify_flags
- URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
- {
- "db": int,
- "socket_timeout": float,
- "socket_connect_timeout": float,
- "socket_keepalive": to_bool,
- "retry_on_timeout": to_bool,
- "max_connections": int,
- "health_check_interval": int,
- "ssl_check_hostname": to_bool,
- "ssl_include_verify_flags": parse_ssl_verify_flags,
- "ssl_exclude_verify_flags": parse_ssl_verify_flags,
- "timeout": float,
- }
- )
- class ConnectKwargs(TypedDict, total=False):
- username: str
- password: str
- connection_class: Type[AbstractConnection]
- host: str
- port: int
- db: int
- path: str
- def parse_url(url: str) -> ConnectKwargs:
- parsed: ParseResult = urlparse(url)
- kwargs: ConnectKwargs = {}
- for name, value_list in parse_qs(parsed.query).items():
- if value_list and len(value_list) > 0:
- value = unquote(value_list[0])
- parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
- if parser:
- try:
- kwargs[name] = parser(value)
- except (TypeError, ValueError):
- raise ValueError(f"Invalid value for '{name}' in connection URL.")
- else:
- kwargs[name] = value
- if parsed.username:
- kwargs["username"] = unquote(parsed.username)
- if parsed.password:
- kwargs["password"] = unquote(parsed.password)
- # We only support redis://, rediss:// and unix:// schemes.
- if parsed.scheme == "unix":
- if parsed.path:
- kwargs["path"] = unquote(parsed.path)
- kwargs["connection_class"] = UnixDomainSocketConnection
- elif parsed.scheme in ("redis", "rediss"):
- if parsed.hostname:
- kwargs["host"] = unquote(parsed.hostname)
- if parsed.port:
- kwargs["port"] = int(parsed.port)
- # If there's a path argument, use it as the db argument if a
- # querystring value wasn't specified
- if parsed.path and "db" not in kwargs:
- try:
- kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
- except (AttributeError, ValueError):
- pass
- if parsed.scheme == "rediss":
- kwargs["connection_class"] = SSLConnection
- else:
- valid_schemes = "redis://, rediss://, unix://"
- raise ValueError(
- f"Redis URL must specify one of the following schemes ({valid_schemes})"
- )
- return kwargs
- _CP = TypeVar("_CP", bound="ConnectionPool")
- class ConnectionPool:
- """
- Create a connection pool. ``If max_connections`` is set, then this
- object raises :py:class:`~redis.ConnectionError` when the pool's
- limit is reached.
- By default, TCP connections are created unless ``connection_class``
- is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
- unix sockets.
- :py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
- Any additional keyword arguments are passed to the constructor of
- ``connection_class``.
- """
- @classmethod
- def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
- """
- Return a connection pool configured from the given URL.
- For example::
- redis://[[username]:[password]]@localhost:6379/0
- rediss://[[username]:[password]]@localhost:6379/0
- unix://[username@]/path/to/socket.sock?db=0[&password=password]
- Three URL schemes are supported:
- - `redis://` creates a TCP socket connection. See more at:
- <https://www.iana.org/assignments/uri-schemes/prov/redis>
- - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
- <https://www.iana.org/assignments/uri-schemes/prov/rediss>
- - ``unix://``: creates a Unix Domain Socket connection.
- The username, password, hostname, path and all querystring values
- are passed through urllib.parse.unquote in order to replace any
- percent-encoded values with their corresponding characters.
- There are several ways to specify a database number. The first value
- found will be used:
- 1. A ``db`` querystring option, e.g. redis://localhost?db=0
- 2. If using the redis:// or rediss:// schemes, the path argument
- of the url, e.g. redis://localhost/0
- 3. A ``db`` keyword argument to this function.
- If none of these options are specified, the default db=0 is used.
- All querystring options are cast to their appropriate Python types.
- Boolean arguments can be specified with string values "True"/"False"
- or "Yes"/"No". Values that cannot be properly cast cause a
- ``ValueError`` to be raised. Once parsed, the querystring arguments
- and keyword arguments are passed to the ``ConnectionPool``'s
- class initializer. In the case of conflicting arguments, querystring
- arguments always win.
- """
- url_options = parse_url(url)
- kwargs.update(url_options)
- return cls(**kwargs)
- def __init__(
- self,
- connection_class: Type[AbstractConnection] = Connection,
- max_connections: Optional[int] = None,
- **connection_kwargs,
- ):
- max_connections = max_connections or 2**31
- if not isinstance(max_connections, int) or max_connections < 0:
- raise ValueError('"max_connections" must be a positive integer')
- self.connection_class = connection_class
- self.connection_kwargs = connection_kwargs
- self.max_connections = max_connections
- self._available_connections: List[AbstractConnection] = []
- self._in_use_connections: Set[AbstractConnection] = set()
- self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
- self._lock = asyncio.Lock()
- self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
- if self._event_dispatcher is None:
- self._event_dispatcher = EventDispatcher()
- # Keys that should be redacted in __repr__ to avoid exposing sensitive information
- SENSITIVE_REPR_KEYS = frozenset(
- {
- "password",
- "username",
- "ssl_password",
- "credential_provider",
- }
- )
- def __repr__(self):
- conn_kwargs = ",".join(
- [
- f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}"
- for k, v in self.connection_kwargs.items()
- ]
- )
- return (
- f"<{self.__class__.__module__}.{self.__class__.__name__}"
- f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
- f"({conn_kwargs})>)>"
- )
- def reset(self):
- # Record metrics for connections being removed before clearing
- # (only if attributes exist - they won't during __init__)
- if hasattr(self, "_available_connections") and hasattr(
- self, "_in_use_connections"
- ):
- idle_count = len(self._available_connections)
- in_use_count = len(self._in_use_connections)
- if idle_count > 0 or in_use_count > 0:
- pool_name = get_pool_name(self)
- # Note: Using sync version since reset() is sync
- from redis.observability.recorder import (
- record_connection_count as sync_record_connection_count,
- )
- if idle_count > 0:
- sync_record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.IDLE,
- counter=-idle_count,
- )
- if in_use_count > 0:
- sync_record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.USED,
- counter=-in_use_count,
- )
- self._available_connections = []
- self._in_use_connections = weakref.WeakSet()
- def __del__(self) -> None:
- """Clean up connection pool and record metrics when garbage collected."""
- try:
- if not hasattr(self, "_available_connections") or not hasattr(
- self, "_in_use_connections"
- ):
- return
- idle_count = len(self._available_connections)
- in_use_count = len(self._in_use_connections)
- if idle_count > 0 or in_use_count > 0:
- pool_name = get_pool_name(self)
- # Note: Using sync version since __del__ is sync
- from redis.observability.recorder import (
- record_connection_count as sync_record_connection_count,
- )
- if idle_count > 0:
- sync_record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.IDLE,
- counter=-idle_count,
- )
- if in_use_count > 0:
- sync_record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.USED,
- counter=-in_use_count,
- )
- except Exception:
- pass
- def can_get_connection(self) -> bool:
- """Return True if a connection can be retrieved from the pool."""
- return (
- self._available_connections
- or len(self._in_use_connections) < self.max_connections
- )
- @deprecated_args(
- args_to_warn=["*"],
- reason="Use get_connection() without args instead",
- version="5.3.0",
- )
- async def get_connection(self, command_name=None, *keys, **options):
- """Get a connected connection from the pool"""
- # Track connection count before to detect if a new connection is created
- async with self._lock:
- connections_before = len(self._available_connections) + len(
- self._in_use_connections
- )
- start_time_created = time.monotonic()
- connection = self.get_available_connection()
- connections_after = len(self._available_connections) + len(
- self._in_use_connections
- )
- is_created = connections_after > connections_before
- # Record state transition for observability
- # This ensures counters stay balanced if ensure_connection() fails and release() is called
- pool_name = get_pool_name(self)
- if is_created:
- # New connection created and acquired: just USED +1
- await record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.USED,
- counter=1,
- )
- else:
- # Existing connection acquired from pool: IDLE -> USED
- await record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.IDLE,
- counter=-1,
- )
- await record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.USED,
- counter=1,
- )
- # We now perform the connection check outside of the lock.
- try:
- await self.ensure_connection(connection)
- if is_created:
- await record_connection_create_time(
- connection_pool=self,
- duration_seconds=time.monotonic() - start_time_created,
- )
- return connection
- except BaseException:
- await self.release(connection)
- raise
- def get_available_connection(self):
- """Get a connection from the pool, without making sure it is connected"""
- try:
- connection = self._available_connections.pop()
- except IndexError:
- if len(self._in_use_connections) >= self.max_connections:
- raise MaxConnectionsError("Too many connections") from None
- connection = self.make_connection()
- self._in_use_connections.add(connection)
- return connection
- def get_encoder(self):
- """Return an encoder based on encoding settings"""
- kwargs = self.connection_kwargs
- return self.encoder_class(
- encoding=kwargs.get("encoding", "utf-8"),
- encoding_errors=kwargs.get("encoding_errors", "strict"),
- decode_responses=kwargs.get("decode_responses", False),
- )
- def make_connection(self):
- """Create a new connection. Can be overridden by child classes."""
- # Note: We don't record IDLE here because async uses a sync make_connection
- # but async record_connection_count. The recording is handled in get_connection.
- return self.connection_class(**self.connection_kwargs)
- async def ensure_connection(self, connection: AbstractConnection):
- """Ensure that the connection object is connected and valid"""
- await connection.connect()
- # connections that the pool provides should be ready to send
- # a command. if not, the connection was either returned to the
- # pool before all data has been read or the socket has been
- # closed. either way, reconnect and verify everything is good.
- try:
- if await connection.can_read_destructive():
- raise ConnectionError("Connection has data") from None
- except (ConnectionError, TimeoutError, OSError):
- await connection.disconnect()
- await connection.connect()
- if await connection.can_read_destructive():
- raise ConnectionError("Connection not ready") from None
- async def release(self, connection: AbstractConnection):
- """Releases the connection back to the pool"""
- # Connections should always be returned to the correct pool,
- # not doing so is an error that will cause an exception here.
- self._in_use_connections.remove(connection)
- if connection.should_reconnect():
- await connection.disconnect()
- self._available_connections.append(connection)
- await self._event_dispatcher.dispatch_async(
- AsyncAfterConnectionReleasedEvent(connection)
- )
- # Record state transition: USED -> IDLE
- pool_name = get_pool_name(self)
- await record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.USED,
- counter=-1,
- )
- await record_connection_count(
- pool_name=pool_name,
- connection_state=ConnectionState.IDLE,
- counter=1,
- )
- async def disconnect(self, inuse_connections: bool = True):
- """
- Disconnects connections in the pool
- If ``inuse_connections`` is True, disconnect connections that are
- current in use, potentially by other tasks. Otherwise only disconnect
- connections that are idle in the pool.
- """
- if inuse_connections:
- connections: Iterable[AbstractConnection] = chain(
- self._available_connections, self._in_use_connections
- )
- else:
- connections = self._available_connections
- resp = await asyncio.gather(
- *(connection.disconnect() for connection in connections),
- return_exceptions=True,
- )
- exc = next((r for r in resp if isinstance(r, BaseException)), None)
- if exc:
- raise exc
- async def update_active_connections_for_reconnect(self):
- """
- Mark all active connections for reconnect.
- """
- async with self._lock:
- for conn in self._in_use_connections:
- conn.mark_for_reconnect()
- async def aclose(self) -> None:
- """Close the pool, disconnecting all connections"""
- await self.disconnect()
- def set_retry(self, retry: "Retry") -> None:
- for conn in self._available_connections:
- conn.retry = retry
- for conn in self._in_use_connections:
- conn.retry = retry
- async def re_auth_callback(self, token: TokenInterface):
- async with self._lock:
- for conn in self._available_connections:
- await conn.retry.call_with_retry(
- lambda: conn.send_command(
- "AUTH", token.try_get("oid"), token.get_value()
- ),
- lambda error: self._mock(error),
- )
- await conn.retry.call_with_retry(
- lambda: conn.read_response(), lambda error: self._mock(error)
- )
- for conn in self._in_use_connections:
- conn.set_re_auth_token(token)
- async def _mock(self, error: RedisError):
- """
- Dummy functions, needs to be passed as error callback to retry object.
- :param error:
- :return:
- """
- pass
- def get_connection_count(self) -> List[tuple[int, dict]]:
- """
- Returns a connection count (both idle and in use).
- """
- attributes = AttributeBuilder.build_base_attributes()
- attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self)
- free_connections_attributes = attributes.copy()
- in_use_connections_attributes = attributes.copy()
- free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
- ConnectionState.IDLE.value
- )
- in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = (
- ConnectionState.USED.value
- )
- return [
- (len(self._available_connections), free_connections_attributes),
- (len(self._in_use_connections), in_use_connections_attributes),
- ]
- class BlockingConnectionPool(ConnectionPool):
- """
- A blocking connection pool::
- >>> from redis.asyncio import Redis, BlockingConnectionPool
- >>> client = Redis.from_pool(BlockingConnectionPool())
- It performs the same function as the default
- :py:class:`~redis.asyncio.ConnectionPool` implementation, in that,
- it maintains a pool of reusable connections that can be shared by
- multiple async redis clients.
- The difference is that, in the event that a client tries to get a
- connection from the pool when all of connections are in use, rather than
- raising a :py:class:`~redis.ConnectionError` (as the default
- :py:class:`~redis.asyncio.ConnectionPool` implementation does), it
- blocks the current `Task` for a specified number of seconds until
- a connection becomes available.
- Use ``max_connections`` to increase / decrease the pool size::
- >>> pool = BlockingConnectionPool(max_connections=10)
- Use ``timeout`` to tell it either how many seconds to wait for a connection
- to become available, or to block forever:
- >>> # Block forever.
- >>> pool = BlockingConnectionPool(timeout=None)
- >>> # Raise a ``ConnectionError`` after five seconds if a connection is
- >>> # not available.
- >>> pool = BlockingConnectionPool(timeout=5)
- """
- def __init__(
- self,
- max_connections: int = 50,
- timeout: Optional[float] = 20,
- connection_class: Type[AbstractConnection] = Connection,
- queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
- **connection_kwargs,
- ):
- super().__init__(
- connection_class=connection_class,
- max_connections=max_connections,
- **connection_kwargs,
- )
- self._condition = asyncio.Condition()
- self.timeout = timeout
- @deprecated_args(
- args_to_warn=["*"],
- reason="Use get_connection() without args instead",
- version="5.3.0",
- )
- async def get_connection(self, command_name=None, *keys, **options):
- """Gets a connection from the pool, blocking until one is available"""
- # Start timing for wait time observability
- start_time_acquired = time.monotonic()
- try:
- async with self._condition:
- async with async_timeout(self.timeout):
- await self._condition.wait_for(self.can_get_connection)
- # Track connection count before to detect if a new connection is created
- connections_before = len(self._available_connections) + len(
- self._in_use_connections
- )
- start_time_created = time.monotonic()
- connection = super().get_available_connection()
- connections_after = len(self._available_connections) + len(
- self._in_use_connections
- )
- is_created = connections_after > connections_before
- except asyncio.TimeoutError as err:
- raise ConnectionError("No connection available.") from err
- # We now perform the connection check outside of the lock.
- try:
- await self.ensure_connection(connection)
- if is_created:
- await record_connection_create_time(
- connection_pool=self,
- duration_seconds=time.monotonic() - start_time_created,
- )
- await record_connection_wait_time(
- pool_name=get_pool_name(self),
- duration_seconds=time.monotonic() - start_time_acquired,
- )
- return connection
- except BaseException:
- await self.release(connection)
- raise
- async def release(self, connection: AbstractConnection):
- """Releases the connection back to the pool."""
- async with self._condition:
- await super().release(connection)
- self._condition.notify()
|