Skip to content

Commit

Permalink
[SPARK-44424][CONNECT][PYTHON] Python client for reattaching to exist…
Browse files Browse the repository at this point in the history
…ing execute in Spark Connect

### What changes were proposed in this pull request?

This PR proposes to implement the Python client side for #42228.

Basically this PR applies the same changes of `ExecutePlanResponseReattachableIterator`, and `SparkConnectClient` to PySpark as  the symmetry.

### Why are the changes needed?

To enable the same feature in #42228

### Does this PR introduce _any_ user-facing change?

Yes, see #42228.

### How was this patch tested?

Existing unittests because it enables the feature by default. Also, manual E2E tests.

Closes #42235 from HyukjinKwon/SPARK-44599.

Lead-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon and HyukjinKwon committed Aug 2, 2023
1 parent 79938ee commit 68d8e65
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 82 deletions.
207 changes: 133 additions & 74 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"getLogLevel",
]

from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)
Expand Down Expand Up @@ -50,6 +51,7 @@
Generator,
Type,
TYPE_CHECKING,
Sequence,
)

import pandas as pd
Expand Down Expand Up @@ -558,8 +560,6 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
.. versionadded:: 3.4.0
"""

@classmethod
Expand All @@ -572,32 +572,48 @@ def retry_exception(cls, e: Exception) -> bool:
def __init__(
self,
connection: Union[str, ChannelBuilder],
userId: Optional[str] = None,
channelOptions: Optional[List[Tuple[str, Any]]] = None,
retryPolicy: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None,
channel_options: Optional[List[Tuple[str, Any]]] = None,
retry_policy: Optional[Dict[str, Any]] = None,
use_reattachable_execute: bool = True,
):
"""
Creates a new SparkSession for the Spark Connect interface.
Parameters
----------
connection: Union[str,ChannelBuilder]
connection : str or :class:`ChannelBuilder`
Connection string that is used to extract the connection parameters and configure
the GRPC connection. Or instance of ChannelBuilder that creates GRPC connection.
Defaults to `sc://localhost`.
userId : Optional[str]
user_id : str, optional
Optional unique user ID that is used to differentiate multiple users and
isolate their Spark Sessions. If the `user_id` is not set, will default to
the $USER environment. Defining the user ID as part of the connection string
takes precedence.
channel_options: list of tuple, optional
Additional options that can be passed to the GRPC channel construction.
retry_policy: dict of str and any, optional
Additional configuration for retrying. There are four configurations as below
* ``max_retries``
Maximum number of tries default 15
* ``backoff_multiplier``
Backoff multiplier for the policy. Default: 4(ms)
* ``initial_backoff``
Backoff to wait before the first retry. Default: 50(ms)
* ``max_backoff``
Maximum backoff controls the maximum amount of time to wait before retrying
a failed request. Default: 60000(ms).
use_reattachable_execute: bool
Enable reattachable execution.
"""
self.thread_local = threading.local()

# Parse the connection string.
self._builder = (
connection
if isinstance(connection, ChannelBuilder)
else ChannelBuilder(connection, channelOptions)
else ChannelBuilder(connection, channel_options)
)
self._user_id = None
self._retry_policy = {
Expand All @@ -606,26 +622,35 @@ def __init__(
"initial_backoff": 50,
"max_backoff": 60000,
}
if retryPolicy:
self._retry_policy.update(retryPolicy)
if retry_policy:
self._retry_policy.update(retry_policy)

# Generate a unique session ID for this client. This UUID must be unique to allow
# concurrent Spark sessions of the same user. If the channel is closed, creating
# a new client will create a new session ID.
self._session_id = str(uuid.uuid4())
if self._builder.userId is not None:
self._user_id = self._builder.userId
elif userId is not None:
self._user_id = userId
elif user_id is not None:
self._user_id = user_id
else:
self._user_id = os.getenv("USER", None)

self._channel = self._builder.toChannel()
self._closed = False
self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
self._artifact_manager = ArtifactManager(self._user_id, self._session_id, self._channel)
self._use_reattachable_execute = use_reattachable_execute
# Configure logging for the SparkConnect client.

def disable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = False
return self

def enable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = True
return self

def register_udf(
self,
function: Any,
Expand Down Expand Up @@ -741,7 +766,7 @@ def _resources(self) -> Dict[str, ResourceInformation]:
return resources

def _build_observed_metrics(
self, metrics: List["pb2.ExecutePlanResponse.ObservedMetrics"]
self, metrics: Sequence["pb2.ExecutePlanResponse.ObservedMetrics"]
) -> Iterator[PlanObservedMetrics]:
return (PlanObservedMetrics(x.name, [v for v in x.values]) for x in metrics)

Expand Down Expand Up @@ -1065,17 +1090,29 @@ def _execute(self, req: pb2.ExecutePlanRequest) -> None:
"""
logger.info("Execute")

def handle_response(b: pb2.ExecutePlanResponse) -> None:
if b.session_id != self._session_id:
raise SparkConnectException(
"Received incorrect session identifier for request: "
f"{b.session_id} != {self._session_id}"
)

try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
with attempt:
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
if b.session_id != self._session_id:
raise SparkConnectException(
"Received incorrect session identifier for request: "
f"{b.session_id} != {self._session_id}"
)
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retry_policy, self._builder.metadata()
)
for b in generator:
handle_response(b)
else:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
with attempt:
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
handle_response(b)
except Exception as error:
self._handle_error(error)

Expand All @@ -1092,58 +1129,77 @@ def _execute_and_fetch_as_iterator(
]:
logger.info("ExecuteAndFetchAsIterator")

def handle_response(
b: pb2.ExecutePlanResponse,
) -> Iterator[
Union[
"pa.RecordBatch",
StructType,
PlanMetrics,
PlanObservedMetrics,
Dict[str, Any],
]
]:
if b.session_id != self._session_id:
raise SparkConnectException(
"Received incorrect session identifier for request: "
f"{b.session_id} != {self._session_id}"
)
if b.HasField("metrics"):
logger.debug("Received metric batch.")
yield from self._build_metrics(b.metrics)
if b.observed_metrics:
logger.debug("Received observed metric batch.")
yield from self._build_observed_metrics(b.observed_metrics)
if b.HasField("schema"):
logger.debug("Received the schema.")
dt = types.proto_schema_to_pyspark_data_type(b.schema)
assert isinstance(dt, StructType)
yield dt
if b.HasField("sql_command_result"):
logger.debug("Received the SQL command result.")
yield {"sql_command_result": b.sql_command_result.relation}
if b.HasField("write_stream_operation_start_result"):
field = "write_stream_operation_start_result"
yield {field: b.write_stream_operation_start_result}
if b.HasField("streaming_query_command_result"):
yield {"streaming_query_command_result": b.streaming_query_command_result}
if b.HasField("streaming_query_manager_command_result"):
cmd_result = b.streaming_query_manager_command_result
yield {"streaming_query_manager_command_result": cmd_result}
if b.HasField("get_resources_command_result"):
resources = {}
for key, resource in b.get_resources_command_result.resources.items():
name = resource.name
addresses = [address for address in resource.addresses]
resources[key] = ResourceInformation(name, addresses)
yield {"get_resources_command_result": resources}
if b.HasField("arrow_batch"):
logger.debug(
f"Received arrow batch rows={b.arrow_batch.row_count} "
f"size={len(b.arrow_batch.data)}"
)

with pa.ipc.open_stream(b.arrow_batch.data) as reader:
for batch in reader:
assert isinstance(batch, pa.RecordBatch)
yield batch

try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
with attempt:
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
if b.session_id != self._session_id:
raise SparkConnectException(
"Received incorrect session identifier for request: "
f"{b.session_id} != {self._session_id}"
)
if b.HasField("metrics"):
logger.debug("Received metric batch.")
yield from self._build_metrics(b.metrics)
if b.observed_metrics:
logger.debug("Received observed metric batch.")
yield from self._build_observed_metrics(b.observed_metrics)
if b.HasField("schema"):
logger.debug("Received the schema.")
dt = types.proto_schema_to_pyspark_data_type(b.schema)
assert isinstance(dt, StructType)
yield dt
if b.HasField("sql_command_result"):
logger.debug("Received the SQL command result.")
yield {"sql_command_result": b.sql_command_result.relation}
if b.HasField("write_stream_operation_start_result"):
field = "write_stream_operation_start_result"
yield {field: b.write_stream_operation_start_result}
if b.HasField("streaming_query_command_result"):
yield {
"streaming_query_command_result": b.streaming_query_command_result
}
if b.HasField("streaming_query_manager_command_result"):
cmd_result = b.streaming_query_manager_command_result
yield {"streaming_query_manager_command_result": cmd_result}
if b.HasField("get_resources_command_result"):
resources = {}
for key, resource in b.get_resources_command_result.resources.items():
name = resource.name
addresses = [address for address in resource.addresses]
resources[key] = ResourceInformation(name, addresses)
yield {"get_resources_command_result": resources}
if b.HasField("arrow_batch"):
logger.debug(
f"Received arrow batch rows={b.arrow_batch.row_count} "
f"size={len(b.arrow_batch.data)}"
)

with pa.ipc.open_stream(b.arrow_batch.data) as reader:
for batch in reader:
assert isinstance(batch, pa.RecordBatch)
yield batch
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retry_policy, self._builder.metadata()
)
for b in generator:
yield from handle_response(b)
else:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
with attempt:
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
yield from handle_response(b)
except Exception as error:
self._handle_error(error)

Expand Down Expand Up @@ -1502,6 +1558,9 @@ def __exit__(
self._retry_state.set_done()
return None

def is_first_try(self) -> bool:
return self._retry_state._count == 0


class Retrying:
"""
Expand Down

0 comments on commit 68d8e65

Please sign in to comment.