Skip to content

Commit

Permalink
Merge pull request #21687 from rafaelcfsousa/bug_comparison
Browse files Browse the repository at this point in the history
BUG: switch _CMP_NEQ_OQ to _CMP_NEQ_UQ for npyv_cmpneq_f[32,64]
  • Loading branch information
seberg committed Jun 8, 2022
2 parents 11cc8a2 + ad9a030 commit 4a0e507
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
4 changes: 2 additions & 2 deletions numpy/core/src/common/simd/avx2/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ NPY_FINLINE __m256i npyv_cmpge_u32(__m256i a, __m256i b)
// precision comparison
#define npyv_cmpeq_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_EQ_OQ))
#define npyv_cmpeq_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_EQ_OQ))
#define npyv_cmpneq_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_NEQ_OQ))
#define npyv_cmpneq_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_NEQ_OQ))
#define npyv_cmpneq_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_NEQ_UQ))
#define npyv_cmpneq_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_NEQ_UQ))
#define npyv_cmplt_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_LT_OQ))
#define npyv_cmplt_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_LT_OQ))
#define npyv_cmple_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_LE_OQ))
Expand Down
4 changes: 2 additions & 2 deletions numpy/core/src/common/simd/avx512/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@
// precision comparison
#define npyv_cmpeq_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_EQ_OQ)
#define npyv_cmpeq_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_EQ_OQ)
#define npyv_cmpneq_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_NEQ_OQ)
#define npyv_cmpneq_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_NEQ_OQ)
#define npyv_cmpneq_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_NEQ_UQ)
#define npyv_cmpneq_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_NEQ_UQ)
#define npyv_cmplt_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_LT_OQ)
#define npyv_cmplt_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_LT_OQ)
#define npyv_cmple_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_LE_OQ)
Expand Down
28 changes: 28 additions & 0 deletions numpy/core/tests/test_simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,34 @@ def test_special_cases(self):
nnan = self.notnan(self.setall(self._nan()))
assert nnan == [0]*self.nlanes

import operator

@pytest.mark.parametrize('py_comp,np_comp', [
(operator.lt, "cmplt"),
(operator.le, "cmple"),
(operator.gt, "cmpgt"),
(operator.ge, "cmpge"),
(operator.eq, "cmpeq"),
(operator.ne, "cmpneq")
])
def test_comparison_with_nan(self, py_comp, np_comp):
pinf, ninf, nan = self._pinfinity(), self._ninfinity(), self._nan()
mask_true = self._true_mask()

def to_bool(vector):
return [lane == mask_true for lane in vector]

intrin = getattr(self, np_comp)
cmp_cases = ((0, nan), (nan, 0), (nan, nan), (pinf, nan), (ninf, nan))
for case_operand1, case_operand2 in cmp_cases:
data_a = [case_operand1]*self.nlanes
data_b = [case_operand2]*self.nlanes
vdata_a = self.setall(case_operand1)
vdata_b = self.setall(case_operand2)
vcmp = to_bool(intrin(vdata_a, vdata_b))
data_cmp = [py_comp(a, b) for a, b in zip(data_a, data_b)]
assert vcmp == data_cmp

class _SIMD_ALL(_Test_Utility):
"""
To test all vector types at once
Expand Down

0 comments on commit 4a0e507

Please sign in to comment.