Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP::ENH:SIMD Improve the performance of comparison operators #16960

Closed
wants to merge 9 commits into from
16 changes: 8 additions & 8 deletions numpy/core/code_generators/generate_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,55 +425,55 @@ def english_upper(s):
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.greater'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
TD(all, out='?', dispatch=[('loops_cmp', nodatetime_or_obj)]),
[TypeDescription('O', FullTypeDescr, 'OO', 'O')],
TD('O', out='?'),
),
'greater_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.greater_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
TD(all, out='?', dispatch=[('loops_cmp', nodatetime_or_obj)]),
[TypeDescription('O', FullTypeDescr, 'OO', 'O')],
TD('O', out='?'),
),
'less':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.less'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
TD(all, out='?', dispatch=[('loops_cmp', nodatetime_or_obj)]),
[TypeDescription('O', FullTypeDescr, 'OO', 'O')],
TD('O', out='?'),
),
'less_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.less_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
TD(all, out='?', dispatch=[('loops_cmp', nodatetime_or_obj)]),
[TypeDescription('O', FullTypeDescr, 'OO', 'O')],
TD('O', out='?'),
),
'equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
TD(all, out='?', dispatch=[('loops_cmp', nodatetime_or_obj)]),
[TypeDescription('O', FullTypeDescr, 'OO', 'O')],
TD('O', out='?'),
),
'not_equal':
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.not_equal'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(all, out='?', simd=[('avx2', ints)]),
TD(all, out='?', dispatch=[('loops_cmp', nodatetime_or_obj)]),
[TypeDescription('O', FullTypeDescr, 'OO', 'O')],
TD('O', out='?'),
),
'logical_and':
Ufunc(2, 1, True_,
docstrings.get('numpy.core.umath.logical_and'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(nodatetime_or_obj, out='?', simd=[('avx2', ints)]),
TD(nodatetime_or_obj, out='?', dispatch=[('loops_cmp', nodatetime_or_obj)]),
TD(O, f='npy_ObjectLogicalAnd'),
TD(O, f='npy_ObjectLogicalAnd', out='?'),
),
Expand All @@ -489,7 +489,7 @@ def english_upper(s):
Ufunc(2, 1, False_,
docstrings.get('numpy.core.umath.logical_or'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(nodatetime_or_obj, out='?', simd=[('avx2', ints)]),
TD(nodatetime_or_obj, out='?', dispatch=[('loops_cmp', nodatetime_or_obj)]),
TD(O, f='npy_ObjectLogicalOr'),
TD(O, f='npy_ObjectLogicalOr', out='?'),
),
Expand Down
1 change: 1 addition & 0 deletions numpy/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,7 @@ def generate_umath_c(ext, build_dir):
join('src', 'umath', 'simd.inc.src'),
join('src', 'umath', 'loops.h.src'),
join('src', 'umath', 'loops.c.src'),
join('src', 'umath', 'loops_cmp.dispatch.pyas.c'),
join('src', 'umath', 'matmul.h.src'),
join('src', 'umath', 'matmul.c.src'),
join('src', 'umath', 'clip.h.src'),
Expand Down
28 changes: 28 additions & 0 deletions numpy/core/src/common/simd/avx2/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,32 @@
#define npyv_cvt_b32_f32(BL) _mm256_castps_si256(BL)
#define npyv_cvt_b64_f64(BL) _mm256_castpd_si256(BL)

// pack two 16-bit boolean into one 8-bit boolean vector
NPY_FINLINE npyv_b8 npyv_pack_b16(npyv_b16 a, npyv_b16 b)
{
__m256i ab = _mm256_packs_epi16(a, b);
return npyv256_shuffle_odd(ab);
}
// pack four 32-bit boolean vectors into one 8-bit boolean vector
NPY_FINLINE npyv_b8 npyv_pack_b8_b32(npyv_b32 a, npyv_b32 b, npyv_b32 c, npyv_b32 d)
{
__m256i ab = _mm256_packs_epi32(a, b);
__m256i cd = _mm256_packs_epi32(c, d);
__m256i abcd = npyv_pack_b16(ab, cd);
return _mm256_shuffle_epi32(abcd, _MM_SHUFFLE(3, 1, 2, 0));
}
// pack eight 64-bit boolean vectors into one 8-bit boolean vector
NPY_FINLINE npyv_b16 npyv_pack_b8_b64(npyv_b64 a, npyv_b64 b, npyv_b64 c, npyv_b64 d,
npyv_b64 e, npyv_b64 f, npyv_b64 g, npyv_b64 h)
{
__m256i ab = _mm256_packs_epi32(a, b);
__m256i cd = _mm256_packs_epi32(c, d);
__m256i ef = _mm256_packs_epi32(e, f);
__m256i gh = _mm256_packs_epi32(g, h);
__m256i abcd = _mm256_packs_epi32(ab, cd);
__m256i efgh = _mm256_packs_epi32(ef, gh);
__m256i all = npyv256_shuffle_odd(_mm256_packs_epi16(abcd, efgh));
__m256i rev128 = _mm256_alignr_epi8(all, all, 8);
return _mm256_unpacklo_epi16(all, rev128);
}
#endif // _NPY_SIMD_AVX2_CVT_H
157 changes: 157 additions & 0 deletions numpy/core/src/common/simd/avx2/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,161 @@ NPYV_IMPL_AVX2_MEM_INT(npy_int64, s64)
#define npyv_storeh_f32(PTR, VEC) _mm_storeu_ps(PTR, _mm256_extractf128_ps(VEC, 1))
#define npyv_storeh_f64(PTR, VEC) _mm_storeu_pd(PTR, _mm256_extractf128_pd(VEC, 1))

/***************************
* Non-contiguous Load
***************************/
//// 8
NPY_FINLINE npyv_u8 npyv_loadn_u8(const npy_uint8 *ptr, int stride)
{
const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32(stride), steps);
const __m256i cut32 = _mm256_set1_epi32(0xFF);
const __m256i sort_odd = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
__m256i a = _mm256_i32gather_epi32((const int*)ptr, idx, 1);
__m256i b = _mm256_i32gather_epi32((const int*)(ptr + stride*8), idx, 1);
__m256i c = _mm256_i32gather_epi32((const int*)(ptr + stride*16), idx, 1);
__m256i d = _mm256_i32gather_epi32((const int*)((ptr-3/*overflow guard*/) + stride*24), idx, 1);
a = _mm256_and_si256(a, cut32);
b = _mm256_and_si256(b, cut32);
c = _mm256_and_si256(c, cut32);
d = _mm256_srli_epi32(d, 24);
a = _mm256_packus_epi32(a, b);
c = _mm256_packus_epi32(c, d);
return _mm256_permutevar8x32_epi32(_mm256_packus_epi16(a, c), sort_odd);
}
NPY_FINLINE npyv_s8 npyv_loadn_s8(const npy_int8 *ptr, int stride)
{ return npyv_loadn_u8((const npy_uint8 *)ptr, stride); }
//// 16
NPY_FINLINE npyv_u16 npyv_loadn_u16(const npy_uint16 *ptr, int stride)
{
const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32(stride), steps);
const __m256i cut32 = _mm256_set1_epi32(0xFF);
__m256i a = _mm256_i32gather_epi32((const int*)ptr, idx, 2);
__m256i b = _mm256_i32gather_epi32((const int*)((ptr-1/*overflow guard*/) + stride*8), idx, 2);
a = _mm256_and_si256(a, cut32);
b = _mm256_srli_epi32(b, 16);
return npyv256_shuffle_odd(_mm256_packus_epi16(a, b));
}
NPY_FINLINE npyv_s16 npyv_loadn_s16(const npy_int16 *ptr, int stride)
{ return npyv_loadn_u16((const npy_uint16 *)ptr, stride); }
//// 32
NPY_FINLINE npyv_u32 npyv_loadn_u32(const npy_uint32 *ptr, int stride)
{
const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32(stride), steps);
return _mm256_i32gather_epi32((const int*)ptr, idx, 4);
}
NPY_FINLINE npyv_s32 npyv_loadn_s32(const npy_int32 *ptr, int stride)
{ return npyv_loadn_u32((const npy_uint32*)ptr, stride); }
NPY_FINLINE npyv_f32 npyv_loadn_f32(const float *ptr, int stride)
{ return _mm256_castsi256_ps(npyv_loadn_u32((const npy_uint32*)ptr, stride)); }
//// 64
NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, int stride)
{
__m128d a0 = _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)ptr));
__m128d a2 = _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)(ptr + stride*2)));
__m128d a01 = _mm_loadh_pd(a0, ptr + stride);
__m128d a23 = _mm_loadh_pd(a2, ptr + stride*3);
return _mm256_insertf128_pd(_mm256_castpd128_pd256(a01), a23, 1);
}
NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, int stride)
{ return _mm256_castpd_si256(npyv_loadn_f64((const double*)ptr, stride)); }
NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, int stride)
{ return _mm256_castpd_si256(npyv_loadn_f64((const double*)ptr, stride)); }
/*
NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, int stride)
{
const __m128i steps = _mm_setr_epi32(0, 1, 2, 3);
const __m128i idx = _mm_mullo_epi32(_mm_set1_epi32(stride), steps);
return _mm256_i32gather_epi64((const void*)ptr, idx, 8);
}
NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, int stride)
{ return npyv_loadn_u64((const npy_uint64*)ptr, stride); }
NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, int stride)
{ return _mm256_castsi256_pd(npyv_loadn_u64((const npy_uint64*)ptr, stride)); }
*/

/***************************
* Non-contiguous Store
***************************/
//// 8
NPY_FINLINE void npyv_storen_u8(npy_uint8 *ptr, int stride, npyv_u8 a)
{
__m128i a0 = _mm256_castsi256_si128(a);
__m128i a1 = _mm256_extracti128_si256(a, 1);
#define NPYV_IMPL_AVX2_STOREN8(VEC, EI, I) \
{ \
unsigned e = (unsigned)_mm_extract_epi32(VEC, EI); \
ptr[stride*(I+0)] = (npy_uint8)e; \
ptr[stride*(I+1)] = (npy_uint8)(e >> 8); \
ptr[stride*(I+2)] = (npy_uint8)(e >> 16); \
ptr[stride*(I+3)] = (npy_uint8)(e >> 24); \
}
NPYV_IMPL_AVX2_STOREN8(a0, 0, 0)
NPYV_IMPL_AVX2_STOREN8(a0, 1, 4)
NPYV_IMPL_AVX2_STOREN8(a0, 2, 8)
NPYV_IMPL_AVX2_STOREN8(a0, 3, 12)
NPYV_IMPL_AVX2_STOREN8(a1, 0, 16)
NPYV_IMPL_AVX2_STOREN8(a1, 1, 20)
NPYV_IMPL_AVX2_STOREN8(a1, 2, 24)
NPYV_IMPL_AVX2_STOREN8(a1, 3, 28)
}
NPY_FINLINE void npyv_storen_s8(npy_int8 *ptr, int stride, npyv_s8 a)
{ npyv_storen_u8((npy_uint8*)ptr, stride, a); }
//// 16
NPY_FINLINE void npyv_storen_u16(npy_uint16 *ptr, int stride, npyv_u16 a)
{
__m128i a0 = _mm256_castsi256_si128(a);
__m128i a1 = _mm256_extracti128_si256(a, 1);
#define NPYV_IMPL_AVX2_STOREN16(VEC, EI, I) \
{ \
unsigned e = (unsigned)_mm_extract_epi32(VEC, EI); \
ptr[stride*(I+0)] = (npy_uint16)e; \
ptr[stride*(I+1)] = (npy_uint16)(e >> 16); \
}
NPYV_IMPL_AVX2_STOREN16(a0, 0, 0)
NPYV_IMPL_AVX2_STOREN16(a0, 1, 2)
NPYV_IMPL_AVX2_STOREN16(a0, 2, 4)
NPYV_IMPL_AVX2_STOREN16(a0, 3, 6)
NPYV_IMPL_AVX2_STOREN16(a1, 0, 8)
NPYV_IMPL_AVX2_STOREN16(a1, 1, 10)
NPYV_IMPL_AVX2_STOREN16(a1, 2, 12)
NPYV_IMPL_AVX2_STOREN16(a1, 3, 14)
}
NPY_FINLINE void npyv_storen_s16(npy_int16 *ptr, int stride, npyv_s16 a)
{ npyv_storen_u16((npy_uint16*)ptr, stride, a); }
//// 32
NPY_FINLINE void npyv_storen_s32(npy_int32 *ptr, int stride, npyv_s32 a)
{
__m128i a0 = _mm256_castsi256_si128(a);
__m128i a1 = _mm256_extracti128_si256(a, 1);
ptr[stride * 0] = _mm_cvtsi128_si32(a0);
ptr[stride * 1] = _mm_extract_epi32(a0, 1);
ptr[stride * 2] = _mm_extract_epi32(a0, 2);
ptr[stride * 3] = _mm_extract_epi32(a0, 3);
ptr[stride * 4] = _mm_cvtsi128_si32(a1);
ptr[stride * 5] = _mm_extract_epi32(a1, 1);
ptr[stride * 6] = _mm_extract_epi32(a1, 2);
ptr[stride * 7] = _mm_extract_epi32(a1, 3);
}
NPY_FINLINE void npyv_storen_u32(npy_uint32 *ptr, int stride, npyv_u32 a)
{ npyv_storen_s32((npy_int32*)ptr, stride, a); }
NPY_FINLINE void npyv_storen_f32(float *ptr, int stride, npyv_f32 a)
{ npyv_storen_s32((npy_int32*)ptr, stride, _mm256_castps_si256(a)); }
//// 64
NPY_FINLINE void npyv_storen_f64(double *ptr, int stride, npyv_f64 a)
{
__m128d a0 = _mm256_castpd256_pd128(a);
__m128d a1 = _mm256_extractf128_pd(a, 1);
_mm_storel_pd(ptr + stride * 0, a0);
_mm_storeh_pd(ptr + stride * 1, a0);
_mm_storel_pd(ptr + stride * 2, a1);
_mm_storeh_pd(ptr + stride * 3, a1);
}
NPY_FINLINE void npyv_storen_u64(npy_uint64 *ptr, int stride, npyv_u64 a)
{ npyv_storen_f64((double*)ptr, stride, _mm256_castsi256_pd(a)); }
NPY_FINLINE void npyv_storen_s64(npy_int64 *ptr, int stride, npyv_s64 a)
{ npyv_storen_f64((double*)ptr, stride, _mm256_castsi256_pd(a)); }

#endif // _NPY_SIMD_AVX2_MEMORY_H
27 changes: 25 additions & 2 deletions numpy/core/src/common/simd/avx2/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,48 +53,64 @@ NPY_FINLINE __m256i npyv_shr_s64(__m256i a, int c)
// AND
#define npyv_and_u8 _mm256_and_si256
#define npyv_and_s8 _mm256_and_si256
#define npyv_and_b8 _mm256_and_si256
#define npyv_and_u16 _mm256_and_si256
#define npyv_and_s16 _mm256_and_si256
#define npyv_and_b16 _mm256_and_si256
#define npyv_and_u32 _mm256_and_si256
#define npyv_and_s32 _mm256_and_si256
#define npyv_and_b32 _mm256_and_si256
#define npyv_and_u64 _mm256_and_si256
#define npyv_and_s64 _mm256_and_si256
#define npyv_and_b64 _mm256_and_si256
#define npyv_and_f32 _mm256_and_ps
#define npyv_and_f64 _mm256_and_pd

// OR
#define npyv_or_u8 _mm256_or_si256
#define npyv_or_s8 _mm256_or_si256
#define npyv_or_b8 _mm256_or_si256
#define npyv_or_u16 _mm256_or_si256
#define npyv_or_s16 _mm256_or_si256
#define npyv_or_b16 _mm256_or_si256
#define npyv_or_u32 _mm256_or_si256
#define npyv_or_s32 _mm256_or_si256
#define npyv_or_b32 _mm256_or_si256
#define npyv_or_u64 _mm256_or_si256
#define npyv_or_s64 _mm256_or_si256
#define npyv_or_b64 _mm256_or_si256
#define npyv_or_f32 _mm256_or_ps
#define npyv_or_f64 _mm256_or_pd

// XOR
#define npyv_xor_u8 _mm256_xor_si256
#define npyv_xor_s8 _mm256_xor_si256
#define npyv_xor_b8 _mm256_xor_si256
#define npyv_xor_u16 _mm256_xor_si256
#define npyv_xor_s16 _mm256_xor_si256
#define npyv_xor_b16 _mm256_xor_si256
#define npyv_xor_u32 _mm256_xor_si256
#define npyv_xor_s32 _mm256_xor_si256
#define npyv_xor_b32 _mm256_xor_si256
#define npyv_xor_u64 _mm256_xor_si256
#define npyv_xor_s64 _mm256_xor_si256
#define npyv_xor_b64 _mm256_xor_si256
#define npyv_xor_f32 _mm256_xor_ps
#define npyv_xor_f64 _mm256_xor_pd

// NOT
#define npyv_not_u8(A) _mm256_xor_si256(A, _mm256_set1_epi32(-1))
#define npyv_not_u8(A) _mm256_andnot_si256(A, _mm256_set1_epi32(-1))
#define npyv_not_s8 npyv_not_u8
#define npyv_not_b8 npyv_not_u8
#define npyv_not_u16 npyv_not_u8
#define npyv_not_s16 npyv_not_u8
#define npyv_not_b16 npyv_not_u8
#define npyv_not_u32 npyv_not_u8
#define npyv_not_s32 npyv_not_u8
#define npyv_not_b32 npyv_not_u8
#define npyv_not_u64 npyv_not_u8
#define npyv_not_s64 npyv_not_u8
#define npyv_not_b64 npyv_not_u8
#define npyv_not_f32(A) _mm256_xor_ps(A, _mm256_castsi256_ps(_mm256_set1_epi32(-1)))
#define npyv_not_f64(A) _mm256_xor_pd(A, _mm256_castsi256_pd(_mm256_set1_epi32(-1)))

Expand Down Expand Up @@ -135,6 +151,7 @@ NPY_FINLINE __m256i npyv_shr_s64(__m256i a, int c)
#define npyv_cmpge_s64(A, B) npyv_not_s64(_mm256_cmpgt_epi64(B, A))

// unsigned greater than
/*
#define NPYV_IMPL_AVX2_UNSIGNED_GT(LEN, SIGN) \
NPY_FINLINE __m256i npyv_cmpgt_u##LEN(__m256i a, __m256i b) \
{ \
Expand All @@ -147,7 +164,13 @@ NPY_FINLINE __m256i npyv_shr_s64(__m256i a, int c)
NPYV_IMPL_AVX2_UNSIGNED_GT(8, 0x80808080)
NPYV_IMPL_AVX2_UNSIGNED_GT(16, 0x80008000)
NPYV_IMPL_AVX2_UNSIGNED_GT(32, 0x80000000)

*/
NPY_FINLINE __m256i npyv_cmpgt_u8(__m256i a, __m256i b)
{ return npyv_not_u8(_mm256_cmpeq_epi8(b, _mm256_max_epu8(b, a))); }
NPY_FINLINE __m256i npyv_cmpgt_u16(__m256i a, __m256i b)
{ return npyv_not_u16(_mm256_cmpeq_epi16(b, _mm256_max_epu16(b, a))); }
NPY_FINLINE __m256i npyv_cmpgt_u32(__m256i a, __m256i b)
{ return npyv_not_u32(_mm256_cmpeq_epi32(b, _mm256_max_epu32(b, a))); }
NPY_FINLINE __m256i npyv_cmpgt_u64(__m256i a, __m256i b)
{
const __m256i sbit = _mm256_set1_epi64x(0x8000000000000000);
Expand Down