Skip to content

Commit

Permalink
Lazy import either 'anyio' or 'trio' (#639)
Browse files Browse the repository at this point in the history
* Lazy import either 'anyio' or 'trio'

* Add comments in _synchronization.py

* Add comments in _synchronization.py
  • Loading branch information
tomchristie committed Dec 12, 2022
1 parent b4fd42b commit 6a97dad
Showing 1 changed file with 110 additions and 13 deletions.
123 changes: 110 additions & 13 deletions httpcore/_synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,44 @@
from types import TracebackType
from typing import Optional, Type

import anyio
import sniffio

from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions

# Our async synchronization primatives use either 'anyio' or 'trio' depending
# on if they're running under asyncio or trio.
#
# We take care to only lazily import whichever of these two we need.


class AsyncLock:
def __init__(self) -> None:
self._lock = anyio.Lock()
self._backend = ""

def setup(self) -> None:
"""
Detect if we're running under 'asyncio' or 'trio' and create
a lock with the correct implementation.
"""
self._backend = sniffio.current_async_library()
if self._backend == "trio":
import trio

self._trio_lock = trio.Lock()
else:
import anyio

self._anyio_lock = anyio.Lock()

async def __aenter__(self) -> "AsyncLock":
await self._lock.acquire()
if not self._backend:
self.setup()

if self._backend == "trio":
await self._trio_lock.acquire()
else:
await self._anyio_lock.acquire()

return self

async def __aexit__(
Expand All @@ -21,32 +48,102 @@ async def __aexit__(
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
self._lock.release()
if self._backend == "trio":
self._trio_lock.release()
else:
self._anyio_lock.release()


class AsyncEvent:
def __init__(self) -> None:
self._event = anyio.Event()
self._backend = ""

def setup(self) -> None:
"""
Detect if we're running under 'asyncio' or 'trio' and create
a lock with the correct implementation.
"""
self._backend = sniffio.current_async_library()
if self._backend == "trio":
import trio

self._trio_event = trio.Event()
else:
import anyio

self._anyio_event = anyio.Event()

def set(self) -> None:
self._event.set()
if not self._backend:
self.setup()

if self._backend == "trio":
self._trio_event.set()
else:
self._anyio_event.set()

async def wait(self, timeout: Optional[float] = None) -> None:
exc_map: ExceptionMapping = {TimeoutError: PoolTimeout}
with map_exceptions(exc_map):
with anyio.fail_after(timeout):
await self._event.wait()
if not self._backend:
self.setup()

if self._backend == "trio":
import trio

trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout}
timeout_or_inf = float("inf") if timeout is None else timeout
with map_exceptions(trio_exc_map):
with trio.fail_after(timeout_or_inf):
await self._trio_event.wait()
else:
import anyio

anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout}
with map_exceptions(anyio_exc_map):
with anyio.fail_after(timeout):
await self._anyio_event.wait()


class AsyncSemaphore:
def __init__(self, bound: int) -> None:
self._semaphore = anyio.Semaphore(initial_value=bound, max_value=bound)
self._bound = bound
self._backend = ""

def setup(self) -> None:
"""
Detect if we're running under 'asyncio' or 'trio' and create
a semaphore with the correct implementation.
"""
self._backend = sniffio.current_async_library()
if self._backend == "trio":
import trio

self._trio_semaphore = trio.Semaphore(
initial_value=self._bound, max_value=self._bound
)
else:
import anyio

self._anyio_semaphore = anyio.Semaphore(
initial_value=self._bound, max_value=self._bound
)

async def acquire(self) -> None:
await self._semaphore.acquire()
if not self._backend:
self.setup()

if self._backend == "trio":
await self._trio_semaphore.acquire()
else:
await self._anyio_semaphore.acquire()

async def release(self) -> None:
self._semaphore.release()
if self._backend == "trio":
self._trio_semaphore.release()
else:
self._anyio_semaphore.release()


# Our thread-based synchronization primitives...


class Lock:
Expand Down

0 comments on commit 6a97dad

Please sign in to comment.