Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG/ENH: Allow bit generators to supply their own constructor #22014

Merged
merged 2 commits into from Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion numpy/random/_generator.pyx
Expand Up @@ -220,8 +220,11 @@ cdef class Generator:
self.bit_generator.state = state

def __reduce__(self):
ctor, name_tpl, state = self._bit_generator.__reduce__()

from ._pickle import __generator_ctor
return __generator_ctor, (self.bit_generator.state['bit_generator'],), self.bit_generator.state
# Requirements of __generator_ctor are (name, ctor)
return __generator_ctor, (name_tpl[0], ctor), state

@property
def bit_generator(self):
Expand Down
49 changes: 23 additions & 26 deletions numpy/random/_pickle.py
Expand Up @@ -14,70 +14,67 @@
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we could refactor this to avoid the circular import _pickle -> _generator/mtrand -> _pickle

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, not really circular since _generator/mtrand import _pickle in the __reduce__ method, but the pattern is convoluted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only way to avoid would be to move these to be closed to the class definitions. There is always something circular required here since the main class needs its constructor, and the constructor needs the main class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very common in __reduce__ implementations, IME.



def __generator_ctor(bit_generator_name='MT19937'):
def __bit_generator_ctor(bit_generator_name='MT19937'):
"""
Pickling helper function that returns a Generator object
Pickling helper function that returns a bit generator object

Parameters
----------
bit_generator_name : str
String containing the core BitGenerator
String containing the name of the BitGenerator

Returns
-------
rg : Generator
Generator using the named core BitGenerator
bit_generator : BitGenerator
BitGenerator instance
"""
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(str(bit_generator_name) + ' is not a known '
'BitGenerator module.')

return Generator(bit_generator())
return bit_generator()


def __bit_generator_ctor(bit_generator_name='MT19937'):
def __generator_ctor(bit_generator_name="MT19937",
bit_generator_ctor=__bit_generator_ctor):
"""
Pickling helper function that returns a bit generator object
Pickling helper function that returns a Generator object

Parameters
----------
bit_generator_name : str
String containing the name of the BitGenerator
String containing the core BitGenerator's name
bit_generator_ctor : callable, optional
Callable function that takes bit_generator_name as its only argument
and returns an instantized bit generator.

Returns
-------
bit_generator : BitGenerator
BitGenerator instance
rg : Generator
Generator using the named core BitGenerator
"""
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(str(bit_generator_name) + ' is not a known '
'BitGenerator module.')

return bit_generator()
return Generator(bit_generator_ctor(bit_generator_name))


def __randomstate_ctor(bit_generator_name='MT19937'):
def __randomstate_ctor(bit_generator_name="MT19937",
bit_generator_ctor=__bit_generator_ctor):
"""
Pickling helper function that returns a legacy RandomState-like object

Parameters
----------
bit_generator_name : str
String containing the core BitGenerator
String containing the core BitGenerator's name
bit_generator_ctor : callable, optional
Callable function that takes bit_generator_name as its only argument
and returns an instantized bit generator.

Returns
-------
rs : RandomState
Legacy RandomState using the named core BitGenerator
"""
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
else:
raise ValueError(str(bit_generator_name) + ' is not a known '
'BitGenerator module.')

return RandomState(bit_generator())
return RandomState(bit_generator_ctor(bit_generator_name))
5 changes: 3 additions & 2 deletions numpy/random/mtrand.pyx
Expand Up @@ -213,9 +213,10 @@ cdef class RandomState:
self.set_state(state)

def __reduce__(self):
state = self.get_state(legacy=False)
ctor, name_tpl, _ = self._bit_generator.__reduce__()

from ._pickle import __randomstate_ctor
return __randomstate_ctor, (state['bit_generator'],), state
return __randomstate_ctor, (name_tpl[0], ctor), self.get_state(legacy=False)

cdef _reset_gauss(self):
self._aug_state.has_gauss = 0
Expand Down
13 changes: 13 additions & 0 deletions numpy/random/tests/test_generator_mt19937.py
Expand Up @@ -2695,3 +2695,16 @@ def test_contig_req_out(dist, order, dtype):
assert variates is out
variates = dist(out=out, dtype=dtype, size=out.shape)
assert variates is out


def test_generator_ctor_old_style_pickle():
rg = np.random.Generator(np.random.PCG64DXSM(0))
rg.standard_normal(1)
# Directly call reduce which is used in pickline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Directly call reduce which is used in pickline
# Directly call reduce which is used in pickling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks.

ctor, args, state_a = rg.__reduce__()
# Simulate unpickling an old pickle that only has the name
assert args[:1] == ("PCG64DXSM",)
b = ctor(*args[:1])
b.bit_generator.state = state_a
state_b = b.bit_generator.state
assert state_a == state_b
18 changes: 18 additions & 0 deletions numpy/random/tests/test_randomstate.py
Expand Up @@ -2020,3 +2020,21 @@ def test_broadcast_size_error():
random.binomial([1, 2], 0.3, size=(2, 1))
with pytest.raises(ValueError):
random.binomial([1, 2], [0.3, 0.7], size=(2, 1))


def test_randomstate_ctor_old_style_pickle():
rs = np.random.RandomState(MT19937(0))
rs.standard_normal(1)
# Directly call reduce which is used in pickline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Directly call reduce which is used in pickline
# Directly call reduce which is used in pickling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks.

ctor, args, state_a = rs.__reduce__()
# Simulate unpickling an old pickle that only has the name
assert args[:1] == ("MT19937",)
b = ctor(*args[:1])
b.set_state(state_a)
state_b = b.get_state(legacy=False)

assert_equal(state_a['bit_generator'], state_b['bit_generator'])
assert_array_equal(state_a['state']['key'], state_b['state']['key'])
assert_array_equal(state_a['state']['pos'], state_b['state']['pos'])
assert_equal(state_a['has_gauss'], state_b['has_gauss'])
assert_equal(state_a['gauss'], state_b['gauss'])