Skip to content

Commit 886ef10

Browse files
vector_algorithms.cpp: find, find_last, count: make AVX2 path avoid SSE path and (for some types) fallback (#4570)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent 9839187 commit 886ef10

File tree

2 files changed

+99
-42
lines changed

2 files changed

+99
-42
lines changed

benchmarks/src/find_and_count.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
#include <benchmark/benchmark.h>
66
#include <cstddef>
77
#include <cstdint>
8+
#include <cstdlib>
89
#include <ranges>
10+
#include <vector>
911

1012
enum class Op {
1113
FindSized,
@@ -15,39 +17,50 @@ enum class Op {
1517

1618
using namespace std;
1719

18-
template <class T, size_t Size, size_t Pos, Op Operation>
20+
template <class T, Op Operation>
1921
void bm(benchmark::State& state) {
20-
T a[Size];
22+
const auto size = static_cast<size_t>(state.range(0));
23+
const auto pos = static_cast<size_t>(state.range(1));
2124

22-
fill_n(a, Size, T{'0'});
23-
if constexpr (Pos < Size) {
24-
a[Pos] = T{'1'};
25+
vector<T> a(size, T{'0'});
26+
27+
if (pos < size) {
28+
a[pos] = T{'1'};
2529
} else {
26-
static_assert(Operation != Op::FindUnsized);
30+
if constexpr (Operation == Op::FindUnsized) {
31+
abort();
32+
}
2733
}
2834

2935
for (auto _ : state) {
3036
if constexpr (Operation == Op::FindSized) {
31-
benchmark::DoNotOptimize(ranges::find(a, a + Size, T{'1'}));
37+
benchmark::DoNotOptimize(ranges::find(a.begin(), a.end(), T{'1'}));
3238
} else if constexpr (Operation == Op::FindUnsized) {
33-
benchmark::DoNotOptimize(ranges::find(a, unreachable_sentinel, T{'1'}));
39+
benchmark::DoNotOptimize(ranges::find(a.begin(), unreachable_sentinel, T{'1'}));
3440
} else if constexpr (Operation == Op::Count) {
35-
benchmark::DoNotOptimize(ranges::count(a, a + Size, T{'1'}));
41+
benchmark::DoNotOptimize(ranges::count(a.begin(), a.end(), T{'1'}));
3642
}
3743
}
3844
}
3945

40-
BENCHMARK(bm<uint8_t, 8021, 3056, Op::FindSized>);
41-
BENCHMARK(bm<uint8_t, 8021, 3056, Op::FindUnsized>);
42-
BENCHMARK(bm<uint8_t, 8021, 3056, Op::Count>);
46+
void common_args(auto bm) {
47+
bm->Args({8021, 3056});
48+
// AVX tail tests
49+
bm->Args({63, 62})->Args({31, 30})->Args({15, 14})->Args({7, 6});
50+
}
51+
52+
53+
BENCHMARK(bm<uint8_t, Op::FindSized>)->Apply(common_args);
54+
BENCHMARK(bm<uint8_t, Op::FindUnsized>)->Apply(common_args);
55+
BENCHMARK(bm<uint8_t, Op::Count>)->Apply(common_args);
4356

44-
BENCHMARK(bm<uint16_t, 8021, 3056, Op::FindSized>);
45-
BENCHMARK(bm<uint16_t, 8021, 3056, Op::Count>);
57+
BENCHMARK(bm<uint16_t, Op::FindSized>)->Apply(common_args);
58+
BENCHMARK(bm<uint16_t, Op::Count>)->Apply(common_args);
4659

47-
BENCHMARK(bm<uint32_t, 8021, 3056, Op::FindSized>);
48-
BENCHMARK(bm<uint32_t, 8021, 3056, Op::Count>);
60+
BENCHMARK(bm<uint32_t, Op::FindSized>)->Apply(common_args);
61+
BENCHMARK(bm<uint32_t, Op::Count>)->Apply(common_args);
4962

50-
BENCHMARK(bm<uint64_t, 8021, 3056, Op::FindSized>);
51-
BENCHMARK(bm<uint64_t, 8021, 3056, Op::Count>);
63+
BENCHMARK(bm<uint64_t, Op::FindSized>)->Apply(common_args);
64+
BENCHMARK(bm<uint64_t, Op::Count>)->Apply(common_args);
5265

5366
BENCHMARK_MAIN();

stl/src/vector_algorithms.cpp

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,15 +1837,15 @@ namespace {
18371837
template <class _Traits, class _Ty>
18381838
const void* __stdcall __std_find_trivial_impl(const void* _First, const void* _Last, _Ty _Val) noexcept {
18391839
#ifndef _M_ARM64EC
1840-
size_t _Size_bytes = _Byte_length(_First, _Last);
1840+
const size_t _Size_bytes = _Byte_length(_First, _Last);
18411841

1842-
const size_t _Avx_size = _Size_bytes & ~size_t{0x1F};
1843-
if (_Avx_size != 0 && _Use_avx2()) {
1842+
if (const size_t _Avx_size = _Size_bytes & ~size_t{0x1F}; _Avx_size != 0 && _Use_avx2()) {
18441843
_Zeroupper_on_exit _Guard; // TRANSITION, DevCom-10331414
18451844

18461845
const __m256i _Comparand = _Traits::_Set_avx(_Val);
18471846
const void* _Stop_at = _First;
18481847
_Advance_bytes(_Stop_at, _Avx_size);
1848+
18491849
do {
18501850
const __m256i _Data = _mm256_loadu_si256(static_cast<const __m256i*>(_First));
18511851
const int _Bingo = _mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand));
@@ -1858,14 +1858,30 @@ namespace {
18581858

18591859
_Advance_bytes(_First, 32);
18601860
} while (_First != _Stop_at);
1861-
_Size_bytes &= 0x1F;
1862-
}
18631861

1864-
const size_t _Sse_size = _Size_bytes & ~size_t{0xF};
1865-
if (_Sse_size != 0 && _Use_sse42()) {
1862+
if (const size_t _Avx_tail_size = _Size_bytes & 0x1C; _Avx_tail_size != 0) {
1863+
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Avx_tail_size >> 2);
1864+
const __m256i _Data = _mm256_maskload_epi32(static_cast<const int*>(_First), _Tail_mask);
1865+
const int _Bingo =
1866+
_mm256_movemask_epi8(_mm256_and_si256(_Traits::_Cmp_avx(_Data, _Comparand), _Tail_mask));
1867+
1868+
if (_Bingo != 0) {
1869+
const unsigned long _Offset = _tzcnt_u32(_Bingo);
1870+
_Advance_bytes(_First, _Offset);
1871+
return _First;
1872+
}
1873+
1874+
_Advance_bytes(_First, _Avx_tail_size);
1875+
}
1876+
1877+
if constexpr (sizeof(_Ty) >= 4) {
1878+
return _First;
1879+
}
1880+
} else if (const size_t _Sse_size = _Size_bytes & ~size_t{0xF}; _Sse_size != 0 && _Use_sse42()) {
18661881
const __m128i _Comparand = _Traits::_Set_sse(_Val);
18671882
const void* _Stop_at = _First;
18681883
_Advance_bytes(_Stop_at, _Sse_size);
1884+
18691885
do {
18701886
const __m128i _Data = _mm_loadu_si128(static_cast<const __m128i*>(_First));
18711887
const int _Bingo = _mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand));
@@ -1892,15 +1908,15 @@ namespace {
18921908
const void* __stdcall __std_find_last_trivial_impl(const void* _First, const void* _Last, _Ty _Val) noexcept {
18931909
const void* const _Real_last = _Last;
18941910
#ifndef _M_ARM64EC
1895-
size_t _Size_bytes = _Byte_length(_First, _Last);
1911+
const size_t _Size_bytes = _Byte_length(_First, _Last);
18961912

1897-
const size_t _Avx_size = _Size_bytes & ~size_t{0x1F};
1898-
if (_Avx_size != 0 && _Use_avx2()) {
1913+
if (const size_t _Avx_size = _Size_bytes & ~size_t{0x1F}; _Avx_size != 0 && _Use_avx2()) {
18991914
_Zeroupper_on_exit _Guard; // TRANSITION, DevCom-10331414
19001915

19011916
const __m256i _Comparand = _Traits::_Set_avx(_Val);
19021917
const void* _Stop_at = _Last;
19031918
_Rewind_bytes(_Stop_at, _Avx_size);
1919+
19041920
do {
19051921
_Rewind_bytes(_Last, 32);
19061922
const __m256i _Data = _mm256_loadu_si256(static_cast<const __m256i*>(_Last));
@@ -1912,14 +1928,29 @@ namespace {
19121928
return _Last;
19131929
}
19141930
} while (_Last != _Stop_at);
1915-
_Size_bytes &= 0x1F;
1916-
}
19171931

1918-
const size_t _Sse_size = _Size_bytes & ~size_t{0xF};
1919-
if (_Sse_size != 0 && _Use_sse42()) {
1932+
if (const size_t _Avx_tail_size = _Size_bytes & 0x1C; _Avx_tail_size != 0) {
1933+
_Rewind_bytes(_Last, _Avx_tail_size);
1934+
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Avx_tail_size >> 2);
1935+
const __m256i _Data = _mm256_maskload_epi32(static_cast<const int*>(_Last), _Tail_mask);
1936+
const int _Bingo =
1937+
_mm256_movemask_epi8(_mm256_and_si256(_Traits::_Cmp_avx(_Data, _Comparand), _Tail_mask));
1938+
1939+
if (_Bingo != 0) {
1940+
const unsigned long _Offset = _lzcnt_u32(_Bingo);
1941+
_Advance_bytes(_Last, (31 - _Offset) - (sizeof(_Ty) - 1));
1942+
return _Last;
1943+
}
1944+
}
1945+
1946+
if constexpr (sizeof(_Ty) >= 4) {
1947+
return _Real_last;
1948+
}
1949+
} else if (const size_t _Sse_size = _Size_bytes & ~size_t{0xF}; _Sse_size != 0 && _Use_sse42()) {
19201950
const __m128i _Comparand = _Traits::_Set_sse(_Val);
19211951
const void* _Stop_at = _Last;
19221952
_Rewind_bytes(_Stop_at, _Sse_size);
1953+
19231954
do {
19241955
_Rewind_bytes(_Last, 16);
19251956
const __m128i _Data = _mm_loadu_si128(static_cast<const __m128i*>(_Last));
@@ -1952,40 +1983,53 @@ namespace {
19521983
size_t _Result = 0;
19531984

19541985
#ifndef _M_ARM64EC
1955-
size_t _Size_bytes = _Byte_length(_First, _Last);
1986+
const size_t _Size_bytes = _Byte_length(_First, _Last);
19561987

1957-
const size_t _Avx_size = _Size_bytes & ~size_t{0x1F};
1958-
if (_Avx_size != 0 && _Use_avx2()) {
1988+
if (const size_t _Avx_size = _Size_bytes & ~size_t{0x1F}; _Avx_size != 0 && _Use_avx2()) {
19591989
const __m256i _Comparand = _Traits::_Set_avx(_Val);
19601990
const void* _Stop_at = _First;
19611991
_Advance_bytes(_Stop_at, _Avx_size);
1992+
19621993
do {
19631994
const __m256i _Data = _mm256_loadu_si256(static_cast<const __m256i*>(_First));
19641995
const int _Bingo = _mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand));
19651996
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
19661997
_Advance_bytes(_First, 32);
19671998
} while (_First != _Stop_at);
1968-
_Size_bytes &= 0x1F;
1999+
2000+
if (const size_t _Avx_tail_size = _Size_bytes & 0x1C; _Avx_tail_size != 0) {
2001+
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Avx_tail_size >> 2);
2002+
const __m256i _Data = _mm256_maskload_epi32(static_cast<const int*>(_First), _Tail_mask);
2003+
const int _Bingo =
2004+
_mm256_movemask_epi8(_mm256_and_si256(_Traits::_Cmp_avx(_Data, _Comparand), _Tail_mask));
2005+
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
2006+
_Advance_bytes(_First, _Avx_tail_size);
2007+
}
19692008

19702009
_mm256_zeroupper(); // TRANSITION, DevCom-10331414
1971-
}
19722010

1973-
const size_t _Sse_size = _Size_bytes & ~size_t{0xF};
1974-
if (_Sse_size != 0 && _Use_sse42()) {
2011+
_Result >>= _Traits::_Shift;
2012+
2013+
if constexpr (sizeof(_Ty) >= 4) {
2014+
return _Result;
2015+
}
2016+
} else if (const size_t _Sse_size = _Size_bytes & ~size_t{0xF}; _Sse_size != 0 && _Use_sse42()) {
19752017
const __m128i _Comparand = _Traits::_Set_sse(_Val);
19762018
const void* _Stop_at = _First;
19772019
_Advance_bytes(_Stop_at, _Sse_size);
2020+
19782021
do {
19792022
const __m128i _Data = _mm_loadu_si128(static_cast<const __m128i*>(_First));
19802023
const int _Bingo = _mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand));
19812024
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
19822025
_Advance_bytes(_First, 16);
19832026
} while (_First != _Stop_at);
2027+
2028+
_Result >>= _Traits::_Shift;
19842029
}
19852030
#endif // !_M_ARM64EC
1986-
_Result >>= _Traits::_Shift;
1987-
auto _Ptr = static_cast<const _Ty*>(_First);
1988-
for (; _Ptr != _Last; ++_Ptr) {
2031+
2032+
for (auto _Ptr = static_cast<const _Ty*>(_First); _Ptr != _Last; ++_Ptr) {
19892033
if (*_Ptr == _Val) {
19902034
++_Result;
19912035
}

0 commit comments

Comments
 (0)