Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix scan iter command issued to different replicas #3220

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
114 changes: 113 additions & 1 deletion redis/asyncio/sentinel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import asyncio
import random
import weakref
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
from typing import (
Any,
AsyncIterator,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
)

from redis.asyncio.client import Redis
from redis.asyncio.connection import (
Expand Down Expand Up @@ -66,6 +75,22 @@ async def connect(self):
lambda error: asyncio.sleep(0),
)

async def _connect_to_address_retry(self, host: str, port: int) -> None:
if self._reader:
return # already connected
try:
return await self.connect_to((host, port))
except ConnectionError:
raise SlaveNotFoundError

async def connect_to_address(self, host: str, port: int) -> None:
# Connect to the specified host and port
# instead of connecting to the master / rotated slaves
return await self.retry.call_with_retry(
lambda: self._connect_to_address_retry(host, port),
lambda error: asyncio.sleep(0),
)

async def read_response(
self,
disable_decoding: bool = False,
Expand Down Expand Up @@ -122,6 +147,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
self.sentinel_manager = sentinel_manager
self.master_address = None
self.slave_rr_counter = None
self._request_id_to_replica_address = {}

def __repr__(self):
return (
Expand Down Expand Up @@ -167,6 +193,92 @@ async def rotate_slaves(self) -> AsyncIterator:
pass
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")

async def get_connection(
self, command_name: str, *keys: Any, **options: Any
) -> SentinelManagedConnection:
"""
Get a connection from the pool.
'xxxscan_iter' ('scan_iter', 'hscan_iter', 'sscan_iter', 'zscan_iter')
commands needs to be handled specially.
If the client is created using a connection pool, in replica mode,
all 'scan' command-equivalent of the 'xxx_scan_iter' commands needs
to be issued to the same Redis replica.

The way each server positions each key is different with one another,
and the cursor acts as the offset of the scan.
Hence, all scans coming from a single 'xxx_scan_iter_channel' command
should go to the same replica.
"""
# If not an iter command or in master mode, call superclass' implementation
if not (iter_req_id := options.get("_iter_req_id", None)) or self.is_master:
return await super().get_connection(command_name, *keys, **options)

# Check if this iter request has already been directed to a particular server
(
server_host,
server_port,
) = self._request_id_to_replica_address.get(iter_req_id, (None, None))
connection = None
# If this is the first scan request of the iter command,
# get a connection from the pool
if server_host is None or server_port is None:
try:
connection = self._available_connections.pop()
except IndexError:
connection = self.make_connection()
# If this is not the first scan request of the iter command
else:
# Check from the available connections, if any of the connection
# is connected to the host and port that we want
for available_connection in self._available_connections.copy():
# if yes, use that connection
if (
available_connection.host == server_host
and available_connection.port == server_port
):
self._available_connections.remove(available_connection)
connection = available_connection
# If not, make a new dummy connection object, and set its host and port
# to the one that we want later in the call to ``connect_to_address``
if not connection:
connection = self.make_connection()
assert connection
self._in_use_connections.add(connection)
try:
# Ensure this connection is connected to Redis
# If this is the first scan request, it will
# call rotate_slaves and connect to a random replica
if server_port is None or server_port is None:
await connection.connect()
# If this is not the first scan request,
# connect to the previous replica.
# This will connect to the host and port of the replica
else:
await connection.connect_to_address(server_host, server_port)
# Connections that the pool provides should be ready to send
# a command. If not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. Either way, reconnect and verify everything is good.
try:
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except (ConnectionError, OSError):
await connection.disconnect()
await connection.connect()
if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None
except BaseException:
# Release the connection back to the pool so that we don't
# leak it
await self.release(connection)
raise
# Store the connection to the dictionary
self._request_id_to_replica_address[iter_req_id] = (
connection.host,
connection.port,
)
return connection


class Sentinel(AsyncSentinelCommands):
"""
Expand Down
43 changes: 37 additions & 6 deletions redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import hashlib
import uuid
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -3052,6 +3053,7 @@ def sscan(
cursor: int = 0,
match: Union[PatternT, None] = None,
count: Union[int, None] = None,
**kwargs,
) -> ResponseT:
"""
Incrementally return lists of elements in a set. Also return a cursor
Expand All @@ -3068,7 +3070,7 @@ def sscan(
pieces.extend([b"MATCH", match])
if count is not None:
pieces.extend([b"COUNT", count])
return self.execute_command("SSCAN", *pieces)
return self.execute_command("SSCAN", *pieces, **kwargs)

def sscan_iter(
self,
Expand Down Expand Up @@ -3096,6 +3098,7 @@ def hscan(
match: Union[PatternT, None] = None,
count: Union[int, None] = None,
no_values: Union[bool, None] = None,
**kwargs,
) -> ResponseT:
"""
Incrementally return key/value slices in a hash. Also return a cursor
Expand All @@ -3116,7 +3119,7 @@ def hscan(
pieces.extend([b"COUNT", count])
if no_values is not None:
pieces.extend([b"NOVALUES"])
return self.execute_command("HSCAN", *pieces, no_values=no_values)
return self.execute_command("HSCAN", *pieces, no_values=no_values, **kwargs)

def hscan_iter(
self,
Expand Down Expand Up @@ -3152,6 +3155,7 @@ def zscan(
match: Union[PatternT, None] = None,
count: Union[int, None] = None,
score_cast_func: Union[type, Callable] = float,
**kwargs,
) -> ResponseT:
"""
Incrementally return lists of elements in a sorted set. Also return a
Expand All @@ -3171,7 +3175,7 @@ def zscan(
if count is not None:
pieces.extend([b"COUNT", count])
options = {"score_cast_func": score_cast_func}
return self.execute_command("ZSCAN", *pieces, **options)
return self.execute_command("ZSCAN", *pieces, **options, **kwargs)

def zscan_iter(
self,
Expand Down Expand Up @@ -3224,10 +3228,19 @@ async def scan_iter(
HASH, LIST, SET, STREAM, STRING, ZSET
Additionally, Redis modules can expose other types as well.
"""
# DO NOT inline this statement to the scan call
# Each iter command should have an ID to maintain
# connection to the same replica
iter_req_id = uuid.uuid4()
cursor = "0"
while cursor != 0:
cursor, data = await self.scan(
cursor=cursor, match=match, count=count, _type=_type, **kwargs
cursor=cursor,
match=match,
count=count,
_type=_type,
_iter_req_id=iter_req_id,
**kwargs,
)
for d in data:
yield d
Expand All @@ -3246,10 +3259,14 @@ async def sscan_iter(

``count`` allows for hint the minimum number of returns
"""
# DO NOT inline this statement to the scan call
# Each iter command should have an ID to maintain
# connection to the same replica
iter_req_id = uuid.uuid4()
cursor = "0"
while cursor != 0:
cursor, data = await self.sscan(
name, cursor=cursor, match=match, count=count
name, cursor=cursor, match=match, count=count, _iter_req_id=iter_req_id
)
for d in data:
yield d
Expand All @@ -3271,10 +3288,19 @@ async def hscan_iter(

``no_values`` indicates to return only the keys, without values
"""
# DO NOT inline this statement to the scan call
# Each iter command should have an ID to maintain
# connection to the same replica
iter_req_id = uuid.uuid4()
cursor = "0"
while cursor != 0:
cursor, data = await self.hscan(
name, cursor=cursor, match=match, count=count, no_values=no_values
name,
cursor=cursor,
match=match,
count=count,
no_values=no_values,
_iter_req_id=iter_req_id,
)
if no_values:
for it in data:
Expand All @@ -3300,6 +3326,10 @@ async def zscan_iter(

``score_cast_func`` a callable used to cast the score return value
"""
# DO NOT inline this statement to the scan call
# Each iter command should have an ID to maintain
# connection to the same replica
iter_req_id = uuid.uuid4()
cursor = "0"
while cursor != 0:
cursor, data = await self.zscan(
Expand All @@ -3308,6 +3338,7 @@ async def zscan_iter(
match=match,
count=count,
score_cast_func=score_cast_func,
_iter_req_id=iter_req_id,
)
for d in data:
yield d
Expand Down