Skip to content

Commit

Permalink
Merge pull request #4 from masenf/next-lock
Browse files Browse the repository at this point in the history
Optionally take a lock when stepping the iterator
  • Loading branch information
masenf committed Feb 27, 2020
2 parents 2b7ca32 + 36f72e0 commit 376109b
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 11 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# Regexes for lines to exclude from consideration
exclude_lines =
except ImportError:
pragma: no cover
103 changes: 94 additions & 9 deletions src/iterlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,28 @@
# python 2 compatible
from collections import Sequence
import itertools
import threading

izip = getattr(itertools, "izip", zip) # python2 compatible iter zip
try:
from typing import Any, Callable, Iterable, Iterator, List, Optional, Union
from typing import (
Any,
Callable,
ContextManager,
Iterable,
Iterator,
List,
Optional,
Union,
)
except ImportError:
pass # typing is only used for static analysis


class ConcurrentGeneratorAccess(ValueError):
"""Raised when stepping a generator that is already executing"""


class CachedIterator(object):
"""a tuple-like interface over an iterable that stores iterated values."""

Expand Down Expand Up @@ -44,22 +59,38 @@ def _positive_index(self, index):
self._consume_rest()
pos = len(self._list) - abs(index)
if pos < 0:
raise IndexError('list index out of range')
raise IndexError("list index out of range")
return pos

@staticmethod
def _warn_concurrent_access(exc):
# type: (BaseException) -> None
if "generator already executing" in str(exc):
# MJF: use raise from when py27 support is dropped
raise ConcurrentGeneratorAccess(
"Concurrent access to iterable detected. When using this interface"
"in a multithreaded environment, use ThreadsafeIterTuple, "
"ThreadsafeIterList, or mix in LockingCachedIterator ahead of "
"IterTuple or IterList bases.\nOriginal Exception: {}".format(exc),
)

def _consume_next(self):
# type: () -> None
exhausted = False
try:
self._list.append(next(self._iterable))
except StopIteration:
exhausted = True
if exhausted:
raise IndexError
except ValueError as ve:
self._warn_concurrent_access(ve)
raise # pragma: no cover

def _consume_rest(self):
# type: () -> None
self._list.extend(self._iterable)
try:
self._list.extend(self._iterable)
except ValueError as ve:
self._warn_concurrent_access(ve)
raise # pragma: no cover

def _consume_up_to_index(self, index):
# type: (int) -> None
Expand Down Expand Up @@ -140,8 +171,7 @@ def __eq__(self, other):
# type: (Any) -> bool
if not isinstance(other, (CachedIterator, Sequence)):
return False
return (all(a == b for a, b in izip(self, other))
and len(self) == len(other))
return all(a == b for a, b in izip(self, other)) and len(self) == len(other)

def __ne__(self, other):
# type: (Any) -> bool
Expand Down Expand Up @@ -174,7 +204,7 @@ def index(self, item, start=0, stop=None):
if e == item:
return i + start

raise ValueError('{} is not in list'.format(item))
raise ValueError("{} is not in list".format(item))

def count(self, item):
# type: (Any) -> int
Expand Down Expand Up @@ -325,3 +355,58 @@ def pop(self, index=-1):
item = self._list[index]
del self._list[index]
return item


class LockingCachedIterator(CachedIterator):
"""protect CachedIterator generator execution with an RLock"""

@staticmethod
def lock_factory():
# type: () -> ContextManager[Any]
"""
Return a contextmanager-like lock implementation.
The default lock is threading.RLock.
Subclasses may use a different lock implementation as long as it
follows contextmanager protocol and is re-entrant.
:return: the lock used to protect generator access
"""
return threading.RLock()

def __init__(self, iterable):
# type: (Iterable) -> None
"""Initialize
:type iterable: Iterable
"""
super(LockingCachedIterator, self).__init__(iterable)
self._lock = self.lock_factory()

def _consume_next(self):
# type: () -> None
with self._lock:
super(LockingCachedIterator, self)._consume_next()

def _consume_rest(self):
# type: () -> None
with self._lock:
super(LockingCachedIterator, self)._consume_rest()

def _consume_up_to_index(self, index):
# type: (int) -> None
with self._lock:
super(LockingCachedIterator, self)._consume_up_to_index(index)


class ThreadsafeIterTuple(LockingCachedIterator, IterTuple):
"""IterTuple which can be safely accessed from multiple threads"""


class ThreadsafeIterList(LockingCachedIterator, IterTuple):
"""IterList which can be safely accessed from multiple threads.
Note that regular list manipulations are NOT protected by lock, only generator
access is protected
"""
51 changes: 51 additions & 0 deletions tests/test_iterlist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import concurrent.futures
import itertools
import math
import threading
import time
import unittest

import iterlist
Expand Down Expand Up @@ -503,5 +506,53 @@ def test_iter_consume_while_iter(self):
self.assertEqual(v, orig[ix])


class TestConcurrentAccess(unittest.TestCase):
def gen_test_multiple_iterators(self, iterlist_clz, delay=0.005, n_threads=5):
orig = list(range(range_size))
delay_generator = (ix for ix in orig if time.sleep(delay) is None)
lazy = iterlist_clz(delay_generator)
with concurrent.futures.ThreadPoolExecutor(n_threads) as tp:
future_results = [tp.submit(lambda: [x for x in lazy]) for _ in
range(n_threads)]
results = [f.result(timeout=1) for f in future_results]
for r in results:
self.assertEqual(r, orig)

def test_multiple_iterators_itertuple(self):
self.gen_test_multiple_iterators(iterlist_clz=iterlist.ThreadsafeIterTuple)

def test_multiple_iterators_iterlist(self):
self.gen_test_multiple_iterators(iterlist_clz=iterlist.ThreadsafeIterList)

def test_multiple_iterators_no_lock(self):
with self.assertRaises(iterlist.ConcurrentGeneratorAccess) as cga:
self.gen_test_multiple_iterators(iterlist_clz=iterlist.IterTuple)
with self.assertRaises(iterlist.ConcurrentGeneratorAccess) as cga:
self.gen_test_multiple_iterators(iterlist_clz=iterlist.IterList)

def gen_test_concurrent_length(self, iterlist_clz, delay=0.005, n_threads=5):
orig = list(range(range_size))
delay_generator = (ix for ix in orig if time.sleep(delay) is None)
lazy = iterlist_clz(delay_generator)
with concurrent.futures.ThreadPoolExecutor(n_threads) as tp:
future_results = [tp.submit(lambda: len(lazy)) for _ in
range(n_threads)]
results = [f.result(timeout=1) for f in future_results]
for r in results:
self.assertEqual(r, len(orig))

def test_concurrent_length_itertuple(self):
self.gen_test_concurrent_length(iterlist_clz=iterlist.ThreadsafeIterTuple)

def test_concurrent_length_iterlist(self):
self.gen_test_concurrent_length(iterlist_clz=iterlist.ThreadsafeIterList)

def test_concurrent_length_no_lock(self):
with self.assertRaises(iterlist.ConcurrentGeneratorAccess) as cga:
self.gen_test_concurrent_length(iterlist_clz=iterlist.IterTuple)
with self.assertRaises(iterlist.ConcurrentGeneratorAccess) as cga:
self.gen_test_concurrent_length(iterlist_clz=iterlist.IterList)


if __name__ == '__main__':
unittest.main()
10 changes: 8 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@ envlist =

[testenv]
deps =
pytest ~= 4.6.0
pytest-cov ~= 2.8.0
pytest
pytest-cov
commands =
pytest --cov iterlist --cov-fail-under 100 --cov-report term-missing {posargs:tests}

[testenv:py27]
deps =
pytest ~= 4.6.0
pytest-cov ~= 2.8.0
futures

[testenv:static]
basepython = python3.7
deps =
Expand Down

0 comments on commit 376109b

Please sign in to comment.