diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index 77f70e9fec7a..da8ab64e2217 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -220,8 +220,11 @@ cdef class Generator: self.bit_generator.state = state def __reduce__(self): + ctor, name_tpl, state = self._bit_generator.__reduce__() + from ._pickle import __generator_ctor - return __generator_ctor, (self.bit_generator.state['bit_generator'],), self.bit_generator.state + # Requirements of __generator_ctor are (name, ctor) + return __generator_ctor, (name_tpl[0], ctor), state @property def bit_generator(self): diff --git a/numpy/random/_pickle.py b/numpy/random/_pickle.py index 5e89071e88ab..073993726eb3 100644 --- a/numpy/random/_pickle.py +++ b/numpy/random/_pickle.py @@ -14,19 +14,19 @@ } -def __generator_ctor(bit_generator_name='MT19937'): +def __bit_generator_ctor(bit_generator_name='MT19937'): """ - Pickling helper function that returns a Generator object + Pickling helper function that returns a bit generator object Parameters ---------- bit_generator_name : str - String containing the core BitGenerator + String containing the name of the BitGenerator Returns ------- - rg : Generator - Generator using the named core BitGenerator + bit_generator : BitGenerator + BitGenerator instance """ if bit_generator_name in BitGenerators: bit_generator = BitGenerators[bit_generator_name] @@ -34,50 +34,47 @@ def __generator_ctor(bit_generator_name='MT19937'): raise ValueError(str(bit_generator_name) + ' is not a known ' 'BitGenerator module.') - return Generator(bit_generator()) + return bit_generator() -def __bit_generator_ctor(bit_generator_name='MT19937'): +def __generator_ctor(bit_generator_name="MT19937", + bit_generator_ctor=__bit_generator_ctor): """ - Pickling helper function that returns a bit generator object + Pickling helper function that returns a Generator object Parameters ---------- bit_generator_name : str - String containing the name of the BitGenerator + String containing the core BitGenerator's name + bit_generator_ctor : callable, optional + Callable function that takes bit_generator_name as its only argument + and returns an instantized bit generator. Returns ------- - bit_generator : BitGenerator - BitGenerator instance + rg : Generator + Generator using the named core BitGenerator """ - if bit_generator_name in BitGenerators: - bit_generator = BitGenerators[bit_generator_name] - else: - raise ValueError(str(bit_generator_name) + ' is not a known ' - 'BitGenerator module.') - - return bit_generator() + return Generator(bit_generator_ctor(bit_generator_name)) -def __randomstate_ctor(bit_generator_name='MT19937'): +def __randomstate_ctor(bit_generator_name="MT19937", + bit_generator_ctor=__bit_generator_ctor): """ Pickling helper function that returns a legacy RandomState-like object Parameters ---------- bit_generator_name : str - String containing the core BitGenerator + String containing the core BitGenerator's name + bit_generator_ctor : callable, optional + Callable function that takes bit_generator_name as its only argument + and returns an instantized bit generator. Returns ------- rs : RandomState Legacy RandomState using the named core BitGenerator """ - if bit_generator_name in BitGenerators: - bit_generator = BitGenerators[bit_generator_name] - else: - raise ValueError(str(bit_generator_name) + ' is not a known ' - 'BitGenerator module.') - return RandomState(bit_generator()) + return RandomState(bit_generator_ctor(bit_generator_name)) diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index 19d23f6a856b..c9cdb5839a04 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -213,9 +213,10 @@ cdef class RandomState: self.set_state(state) def __reduce__(self): - state = self.get_state(legacy=False) + ctor, name_tpl, _ = self._bit_generator.__reduce__() + from ._pickle import __randomstate_ctor - return __randomstate_ctor, (state['bit_generator'],), state + return __randomstate_ctor, (name_tpl[0], ctor), self.get_state(legacy=False) cdef _reset_gauss(self): self._aug_state.has_gauss = 0 diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index fa55ac0ee96a..1180cca5a2f4 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -2695,3 +2695,16 @@ def test_contig_req_out(dist, order, dtype): assert variates is out variates = dist(out=out, dtype=dtype, size=out.shape) assert variates is out + + +def test_generator_ctor_old_style_pickle(): + rg = np.random.Generator(np.random.PCG64DXSM(0)) + rg.standard_normal(1) + # Directly call reduce which is used in pickling + ctor, args, state_a = rg.__reduce__() + # Simulate unpickling an old pickle that only has the name + assert args[:1] == ("PCG64DXSM",) + b = ctor(*args[:1]) + b.bit_generator.state = state_a + state_b = b.bit_generator.state + assert state_a == state_b diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py index 861813a95d1f..1e22880ead38 100644 --- a/numpy/random/tests/test_randomstate.py +++ b/numpy/random/tests/test_randomstate.py @@ -2020,3 +2020,21 @@ def test_broadcast_size_error(): random.binomial([1, 2], 0.3, size=(2, 1)) with pytest.raises(ValueError): random.binomial([1, 2], [0.3, 0.7], size=(2, 1)) + + +def test_randomstate_ctor_old_style_pickle(): + rs = np.random.RandomState(MT19937(0)) + rs.standard_normal(1) + # Directly call reduce which is used in pickling + ctor, args, state_a = rs.__reduce__() + # Simulate unpickling an old pickle that only has the name + assert args[:1] == ("MT19937",) + b = ctor(*args[:1]) + b.set_state(state_a) + state_b = b.get_state(legacy=False) + + assert_equal(state_a['bit_generator'], state_b['bit_generator']) + assert_array_equal(state_a['state']['key'], state_b['state']['key']) + assert_array_equal(state_a['state']['pos'], state_b['state']['pos']) + assert_equal(state_a['has_gauss'], state_b['has_gauss']) + assert_equal(state_a['gauss'], state_b['gauss'])