Skip to content

Commit

Permalink
fix race condition in parallel tests
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman (@nstarman) <nstarkman@protonmail.com>
  • Loading branch information
nstarman committed Aug 15, 2021
1 parent c0a335d commit 97ee653
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 30 deletions.
2 changes: 1 addition & 1 deletion astropy/units/quantity_helper/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _is_ulike(unit, allow_structured=False): # TODO! make a public-scoped funct

def _is_seq_ulike(seq, allow_structured=False):
"""Check if a sequence is unit-like."""
return ( isinstance(seq, Sequence)
return (isinstance(seq, Sequence)
and all(_is_ulike(x, allow_structured) for x in seq))


Expand Down
50 changes: 21 additions & 29 deletions astropy/units/tests/test_quantity_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_coverage(self):
all_erfa_ufuncs = set([ufunc for ufunc in erfa_ufunc.__dict__.values()
if isinstance(ufunc, np.ufunc)])
assert ((all_q_ufuncs - all_np_ufuncs - all_erfa_ufuncs
- qh.REGISTERED_NARG_UFUNCS) == set())
- qh.REGISTERED_NARG_UFUNCS) == set()), str(qh.REGISTERED_NARG_UFUNCS)

def test_scipy_registered(self):
# Should be registered as existing even if scipy is not available.
Expand Down Expand Up @@ -1353,7 +1353,7 @@ def test_jv_invalid_units(self, function):

# -------------------------------------------------------------------

def make_ufunc_list():
def make_func_list():

def func_1_1(x: "km") -> "km2":
return x**2
Expand All @@ -1373,7 +1373,7 @@ def func_2_2(x: "km", y: "km") -> ("km2", None):
def funcobj_2_2(x: "km", y: "km") -> (object, object):
return x**2, y**2

ufunc_list = np.array(
func_list = np.array(
# ( func, (input,), (output,) )
[(func_1_1, (2,), (4,)),
(funcobj_1_1, (2,), (4,)),
Expand All @@ -1392,10 +1392,10 @@ def funcobj_2_2(x: "km", y: "km") -> (object, object):
],
dtype=object)

return ufunc_list
return func_list


ufunc_list = make_ufunc_list()
func_list = make_func_list()


class TestRegisterUfunc:
Expand All @@ -1404,19 +1404,18 @@ def setup_class(self):
# registry of normal ufuncs
self.ufunc_registry = {
func: np.frompyfunc(func, *map(int, func.__name__.split("_")[1:]))
for func in ufunc_list[:, 0]}
for func in func_list[:, 0]}

# registry of normal ufuncs, where the units will be assumed correct
self.ufunc_assume_registry = {
func: np.frompyfunc(func, *map(int, func.__name__.split("_")[1:]))
for func in ufunc_list[:, 0]}
for func in func_list[:, 0]}

@pytest.fixture(autouse=True)
def setup(self):
# start by saving a copy of REGISTERED_NARG_UFUNCS
# TODO! when py3.8+ use copy.deepcopy
ORIGINAL_REGISTERED_NARG_UFUNCS = {k for k in qh.REGISTERED_NARG_UFUNCS}
ORIGINAL_UFUNC_HELPERS = {k: v for k, v in qh.UFUNC_HELPERS.items()}

# register as quantity-ufuncs
# NOTE! output units are "u.km" even though the functions are squaring
Expand All @@ -1437,17 +1436,14 @@ def setup(self):

# restore states
ADDED_NARG_UFUNCS = ORIGINAL_REGISTERED_NARG_UFUNCS - qh.REGISTERED_NARG_UFUNCS

qh.REGISTERED_NARG_UFUNCS = ORIGINAL_REGISTERED_NARG_UFUNCS
qh.UFUNC_HELPERS = ORIGINAL_UFUNC_HELPERS

for k in ADDED_NARG_UFUNCS: # double extra make sure it's clean
for k in ADDED_NARG_UFUNCS:
qh.UFUNC_HELPERS.pop(k, None)
qh.REGISTERED_NARG_UFUNCS.pop(k, None)

# -------------------
# variety of funcs

@pytest.mark.parametrize("func, inp, res", ufunc_list[:7])
@pytest.mark.parametrize("func, inp, res", func_list[:7])
def test_raw_func(self, func, inp, res):
"""
In this case, the output will also not have units.
Expand All @@ -1459,7 +1455,7 @@ def test_raw_func(self, func, inp, res):
# got = np.array(got).astype(float) if isinstance(got, (np.ndarray, tuple)) else got
assert_allclose(got, res)

@pytest.mark.parametrize("func, inp, res", ufunc_list)
@pytest.mark.parametrize("func, inp, res", func_list)
def test_no_units(self, func, inp, res):
"""Test unitless input has unitless output. NO UNITS ATTACHED!"""
got = self.ufunc_registry[func](*inp) # no units
Expand All @@ -1468,7 +1464,7 @@ def test_no_units(self, func, inp, res):
got = np.array(got).astype(float) if isinstance(got, (np.ndarray, tuple)) else got
assert_allclose(got, res)

@pytest.mark.parametrize("func, inp, res", ufunc_list)
@pytest.mark.parametrize("func, inp, res", func_list)
def test_has_units(self, func, inp, res):
"""Test unitful input has unitful output."""
inp = [(x * u.km if x is not None else None) for x in inp]
Expand All @@ -1484,7 +1480,7 @@ def test_has_units(self, func, inp, res):
got = (u.Quantity(got, dtype=float) if isinstance(got, (np.ndarray, tuple)) else got)
assert_quantity_allclose(got, res * u.km)

@pytest.mark.parametrize("func, inp, res", ufunc_list)
@pytest.mark.parametrize("func, inp, res", func_list)
def test_has_units_assumed_correct(self, func, inp, res):
"""
Test unitful input has unitful output and units are assumed to be
Expand All @@ -1501,7 +1497,7 @@ def test_has_units_assumed_correct(self, func, inp, res):

@pytest.mark.parametrize("registry",
["ufunc_registry", "ufunc_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
@pytest.mark.parametrize("func, inp, res", func_list)
def test_has_wrong_units(self, registry, func, inp, res):
"""Test wrong unitful input raises errors."""
inp = [(x * u.deg if x is not None else None) for x in inp]
Expand Down Expand Up @@ -1531,14 +1527,13 @@ def setup(self):
# start by saving a copy of REGISTERED_NARG_UFUNCS
# TODO! when py3.8+ use copy.deepcopy
ORIGINAL_REGISTERED_NARG_UFUNCS = {k for k in qh.REGISTERED_NARG_UFUNCS}
ORIGINAL_UFUNC_HELPERS = {k: v for k, v in qh.UFUNC_HELPERS.items()}

# registry of ufuncs
self.ufunc_registry = {}
self.ufunc_assume_registry = {}
self.ufunc_introspect_registry = {}
self.ufunc_introspect_assume_registry = {}
for func in ufunc_list[:, 0]:
for func in func_list[:, 0]:
nin, nout = map(int, func.__name__.split("_")[1:])

self.ufunc_registry[func] = frompyfunc(
Expand All @@ -1556,12 +1551,9 @@ def setup(self):

# restore states
ADDED_NARG_UFUNCS = ORIGINAL_REGISTERED_NARG_UFUNCS - qh.REGISTERED_NARG_UFUNCS

qh.REGISTERED_NARG_UFUNCS = ORIGINAL_REGISTERED_NARG_UFUNCS
qh.UFUNC_HELPERS = ORIGINAL_UFUNC_HELPERS

for k in ADDED_NARG_UFUNCS: # double extra make sure it's clean
for k in ADDED_NARG_UFUNCS:
qh.UFUNC_HELPERS.pop(k, None)
qh.REGISTERED_NARG_UFUNCS.pop(k, None)

# -------------------
# variety of funcs
Expand All @@ -1570,7 +1562,7 @@ def setup(self):
"registry",
["ufunc_registry", "ufunc_assume_registry",
"ufunc_introspect_registry", "ufunc_introspect_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
@pytest.mark.parametrize("func, inp, res", func_list)
def test_no_units(self, registry, func, inp, res):
"""Test unitless input has unitless output. NO UNITS ATTACHED!"""
got = getattr(self, registry)[func](*inp) # no units
Expand All @@ -1582,7 +1574,7 @@ def test_no_units(self, registry, func, inp, res):
@pytest.mark.parametrize(
"registry",
["ufunc_registry", "ufunc_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
@pytest.mark.parametrize("func, inp, res", func_list)
def test_has_units(self, registry, func, inp, res):
"""Test unitful input has unitful output."""
inp = [(x * u.km if x is not None else None) for x in inp]
Expand All @@ -1601,7 +1593,7 @@ def test_has_units(self, registry, func, inp, res):
@pytest.mark.parametrize(
"registry",
["ufunc_introspect_registry", "ufunc_introspect_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
@pytest.mark.parametrize("func, inp, res", func_list)
def test_has_introspected_units(self, registry, func, inp, res):
"""Test unitful input has unitful output."""
# give units to inputs
Expand Down Expand Up @@ -1629,7 +1621,7 @@ def test_has_introspected_units(self, registry, func, inp, res):
@pytest.mark.parametrize(
"registry",
["ufunc_registry", "ufunc_assume_registry"])
@pytest.mark.parametrize("func, inp, res", ufunc_list)
@pytest.mark.parametrize("func, inp, res", func_list)
def test_has_wrong_units(self, registry, func, inp, res):
"""Test wrong unitful input raises errors."""
inp = [(x * u.deg if x is not None else None) for x in inp]
Expand Down

0 comments on commit 97ee653

Please sign in to comment.