Skip to content

Commit 0a0514c

Browse files
Use find for search_n when n=1 (#5346)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent 7b38f9f commit 0a0514c

File tree

4 files changed

+122
-0
lines changed

4 files changed

+122
-0
lines changed

benchmarks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ add_benchmark(random_integer_generation src/random_integer_generation.cpp)
125125
add_benchmark(remove src/remove.cpp)
126126
add_benchmark(replace src/replace.cpp)
127127
add_benchmark(search src/search.cpp)
128+
add_benchmark(search_n src/search_n.cpp)
128129
add_benchmark(std_copy src/std_copy.cpp)
129130
add_benchmark(sv_equal src/sv_equal.cpp)
130131
add_benchmark(swap_ranges src/swap_ranges.cpp)

benchmarks/src/search_n.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include <algorithm>
5+
#include <benchmark/benchmark.h>
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <vector>
9+
10+
#include "skewed_allocator.hpp"
11+
12+
using namespace std;
13+
14+
// NB: This particular algorithm has std and ranges implementations with different perf characteristics!
15+
16+
enum class AlgType { Std, Rng };
17+
18+
template <class T, AlgType Alg>
19+
void bm(benchmark::State& state) {
20+
const auto size = static_cast<size_t>(state.range(0));
21+
22+
constexpr size_t N = 1;
23+
24+
constexpr T no_match{'-'};
25+
constexpr T match{'*'};
26+
27+
vector<T, not_highly_aligned_allocator<T>> v(size, no_match);
28+
29+
fill(v.begin() + v.size() / 2, v.end(), match);
30+
31+
for (auto _ : state) {
32+
if constexpr (Alg == AlgType::Std) {
33+
benchmark::DoNotOptimize(search_n(v.begin(), v.end(), N, match));
34+
} else if constexpr (Alg == AlgType::Rng) {
35+
benchmark::DoNotOptimize(ranges::search_n(v, N, match));
36+
}
37+
}
38+
}
39+
40+
void common_args(auto bm) {
41+
bm->Arg(3000);
42+
}
43+
44+
BENCHMARK(bm<uint8_t, AlgType::Std>)->Apply(common_args);
45+
BENCHMARK(bm<uint8_t, AlgType::Rng>)->Apply(common_args);
46+
47+
BENCHMARK(bm<uint16_t, AlgType::Std>)->Apply(common_args);
48+
BENCHMARK(bm<uint16_t, AlgType::Rng>)->Apply(common_args);
49+
50+
BENCHMARK(bm<uint32_t, AlgType::Std>)->Apply(common_args);
51+
BENCHMARK(bm<uint32_t, AlgType::Rng>)->Apply(common_args);
52+
53+
BENCHMARK(bm<uint64_t, AlgType::Std>)->Apply(common_args);
54+
BENCHMARK(bm<uint64_t, AlgType::Rng>)->Apply(common_args);
55+
56+
BENCHMARK_MAIN();

stl/inc/algorithm

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,16 @@ _NODISCARD _CONSTEXPR20 _FwdIt search_n(
22622262
return _First;
22632263
}
22642264

2265+
if constexpr (_Is_any_of_v<_Pr,
2266+
#if _HAS_CXX20
2267+
_RANGES equal_to,
2268+
#endif
2269+
equal_to<>>) {
2270+
if (_Count == 1) {
2271+
return _STD find(_First, _Last, _Val);
2272+
}
2273+
}
2274+
22652275
if (static_cast<uintmax_t>(_Count) > static_cast<uintmax_t>(_STD _Max_limit<_Iter_diff_t<_FwdIt>>())) {
22662276
// if the number of _Vals searched for is larger than the longest possible sequence, we can't find it
22672277
return _Last;
@@ -2362,6 +2372,19 @@ namespace ranges {
23622372
return {_First, _First};
23632373
}
23642374

2375+
if constexpr (_Is_any_of_v<_Pr, _STD equal_to<>, _RANGES equal_to>) {
2376+
if (_Count == 1) {
2377+
auto _Res = _RANGES find(_First, _Last, _Val, _Pass_fn(_Proj));
2378+
if (_Res != _Last) {
2379+
auto _Res_end = _Res;
2380+
++_Res_end;
2381+
return {_STD move(_Res), _STD move(_Res_end)};
2382+
} else {
2383+
return {_Res, _Res};
2384+
}
2385+
}
2386+
}
2387+
23652388
auto _UFirst = _RANGES _Unwrap_iter<_Se>(_STD move(_First));
23662389
auto _ULast = _RANGES _Unwrap_sent<_It>(_STD move(_Last));
23672390

@@ -2388,6 +2411,20 @@ namespace ranges {
23882411
return {_First, _First};
23892412
}
23902413

2414+
if constexpr (_Is_any_of_v<_Pr, _STD equal_to<>, _RANGES equal_to>) {
2415+
if (_Count == 1) {
2416+
auto _Res = _RANGES find(_Range, _Val, _Pass_fn(_Proj));
2417+
auto _Last = _RANGES end(_Range);
2418+
if (_Res != _Last) {
2419+
auto _Res_end = _Res;
2420+
++_Res_end;
2421+
return {_STD move(_Res), _STD move(_Res_end)};
2422+
} else {
2423+
return {_Res, _Res};
2424+
}
2425+
}
2426+
}
2427+
23912428
if constexpr (sized_range<_Rng>) {
23922429
const auto _Dist = _RANGES distance(_Range);
23932430

tests/std/tests/P0896R4_ranges_alg_search_n/test.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,34 @@ struct instantiator {
9898
assert(result.end() == range.begin());
9999
}
100100

101+
// trivial case: unit needle
102+
{
103+
const auto result = ranges::search_n(range, 1, P{1, 42});
104+
static_assert(same_as<decltype(result), const ranges::subrange<ranges::iterator_t<Fwd>>>);
105+
assert(result.begin() == ranges::next(range.begin(), 1));
106+
assert(result.end() == ranges::next(range.begin(), 2));
107+
}
108+
{
109+
const auto result = ranges::search_n(ranges::begin(range), ranges::end(range), 1, P{1, 42});
110+
static_assert(same_as<decltype(result), const ranges::subrange<ranges::iterator_t<Fwd>>>);
111+
assert(result.begin() == ranges::next(range.begin(), 1));
112+
assert(result.end() == ranges::next(range.begin(), 2));
113+
}
114+
115+
// trivial case: unit needle with predicate
116+
{
117+
const auto result = ranges::search_n(range, 1, 0, cmp, get_first);
118+
static_assert(same_as<decltype(result), const ranges::subrange<ranges::iterator_t<Fwd>>>);
119+
assert(result.begin() == ranges::next(range.begin(), 1));
120+
assert(result.end() == ranges::next(range.begin(), 2));
121+
}
122+
{
123+
const auto result = ranges::search_n(ranges::begin(range), ranges::end(range), 1, 0, cmp, get_first);
124+
static_assert(same_as<decltype(result), const ranges::subrange<ranges::iterator_t<Fwd>>>);
125+
assert(result.begin() == ranges::next(range.begin(), 1));
126+
assert(result.end() == ranges::next(range.begin(), 2));
127+
}
128+
101129
// trivial case: range too small
102130
{
103131
const auto result = ranges::search_n(range, 99999, 0, cmp, get_first);

0 commit comments

Comments
 (0)