Skip to content

Commit 3da7d27

Browse files
Vectorize equality-comparing algorithms for more types (#5527)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent d38c194 commit 3da7d27

File tree

3 files changed

+70
-20
lines changed

3 files changed

+70
-20
lines changed

stl/inc/algorithm

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,6 @@ _Ty* _Unique_copy_vectorized(const _Ty* const _First, const _Ty* const _Last, _T
351351
_STL_INTERNAL_STATIC_ASSERT(false); // Unexpected size
352352
}
353353
}
354-
355-
// Can we activate the vector algorithms for find_first_of?
356-
template <class _It1, class _It2, class _Pr>
357-
constexpr bool _Vector_alg_in_find_first_of_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr>;
358-
359354
// Can we activate the vector algorithms for replace?
360355
template <class _Iter, class _Ty1>
361356
constexpr bool _Vector_alg_in_replace_is_safe = _Vector_alg_in_find_is_safe<_Iter, _Ty1> // can search for the value
@@ -376,7 +371,7 @@ constexpr bool _Vector_alg_in_search_n_is_safe = _Vector_alg_in_find_is_safe<_It
376371
equal_to<>>;
377372
// Can we activate the vector algorithms for unique?
378373
template <class _Iter, class _Pr>
379-
constexpr bool _Vector_alg_in_unique_is_safe = _Equal_memcmp_is_safe<_Iter, _Iter, _Pr>;
374+
constexpr bool _Vector_alg_in_unique_is_safe = _Vector_alg_in_search_is_safe<_Iter, _Iter, _Pr>;
380375

381376
// Can we use this output iterator for remove_copy or unique_copy?
382377
template <class _Out, class _In>
@@ -712,7 +707,7 @@ _NODISCARD _CONSTEXPR20 _FwdIt adjacent_find(const _FwdIt _First, _FwdIt _Last,
712707
auto _ULast = _STD _Get_unwrapped(_Last);
713708
if (_UFirst != _ULast) {
714709
#if _USE_STD_VECTOR_ALGORITHMS
715-
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst), decltype(_UFirst), _Pr>) {
710+
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst), decltype(_UFirst), _Pr>) {
716711
if (!_STD _Is_constant_evaluated()) {
717712
const auto _First_ptr = _STD _To_address(_UFirst);
718713
const auto _Result = _STD _Adjacent_find_vectorized(_First_ptr, _STD _To_address(_ULast));
@@ -888,7 +883,7 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(_InIt1 _First1, const _InI
888883
const auto _ULast1 = _STD _Get_unwrapped(_Last1);
889884
auto _UFirst2 = _STD _Get_unwrapped_n(_First2, _STD _Idl_distance<_InIt1>(_UFirst1, _ULast1));
890885
#if _USE_STD_VECTOR_ALGORITHMS
891-
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
886+
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
892887
if (!_STD _Is_constant_evaluated()) {
893888
constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>);
894889

@@ -952,7 +947,7 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(
952947
const auto _Count = static_cast<_Iter_diff_t<_InIt1>>((_STD min)(_Count1, _Count2));
953948
_ULast1 = _UFirst1 + _Count;
954949
#if _USE_STD_VECTOR_ALGORITHMS
955-
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
950+
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
956951
if (!_STD _Is_constant_evaluated()) {
957952
constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>);
958953

@@ -3796,7 +3791,7 @@ _NODISCARD _CONSTEXPR20 _FwdIt1 find_first_of(
37963791
}
37973792

37983793
#if _USE_STD_VECTOR_ALGORITHMS
3799-
if constexpr (_Vector_alg_in_find_first_of_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
3794+
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
38003795
if (!_STD _Is_constant_evaluated() && _ULast1 - _UFirst1 >= _Threshold_find_first_of) {
38013796
const auto _First1_ptr = _STD _To_address(_UFirst1);
38023797
const auto _Result = _STD _Find_first_of_vectorized(
@@ -3900,7 +3895,7 @@ namespace ranges {
39003895
}
39013896

39023897
#if _USE_STD_VECTOR_ALGORITHMS
3903-
if constexpr (_Vector_alg_in_find_first_of_is_safe<_It1, _It2, _Pr> && sized_sentinel_for<_Se1, _It1>
3898+
if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && sized_sentinel_for<_Se1, _It1>
39043899
&& sized_sentinel_for<_Se2, _It2> && is_same_v<_Pj1, identity> && is_same_v<_Pj2, identity>) {
39053900
if (!_STD is_constant_evaluated() && _Last1 - _First1 >= _Threshold_find_first_of) {
39063901
const auto _Count1 = static_cast<ptrdiff_t>(_Last1 - _First1);

stl/inc/xutility

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ _STL_DISABLE_CLANG_WARNINGS
6363
#endif // _USE_STD_VECTOR_FLOATING_ALGORITHMS && !_USE_STD_VECTOR_ALGORITHMS
6464
#endif // ^^^ defined(_USE_STD_VECTOR_FLOATING_ALGORITHMS) ^^^
6565

66+
#ifndef _USE_BUILTIN_IS_TRIVIALLY_EQUALITY_COMPARABLE
67+
#if defined(__clang__) && __clang_major__ >= 19
68+
#define _USE_BUILTIN_IS_TRIVIALLY_EQUALITY_COMPARABLE 1
69+
#else // ^^^ defined(__clang__) && __clang_major__ >= 19 / !defined(__clang__) || __clang_major__ < 19 vvv
70+
#define _USE_BUILTIN_IS_TRIVIALLY_EQUALITY_COMPARABLE 0
71+
#endif // ^^^ !defined(__clang__) || __clang_major__ < 19 ^^^
72+
#endif // ^^^ !defined(_USE_BUILTIN_IS_TRIVIALLY_EQUALITY_COMPARABLE) ^^^
73+
6674
#if _USE_STD_VECTOR_ALGORITHMS
6775
extern "C" {
6876
// The "noalias" attribute tells the compiler optimizer that pointers going into these hand-vectorized algorithms
@@ -5481,8 +5489,14 @@ inline constexpr bool _Can_memcmp_elements<byte, byte, false> = true;
54815489
template <class _Ty1, class _Ty2>
54825490
constexpr bool _Can_memcmp_elements<_Ty1*, _Ty2*, false> = _Is_pointer_address_comparable<_Ty1, _Ty2>;
54835491

5492+
#if _USE_BUILTIN_IS_TRIVIALLY_EQUALITY_COMPARABLE
5493+
template <class _Elem1, class _Elem2>
5494+
constexpr bool _Can_memcmp_elements<_Elem1, _Elem2, false> =
5495+
is_same_v<_Elem1, _Elem2> && __is_trivially_equality_comparable(_Elem1);
5496+
#else // ^^^ _USE_BUILTIN_IS_TRIVIALLY_EQUALITY_COMPARABLE / !_USE_BUILTIN_IS_TRIVIALLY_EQUALITY_COMPARABLE vvv
54845497
template <class _Elem1, class _Elem2>
54855498
constexpr bool _Can_memcmp_elements<_Elem1, _Elem2, false> = false;
5499+
#endif // ^^^ !_USE_BUILTIN_IS_TRIVIALLY_EQUALITY_COMPARABLE ^^^
54865500

54875501
// _Can_memcmp_elements_with_pred<_Elem1, _Elem2, _Pr> reports whether the memcmp optimization is applicable,
54885502
// given contiguously stored elements. (This avoids having to repeat the metaprogramming that finds the element types.)
@@ -5519,9 +5533,15 @@ template <class _Iter1, class _Iter2, class _Pr>
55195533
constexpr bool _Equal_memcmp_is_safe =
55205534
_Equal_memcmp_is_safe_helper<remove_const_t<_Iter1>, remove_const_t<_Iter2>, remove_const_t<_Pr>>;
55215535

5536+
#if _USE_STD_VECTOR_ALGORITHMS
5537+
template <size_t _Size>
5538+
constexpr bool _Is_vector_element_size = _Size == 1 || _Size == 2 || _Size == 4 || _Size == 8;
5539+
55225540
// Can we activate the vector algorithms for std::search?
55235541
template <class _It1, class _It2, class _Pr>
5524-
constexpr bool _Vector_alg_in_search_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr>;
5542+
constexpr bool _Vector_alg_in_search_is_safe =
5543+
_Equal_memcmp_is_safe<_It1, _It2, _Pr> && _Is_vector_element_size<sizeof(_Iter_value_t<_It1>)>;
5544+
#endif // _USE_STD_VECTOR_ALGORITHMS
55255545

55265546
template <class _CtgIt1, class _CtgIt2>
55275547
_NODISCARD int _Memcmp_count(_CtgIt1 _First1, _CtgIt2 _First2, const size_t _Count) {
@@ -5679,7 +5699,7 @@ namespace ranges {
56795699
_It1 _First1, _It2 _First2, iter_difference_t<_It1> _Count, _Pr _Pred, _Pj1 _Proj1, _Pj2 _Proj2) {
56805700
_STL_INTERNAL_CHECK(_Count >= 0);
56815701
#if _USE_STD_VECTOR_ALGORITHMS
5682-
if constexpr (_Equal_memcmp_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
5702+
if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
56835703
&& is_same_v<_Pj2, identity>) {
56845704
if (!_STD is_constant_evaluated()) {
56855705
constexpr size_t _Elem_size = sizeof(iter_value_t<_It1>);
@@ -6829,7 +6849,7 @@ namespace ranges {
68296849
}
68306850

68316851
#if _USE_STD_VECTOR_ALGORITHMS
6832-
if constexpr (_Equal_memcmp_is_safe<_It, _It, _Pr> && sized_sentinel_for<_Se, _It>
6852+
if constexpr (_Vector_alg_in_search_is_safe<_It, _It, _Pr> && sized_sentinel_for<_Se, _It>
68336853
&& is_same_v<_Pj, identity>) {
68346854
if (!_STD is_constant_evaluated()) {
68356855
const auto _First_ptr = _STD _To_address(_First);

tests/std/tests/GH_000431_equal_memcmp_is_safe/test.compile.pass.cpp

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
#include <type_traits>
1313
#include <vector>
1414

15+
#ifdef __clang__
16+
constexpr bool magic = true;
17+
#else
18+
constexpr bool magic = false;
19+
#endif
20+
1521
using namespace std;
1622

1723
#define STATIC_ASSERT(...) static_assert(__VA_ARGS__, #__VA_ARGS__)
@@ -137,6 +143,22 @@ struct StatefulDerived2 : EmptyBase, StatefulBase {};
137143

138144
struct StatefulPrivatelyDerived2 : private EmptyBase, private StatefulBase {};
139145

146+
#if _HAS_CXX20
147+
struct DefaultComparison {
148+
int i;
149+
150+
bool operator==(const DefaultComparison&) const noexcept = default;
151+
};
152+
153+
struct DefaultComparisonOddSize {
154+
int i;
155+
int j;
156+
int k;
157+
158+
bool operator==(const DefaultComparisonOddSize&) const noexcept = default;
159+
};
160+
#endif // _HAS_CXX20
161+
140162
#ifdef __cpp_lib_is_pointer_interconvertible
141163
STATIC_ASSERT(is_pointer_interconvertible_base_of_v<EmptyBase, EmptyDerived>);
142164
STATIC_ASSERT(is_pointer_interconvertible_base_of_v<EmptyBase, EmptyPrivatelyDerived>);
@@ -485,21 +507,34 @@ STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, void (*)(), void (*)(in
485507
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<is_convertible_v<void (*)(int), void*>, void (*)(int), void*>());
486508
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<is_convertible_v<void (*)(int), void*>, void*, void (*)(int)>());
487509

488-
// Don't allow member object pointers
489-
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, int EmptyBase::*, int EmptyBase::*>());
490-
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, int EmptyDerived::*, int EmptyDerived::*>());
510+
// Don't allow member object pointers, unless magic happens
511+
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<magic, int EmptyBase::*, int EmptyBase::*>());
512+
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<magic, int EmptyDerived::*, int EmptyDerived::*>());
491513
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, int EmptyBase::*, int EmptyDerived::*>());
492514
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, int EmptyDerived::*, int EmptyBase::*>());
493515

494-
// Don't allow member function pointers
495-
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, int (EmptyBase::*)(), int (EmptyBase::*)()>());
496-
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, int (EmptyDerived::*)(), int (EmptyDerived::*)()>());
516+
// Don't allow member function pointers, unless magic happens
517+
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<magic, int (EmptyBase::*)(), int (EmptyBase::*)()>());
518+
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<magic, int (EmptyDerived::*)(), int (EmptyDerived::*)()>());
497519
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, int (EmptyBase::*)(), int (EmptyDerived::*)()>());
498520
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, int (EmptyDerived::*)(), int (EmptyBase::*)()>());
499521

500522
// Don't allow user-defined types
501523
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<false, StatefulBase, StatefulBase>());
502524

525+
#if _HAS_CXX20
526+
// Don't allow user-defined types with default comparison, unless magic happens
527+
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<magic, DefaultComparison, DefaultComparison>());
528+
STATIC_ASSERT(test_equal_memcmp_is_safe_for_types<magic, DefaultComparisonOddSize, DefaultComparisonOddSize>());
529+
// The only difference between _Equal_memcmp_is_safe and _Vector_alg_in_search_is_safe is how the magic works
530+
STATIC_ASSERT(_Equal_memcmp_is_safe<DefaultComparison*, DefaultComparison*, equal_to<>> == magic);
531+
STATIC_ASSERT(_Equal_memcmp_is_safe<DefaultComparisonOddSize*, DefaultComparisonOddSize*, equal_to<>> == magic);
532+
#if _USE_STD_VECTOR_ALGORITHMS
533+
STATIC_ASSERT(_Vector_alg_in_search_is_safe<DefaultComparison*, DefaultComparison*, equal_to<>> == magic);
534+
STATIC_ASSERT(!_Vector_alg_in_search_is_safe<DefaultComparisonOddSize*, DefaultComparisonOddSize*, equal_to<>>);
535+
#endif // _USE_STD_VECTOR_ALGORITHMS
536+
#endif // _HAS_CXX20
537+
503538
// Test _Std_char_traits_eq
504539
STATIC_ASSERT(test_equal_memcmp_is_safe_for_pred<true, char, char, _Std_char_traits_eq<char>>());
505540
STATIC_ASSERT(test_equal_memcmp_is_safe_for_pred<true, wchar_t, wchar_t, _Std_char_traits_eq<wchar_t>>());

0 commit comments

Comments
 (0)