Skip to content

Commit

Permalink
Merge pull request #23107 from gnossen/contextvars_propagation
Browse files Browse the repository at this point in the history
Propagate contextvars to auxiliary threads
  • Loading branch information
gnossen committed Jun 9, 2020
2 parents 12c3440 + a0e23e2 commit 80e834a
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def fork_handlers_and_grpc_init():
_fork_state.fork_handler_registered = True




class ForkManagedThread(object):
def __init__(self, target, args=()):
if _GRPC_ENABLE_FORK_SUPPORT:
Expand All @@ -102,9 +104,9 @@ class ForkManagedThread(object):
target(*args)
finally:
_fork_state.active_thread_count.decrement()
self._thread = threading.Thread(target=managed_target, args=args)
self._thread = threading.Thread(target=_run_with_context(managed_target), args=args)
else:
self._thread = threading.Thread(target=target, args=args)
self._thread = threading.Thread(target=_run_with_context(target), args=args)

def setDaemon(self, daemonic):
self._thread.daemon = daemonic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def fork_handlers_and_grpc_init():

class ForkManagedThread(object):
def __init__(self, target, args=()):
self._thread = threading.Thread(target=target, args=args)
self._thread = threading.Thread(target=_run_with_context(target), args=args)

def setDaemon(self, daemonic):
self._thread.daemon = daemonic
Expand Down
59 changes: 59 additions & 0 deletions src/python/grpcio/grpc/_cython/_cygrpc/thread.pyx.pxi
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2020 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

def _contextvars_supported():
"""Determines if the contextvars module is supported.
We use a 'try it and see if it works approach' here rather than predicting
based on interpreter version in order to support older interpreters that
may have a backported module based on, e.g. `threading.local`.
Returns:
A bool indicating whether `contextvars` are supported in the current
environment.
"""
try:
import contextvars
return True
except ImportError:
return False


def _run_with_context(target):
"""Runs a callable with contextvars propagated.
If contextvars are supported, the calling thread's context will be copied
and propagated. If they are not supported, this function is equivalent
to the identity function.
Args:
target: A callable object to wrap.
Returns:
A callable object with the same signature as `target` but with
contextvars propagated.
"""


if _contextvars_supported():
import contextvars
def _run_with_context(target):
ctx = contextvars.copy_context()
def _run(*args):
ctx.run(target, *args)
return _run
else:
def _run_with_context(target):
def _run(*args):
target(*args)
return _run
2 changes: 2 additions & 0 deletions src/python/grpcio/grpc/_cython/cygrpc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ include "_cygrpc/iomgr.pyx.pxi"

include "_cygrpc/grpc_gevent.pyx.pxi"

include "_cygrpc/thread.pyx.pxi"

IF UNAME_SYSNAME == "Windows":
include "_cygrpc/fork_windows.pyx.pxi"
ELSE:
Expand Down
3 changes: 3 additions & 0 deletions src/python/grpcio_tests/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ class TestGevent(setuptools.Command):
'unit._cython._channel_test.ChannelTest.test_negative_deadline_connectivity',
# TODO(https://github.com/grpc/grpc/issues/15411) enable this test
'unit._local_credentials_test.LocalCredentialsTest',
# TODO(https://github.com/grpc/grpc/issues/22020) LocalCredentials
# aren't supported with custom io managers.
'unit._contextvars_propagation_test',
'testing._time_test.StrictRealTimeTest',
)
BANNED_WINDOWS_TESTS = (
Expand Down
1 change: 1 addition & 0 deletions src/python/grpcio_tests/tests/tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"unit._channel_connectivity_test.ChannelConnectivityTest",
"unit._channel_ready_future_test.ChannelReadyFutureTest",
"unit._compression_test.CompressionTest",
"unit._contextvars_propagation_test.ContextVarsPropagationTest",
"unit._credentials_test.CredentialsTest",
"unit._cython._cancel_many_calls_test.CancelManyCallsTest",
"unit._cython._channel_test.ChannelTest",
Expand Down
1 change: 1 addition & 0 deletions src/python/grpcio_tests/tests/unit/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ GRPCIO_TESTS_UNIT = [
"_channel_connectivity_test.py",
"_channel_ready_future_test.py",
"_compression_test.py",
"_contextvars_propagation_test.py",
"_credentials_test.py",
"_dns_resolver_test.py",
"_empty_message_test.py",
Expand Down
118 changes: 118 additions & 0 deletions src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2020 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test of propagation of contextvars to AuthMetadataPlugin threads.."""

import contextlib
import logging
import os
import sys
import unittest

import grpc

from tests.unit import test_common

_UNARY_UNARY = "/test/UnaryUnary"
_REQUEST = b"0000"


def _unary_unary_handler(request, context):
return request


def contextvars_supported():
try:
import contextvars
return True
except ImportError:
return False


class _GenericHandler(grpc.GenericRpcHandler):

def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY:
return grpc.unary_unary_rpc_method_handler(_unary_unary_handler)
else:
raise NotImplementedError()


@contextlib.contextmanager
def _server():
try:
server = test_common.test_server()
target = 'localhost:0'
port = server.add_insecure_port(target)
server.add_generic_rpc_handlers((_GenericHandler(),))
server.start()
yield port
finally:
server.stop(None)


if contextvars_supported():
import contextvars

_EXPECTED_VALUE = 24601
test_var = contextvars.ContextVar("test_var", default=None)

def set_up_expected_context():
test_var.set(_EXPECTED_VALUE)

class TestCallCredentials(grpc.AuthMetadataPlugin):

def __call__(self, context, callback):
if test_var.get() != _EXPECTED_VALUE:
raise AssertionError("{} != {}".format(test_var.get(),
_EXPECTED_VALUE))
callback((), None)

def assert_called(self, test):
test.assertTrue(self._invoked)
test.assertEqual(_EXPECTED_VALUE, self._recorded_value)

else:

def set_up_expected_context():
pass

class TestCallCredentials(grpc.AuthMetadataPlugin):

def __call__(self, context, callback):
callback((), None)


# TODO(https://github.com/grpc/grpc/issues/22257)
@unittest.skipIf(os.name == "nt", "LocalCredentials not supported on Windows.")
class ContextVarsPropagationTest(unittest.TestCase):

def test_propagation_to_auth_plugin(self):
set_up_expected_context()
with _server() as port:
target = "localhost:{}".format(port)
local_credentials = grpc.local_channel_credentials()
test_call_credentials = TestCallCredentials()
call_credentials = grpc.metadata_call_credentials(
test_call_credentials, "test call credentials")
composite_credentials = grpc.composite_channel_credentials(
local_credentials, call_credentials)
with grpc.secure_channel(target, composite_credentials) as channel:
stub = channel.unary_unary(_UNARY_UNARY)
response = stub(_REQUEST, wait_for_ready=True)
self.assertEqual(_REQUEST, response)


if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)

0 comments on commit 80e834a

Please sign in to comment.