Skip to content

Commit

Permalink
ENH: Add the capability to swap the singleton bit generator
Browse files Browse the repository at this point in the history
Add a new version or seed that supports seeding any bit gen
Add set/get_bit_generator as explicity methodds to support swapping

closes #21808
  • Loading branch information
bashtage committed Jul 13, 2022
1 parent 3784656 commit 011189c
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 5 deletions.
94 changes: 89 additions & 5 deletions numpy/random/mtrand.pyx
Expand Up @@ -223,7 +223,7 @@ cdef class RandomState:

def seed(self, seed=None):
"""
seed(self, seed=None)
seed(seed=None)
Reseed a legacy MT19937 BitGenerator
Expand All @@ -248,7 +248,7 @@ cdef class RandomState:

def get_state(self, legacy=True):
"""
get_state()
get_state(legacy=True)
Return a tuple representing the internal state of the generator.
Expand All @@ -258,12 +258,13 @@ cdef class RandomState:
----------
legacy : bool, optional
Flag indicating to return a legacy tuple state when the BitGenerator
is MT19937, instead of a dict.
is MT19937, instead of a dict. Raises ValueError if the underlying
bit generator is not an instance of MT19937.
Returns
-------
out : {tuple(str, ndarray of 624 uints, int, int, float), dict}
The returned tuple has the following items:
If legacy is True, the returned tuple has the following items:
1. the string 'MT19937'.
2. a 1-D array of 624 unsigned integer keys.
Expand Down Expand Up @@ -293,6 +294,11 @@ cdef class RandomState:
legacy = False
st['has_gauss'] = self._aug_state.has_gauss
st['gauss'] = self._aug_state.gauss
if legacy and not isinstance(self._bit_generator, _MT19937):
raise ValueError(
"legacy can only be True when the underlyign bitgenerator is "
"an instance of MT19937."
)
if legacy:
return (st['bit_generator'], st['state']['key'], st['state']['pos'],
st['has_gauss'], st['gauss'])
Expand Down Expand Up @@ -4689,7 +4695,6 @@ random = _rand.random
random_integers = _rand.random_integers
random_sample = _rand.random_sample
rayleigh = _rand.rayleigh
seed = _rand.seed
set_state = _rand.set_state
shuffle = _rand.shuffle
standard_cauchy = _rand.standard_cauchy
Expand All @@ -4704,6 +4709,83 @@ wald = _rand.wald
weibull = _rand.weibull
zipf = _rand.zipf

def seed(seed=None):
"""
seed(seed=None)
Reseed the singleton RandomState instance.
Notes
-----
This is a convenience, legacy function that exists to support
older code that uses the singleton RandomState. Best practice
is to use a dedicated ``Generator`` instance rather than
the random variate generation methods exposed directly in
the random module.
See Also
--------
numpy.random.Generator
"""
if isinstance(_rand._bit_generator, _MT19937):
return _rand.seed(seed)
else:
bg_type = type(_rand._bit_generator)
_rand._bit_generator.state = bg_type(seed).state

def get_bit_generator():
"""
Returns the singleton RandomState's bit generator
Returns
-------
BitGenerator
The bit generator that underlies the singleton RandomState instance
Notes
-----
The singleton RandomState provides the random variate generators in the
NumPy random namespace. This function, and its counterpart set method,
provides a path to hot-swap the default MT19937 bit generator with a
user provided alternative. These function are intended to provide
a continuous path where a single underlying bit generator can be
used both with an instance of ``Generator`` and with the singleton
instance of RandomState.
See Also
--------
set_bit_generator
numpy.random.Generator
"""
return _rand._bit_generator

def set_bit_generator(bitgen):
"""
Sets the singleton RandomState's bit generator
Parameters
----------
bitgen
A bit generator instance
Notes
-----
The singleton RandomState provides the random variate generators in the
NumPy random namespace. This function, and its counterpart get method,
provides a path to hot-swap the default MT19937 bit generator with a
user provided alternative. These function are intended to provide
a continuous path where a single underlying bit generator can be
used both with an instance of ``Generator`` and with the singleton
instance of RandomState.
See Also
--------
get_bit_generator
numpy.random.Generator
"""
_rand._bit_generator = bitgen


# Old aliases that should not be removed
def sample(*args, **kwargs):
"""
Expand All @@ -4730,6 +4812,7 @@ __all__ = [
'f',
'gamma',
'geometric',
'get_bit_generator',
'get_state',
'gumbel',
'hypergeometric',
Expand Down Expand Up @@ -4757,6 +4840,7 @@ __all__ = [
'rayleigh',
'sample',
'seed',
'set_bit_generator',
'set_state',
'shuffle',
'standard_cauchy',
Expand Down
41 changes: 41 additions & 0 deletions numpy/random/tests/test_randomstate.py
Expand Up @@ -53,6 +53,14 @@ def int_func(request):
INT_FUNC_HASHES[request.param])


@pytest.fixture
def restore_singleton_bitgen():
"""Ensures that the singleton bitgen is restored after a test"""
orig_bitgen = np.random.get_bit_generator()
yield
np.random.set_bit_generator(orig_bitgen)


def assert_mt19937_state_equal(a, b):
assert_equal(a['bit_generator'], b['bit_generator'])
assert_array_equal(a['state']['key'], b['state']['key'])
Expand Down Expand Up @@ -2020,3 +2028,36 @@ def test_broadcast_size_error():
random.binomial([1, 2], 0.3, size=(2, 1))
with pytest.raises(ValueError):
random.binomial([1, 2], [0.3, 0.7], size=(2, 1))


def test_hot_swap(restore_singleton_bitgen):
# GH 21808
def_bg = np.random.default_rng(0)
bg = def_bg.bit_generator
np.random.set_bit_generator(bg)
assert isinstance(np.random.mtrand._rand._bit_generator, type(bg))

second_bg = np.random.get_bit_generator()
assert bg is second_bg


def test_seed_alt_bit_gen(restore_singleton_bitgen):
bg = PCG64(0)
np.random.set_bit_generator(bg)
state = np.random.get_state(legacy=False)
np.random.seed(1)
new_state = np.random.get_state(legacy=False)
print(state)
print(new_state)
assert state["bit_generator"] == "PCG64"
assert state["state"]["state"] != new_state["state"]["state"]
assert state["state"]["inc"] != new_state["state"]["inc"]


def test_state_error_alt_bit_gen(restore_singleton_bitgen):
# GH 21808
state = np.random.get_state()
bg = PCG64(0)
np.random.set_bit_generator(bg)
with pytest.raises(ValueError, match="state must be for a PCG64"):
np.random.set_state(state)

0 comments on commit 011189c

Please sign in to comment.