From 011189c94115df86ec7222e91a44391af2c59518 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Wed, 13 Jul 2022 09:19:43 +0100 Subject: [PATCH] ENH: Add the capability to swap the singleton bit generator Add a new version or seed that supports seeding any bit gen Add set/get_bit_generator as explicity methodds to support swapping closes #21808 --- numpy/random/mtrand.pyx | 94 ++++++++++++++++++++++++-- numpy/random/tests/test_randomstate.py | 41 +++++++++++ 2 files changed, 130 insertions(+), 5 deletions(-) diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index 19d23f6a856b..5d36a5db6b2a 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -223,7 +223,7 @@ cdef class RandomState: def seed(self, seed=None): """ - seed(self, seed=None) + seed(seed=None) Reseed a legacy MT19937 BitGenerator @@ -248,7 +248,7 @@ cdef class RandomState: def get_state(self, legacy=True): """ - get_state() + get_state(legacy=True) Return a tuple representing the internal state of the generator. @@ -258,12 +258,13 @@ cdef class RandomState: ---------- legacy : bool, optional Flag indicating to return a legacy tuple state when the BitGenerator - is MT19937, instead of a dict. + is MT19937, instead of a dict. Raises ValueError if the underlying + bit generator is not an instance of MT19937. Returns ------- out : {tuple(str, ndarray of 624 uints, int, int, float), dict} - The returned tuple has the following items: + If legacy is True, the returned tuple has the following items: 1. the string 'MT19937'. 2. a 1-D array of 624 unsigned integer keys. @@ -293,6 +294,11 @@ cdef class RandomState: legacy = False st['has_gauss'] = self._aug_state.has_gauss st['gauss'] = self._aug_state.gauss + if legacy and not isinstance(self._bit_generator, _MT19937): + raise ValueError( + "legacy can only be True when the underlyign bitgenerator is " + "an instance of MT19937." + ) if legacy: return (st['bit_generator'], st['state']['key'], st['state']['pos'], st['has_gauss'], st['gauss']) @@ -4689,7 +4695,6 @@ random = _rand.random random_integers = _rand.random_integers random_sample = _rand.random_sample rayleigh = _rand.rayleigh -seed = _rand.seed set_state = _rand.set_state shuffle = _rand.shuffle standard_cauchy = _rand.standard_cauchy @@ -4704,6 +4709,83 @@ wald = _rand.wald weibull = _rand.weibull zipf = _rand.zipf +def seed(seed=None): + """ + seed(seed=None) + + Reseed the singleton RandomState instance. + + Notes + ----- + This is a convenience, legacy function that exists to support + older code that uses the singleton RandomState. Best practice + is to use a dedicated ``Generator`` instance rather than + the random variate generation methods exposed directly in + the random module. + + See Also + -------- + numpy.random.Generator + """ + if isinstance(_rand._bit_generator, _MT19937): + return _rand.seed(seed) + else: + bg_type = type(_rand._bit_generator) + _rand._bit_generator.state = bg_type(seed).state + +def get_bit_generator(): + """ + Returns the singleton RandomState's bit generator + + Returns + ------- + BitGenerator + The bit generator that underlies the singleton RandomState instance + + Notes + ----- + The singleton RandomState provides the random variate generators in the + NumPy random namespace. This function, and its counterpart set method, + provides a path to hot-swap the default MT19937 bit generator with a + user provided alternative. These function are intended to provide + a continuous path where a single underlying bit generator can be + used both with an instance of ``Generator`` and with the singleton + instance of RandomState. + + See Also + -------- + set_bit_generator + numpy.random.Generator + """ + return _rand._bit_generator + +def set_bit_generator(bitgen): + """ + Sets the singleton RandomState's bit generator + + Parameters + ---------- + bitgen + A bit generator instance + + Notes + ----- + The singleton RandomState provides the random variate generators in the + NumPy random namespace. This function, and its counterpart get method, + provides a path to hot-swap the default MT19937 bit generator with a + user provided alternative. These function are intended to provide + a continuous path where a single underlying bit generator can be + used both with an instance of ``Generator`` and with the singleton + instance of RandomState. + + See Also + -------- + get_bit_generator + numpy.random.Generator + """ + _rand._bit_generator = bitgen + + # Old aliases that should not be removed def sample(*args, **kwargs): """ @@ -4730,6 +4812,7 @@ __all__ = [ 'f', 'gamma', 'geometric', + 'get_bit_generator', 'get_state', 'gumbel', 'hypergeometric', @@ -4757,6 +4840,7 @@ __all__ = [ 'rayleigh', 'sample', 'seed', + 'set_bit_generator', 'set_state', 'shuffle', 'standard_cauchy', diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py index 861813a95d1f..f1e4d880234f 100644 --- a/numpy/random/tests/test_randomstate.py +++ b/numpy/random/tests/test_randomstate.py @@ -53,6 +53,14 @@ def int_func(request): INT_FUNC_HASHES[request.param]) +@pytest.fixture +def restore_singleton_bitgen(): + """Ensures that the singleton bitgen is restored after a test""" + orig_bitgen = np.random.get_bit_generator() + yield + np.random.set_bit_generator(orig_bitgen) + + def assert_mt19937_state_equal(a, b): assert_equal(a['bit_generator'], b['bit_generator']) assert_array_equal(a['state']['key'], b['state']['key']) @@ -2020,3 +2028,36 @@ def test_broadcast_size_error(): random.binomial([1, 2], 0.3, size=(2, 1)) with pytest.raises(ValueError): random.binomial([1, 2], [0.3, 0.7], size=(2, 1)) + + +def test_hot_swap(restore_singleton_bitgen): + # GH 21808 + def_bg = np.random.default_rng(0) + bg = def_bg.bit_generator + np.random.set_bit_generator(bg) + assert isinstance(np.random.mtrand._rand._bit_generator, type(bg)) + + second_bg = np.random.get_bit_generator() + assert bg is second_bg + + +def test_seed_alt_bit_gen(restore_singleton_bitgen): + bg = PCG64(0) + np.random.set_bit_generator(bg) + state = np.random.get_state(legacy=False) + np.random.seed(1) + new_state = np.random.get_state(legacy=False) + print(state) + print(new_state) + assert state["bit_generator"] == "PCG64" + assert state["state"]["state"] != new_state["state"]["state"] + assert state["state"]["inc"] != new_state["state"]["inc"] + + +def test_state_error_alt_bit_gen(restore_singleton_bitgen): + # GH 21808 + state = np.random.get_state() + bg = PCG64(0) + np.random.set_bit_generator(bg) + with pytest.raises(ValueError, match="state must be for a PCG64"): + np.random.set_state(state)