-
Notifications
You must be signed in to change notification settings - Fork 10.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23107 from gnossen/contextvars_propagation
Propagate contextvars to auxiliary threads
- Loading branch information
Showing
8 changed files
with
189 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright 2020 The gRPC authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
def _contextvars_supported(): | ||
"""Determines if the contextvars module is supported. | ||
We use a 'try it and see if it works approach' here rather than predicting | ||
based on interpreter version in order to support older interpreters that | ||
may have a backported module based on, e.g. `threading.local`. | ||
Returns: | ||
A bool indicating whether `contextvars` are supported in the current | ||
environment. | ||
""" | ||
try: | ||
import contextvars | ||
return True | ||
except ImportError: | ||
return False | ||
|
||
|
||
def _run_with_context(target): | ||
"""Runs a callable with contextvars propagated. | ||
If contextvars are supported, the calling thread's context will be copied | ||
and propagated. If they are not supported, this function is equivalent | ||
to the identity function. | ||
Args: | ||
target: A callable object to wrap. | ||
Returns: | ||
A callable object with the same signature as `target` but with | ||
contextvars propagated. | ||
""" | ||
|
||
|
||
if _contextvars_supported(): | ||
import contextvars | ||
def _run_with_context(target): | ||
ctx = contextvars.copy_context() | ||
def _run(*args): | ||
ctx.run(target, *args) | ||
return _run | ||
else: | ||
def _run_with_context(target): | ||
def _run(*args): | ||
target(*args) | ||
return _run |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright 2020 The gRPC authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Test of propagation of contextvars to AuthMetadataPlugin threads..""" | ||
|
||
import contextlib | ||
import logging | ||
import os | ||
import sys | ||
import unittest | ||
|
||
import grpc | ||
|
||
from tests.unit import test_common | ||
|
||
_UNARY_UNARY = "/test/UnaryUnary" | ||
_REQUEST = b"0000" | ||
|
||
|
||
def _unary_unary_handler(request, context): | ||
return request | ||
|
||
|
||
def contextvars_supported(): | ||
try: | ||
import contextvars | ||
return True | ||
except ImportError: | ||
return False | ||
|
||
|
||
class _GenericHandler(grpc.GenericRpcHandler): | ||
|
||
def service(self, handler_call_details): | ||
if handler_call_details.method == _UNARY_UNARY: | ||
return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) | ||
else: | ||
raise NotImplementedError() | ||
|
||
|
||
@contextlib.contextmanager | ||
def _server(): | ||
try: | ||
server = test_common.test_server() | ||
target = 'localhost:0' | ||
port = server.add_insecure_port(target) | ||
server.add_generic_rpc_handlers((_GenericHandler(),)) | ||
server.start() | ||
yield port | ||
finally: | ||
server.stop(None) | ||
|
||
|
||
if contextvars_supported(): | ||
import contextvars | ||
|
||
_EXPECTED_VALUE = 24601 | ||
test_var = contextvars.ContextVar("test_var", default=None) | ||
|
||
def set_up_expected_context(): | ||
test_var.set(_EXPECTED_VALUE) | ||
|
||
class TestCallCredentials(grpc.AuthMetadataPlugin): | ||
|
||
def __call__(self, context, callback): | ||
if test_var.get() != _EXPECTED_VALUE: | ||
raise AssertionError("{} != {}".format(test_var.get(), | ||
_EXPECTED_VALUE)) | ||
callback((), None) | ||
|
||
def assert_called(self, test): | ||
test.assertTrue(self._invoked) | ||
test.assertEqual(_EXPECTED_VALUE, self._recorded_value) | ||
|
||
else: | ||
|
||
def set_up_expected_context(): | ||
pass | ||
|
||
class TestCallCredentials(grpc.AuthMetadataPlugin): | ||
|
||
def __call__(self, context, callback): | ||
callback((), None) | ||
|
||
|
||
# TODO(https://github.com/grpc/grpc/issues/22257) | ||
@unittest.skipIf(os.name == "nt", "LocalCredentials not supported on Windows.") | ||
class ContextVarsPropagationTest(unittest.TestCase): | ||
|
||
def test_propagation_to_auth_plugin(self): | ||
set_up_expected_context() | ||
with _server() as port: | ||
target = "localhost:{}".format(port) | ||
local_credentials = grpc.local_channel_credentials() | ||
test_call_credentials = TestCallCredentials() | ||
call_credentials = grpc.metadata_call_credentials( | ||
test_call_credentials, "test call credentials") | ||
composite_credentials = grpc.composite_channel_credentials( | ||
local_credentials, call_credentials) | ||
with grpc.secure_channel(target, composite_credentials) as channel: | ||
stub = channel.unary_unary(_UNARY_UNARY) | ||
response = stub(_REQUEST, wait_for_ready=True) | ||
self.assertEqual(_REQUEST, response) | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.basicConfig() | ||
unittest.main(verbosity=2) |