Skip to content

Commit 6705ba7

Browse files
Adding concept to a part of the code (#2842)
1 parent cd07eaa commit 6705ba7

File tree

16 files changed

+355
-323
lines changed

16 files changed

+355
-323
lines changed

include/xtensor/containers/xfixed.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,9 @@ namespace xt
325325
explicit xfixed_container(const inner_shape_type& shape, layout_type l = L);
326326
explicit xfixed_container(const inner_shape_type& shape, value_type v, layout_type l = L);
327327

328-
// remove this enable_if when removing the other value_type constructor
329-
template <class IX = std::integral_constant<std::size_t, N>, class EN = std::enable_if_t<IX::value != 0, int>>
330-
xfixed_container(nested_initializer_list_t<value_type, N> t);
328+
template <class IX = std::integral_constant<std::size_t, N>>
329+
xfixed_container(nested_initializer_list_t<value_type, N> t)
330+
requires(IX::value != 0);
331331

332332
~xfixed_container() = default;
333333

@@ -639,8 +639,9 @@ namespace xt
639639
* Note: for clang < 3.8 this is an initializer_list and the size is not checked at compile-or runtime.
640640
*/
641641
template <class ET, class S, layout_type L, bool SH, class Tag>
642-
template <class IX, class EN>
642+
template <class IX>
643643
inline xfixed_container<ET, S, L, SH, Tag>::xfixed_container(nested_initializer_list_t<value_type, N> t)
644+
requires(IX::value != 0)
644645
{
645646
XTENSOR_ASSERT_MSG(
646647
detail::check_initializer_list_shape<N>::run(t, this->shape()) == true,

include/xtensor/containers/xscalar.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,9 @@ namespace xt
317317
template <class E>
318318
using is_xscalar = detail::is_xscalar_impl<E>;
319319

320+
template <class E>
321+
concept xscalar_concept = is_xscalar<std::decay_t<E>>::value;
322+
320323
namespace detail
321324
{
322325
template <class... E>

include/xtensor/containers/xstorage.hpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@
2525

2626
namespace xt
2727
{
28-
29-
namespace detail
30-
{
31-
template <class It>
32-
using require_input_iter = typename std::enable_if<
33-
std::is_convertible<typename std::iterator_traits<It>::iterator_category, std::input_iterator_tag>::value>::type;
34-
}
35-
3628
template <class C>
3729
struct is_contiguous_container : std::true_type
3830
{
@@ -64,7 +56,7 @@ namespace xt
6456
explicit uvector(size_type count, const allocator_type& alloc = allocator_type());
6557
uvector(size_type count, const_reference value, const allocator_type& alloc = allocator_type());
6658

67-
template <class InputIt, class = detail::require_input_iter<InputIt>>
59+
template <std::input_iterator InputIt>
6860
uvector(InputIt first, InputIt last, const allocator_type& alloc = allocator_type());
6961

7062
uvector(std::initializer_list<T> init, const allocator_type& alloc = allocator_type());
@@ -277,7 +269,7 @@ namespace xt
277269
}
278270

279271
template <class T, class A>
280-
template <class InputIt, class>
272+
template <std::input_iterator InputIt>
281273
inline uvector<T, A>::uvector(InputIt first, InputIt last, const allocator_type& alloc)
282274
: m_allocator(alloc)
283275
, p_begin(nullptr)
@@ -675,19 +667,21 @@ namespace xt
675667

676668
svector(const std::vector<T>& vec);
677669

678-
template <class IT, class = detail::require_input_iter<IT>>
670+
template <std::input_iterator IT>
679671
svector(IT begin, IT end, const allocator_type& alloc = allocator_type());
680672

681-
template <std::size_t N2, bool I2, class = std::enable_if_t<N != N2, void>>
682-
explicit svector(const svector<T, N2, A, I2>& rhs);
673+
template <std::size_t N2, bool I2>
674+
explicit svector(const svector<T, N2, A, I2>& rhs)
675+
requires(N != N2);
683676

684677
svector& operator=(const svector& rhs);
685678
svector& operator=(svector&& rhs) noexcept(std::is_nothrow_move_assignable<value_type>::value);
686679
svector& operator=(const std::vector<T>& rhs);
687680
svector& operator=(std::initializer_list<T> il);
688681

689-
template <std::size_t N2, bool I2, class = std::enable_if_t<N != N2, void>>
690-
svector& operator=(const svector<T, N2, A, I2>& rhs);
682+
template <std::size_t N2, bool I2>
683+
svector& operator=(const svector<T, N2, A, I2>& rhs)
684+
requires(N != N2);
691685

692686
svector(const svector& other);
693687
svector(svector&& other) noexcept(std::is_nothrow_move_constructible<value_type>::value);
@@ -809,16 +803,17 @@ namespace xt
809803
}
810804

811805
template <class T, std::size_t N, class A, bool Init>
812-
template <class IT, class>
806+
template <std::input_iterator IT>
813807
inline svector<T, N, A, Init>::svector(IT begin, IT end, const allocator_type& alloc)
814808
: m_allocator(alloc)
815809
{
816810
assign(begin, end);
817811
}
818812

819813
template <class T, std::size_t N, class A, bool Init>
820-
template <std::size_t N2, bool I2, class>
814+
template <std::size_t N2, bool I2>
821815
inline svector<T, N, A, Init>::svector(const svector<T, N2, A, I2>& rhs)
816+
requires(N != N2)
822817
: m_allocator(rhs.get_allocator())
823818
{
824819
assign(rhs.begin(), rhs.end());
@@ -876,8 +871,9 @@ namespace xt
876871
}
877872

878873
template <class T, std::size_t N, class A, bool Init>
879-
template <std::size_t N2, bool I2, class>
874+
template <std::size_t N2, bool I2>
880875
inline svector<T, N, A, Init>& svector<T, N, A, Init>::operator=(const svector<T, N2, A, I2>& rhs)
876+
requires(N != N2)
881877
{
882878
m_allocator = std::allocator_traits<allocator_type>::select_on_container_copy_construction(
883879
rhs.get_allocator()

include/xtensor/core/xexpression.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ namespace xt
179179
template <class E>
180180
using is_xexpression = is_crtp_base_of<xexpression, E>;
181181

182+
template <class E>
183+
concept xexpression_concept = is_xexpression<E>::value;
184+
182185
template <class E, class R = void>
183186
using enable_xexpression = typename std::enable_if<is_xexpression<E>::value, R>::type;
184187

include/xtensor/core/xshape.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,9 @@ namespace xt
508508
};
509509
}
510510

511+
template <typename T>
512+
concept fixed_shape_container_concept = detail::is_fixed<typename std::decay_t<T>::shape_type>::value;
513+
511514
template <class... S>
512515
struct promote_shape
513516
{

include/xtensor/generators/xbuilder.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ namespace xt
833833
return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape);
834834
}
835835

836-
template <std::size_t axis, class... CT, typename = std::enable_if_t<detail::all_fixed_shapes<CT...>::value>>
836+
template <std::size_t axis, fixed_shape_container_concept... CT>
837837
inline auto concatenate(std::tuple<CT...>&& t)
838838
{
839839
using shape_type = detail::concat_fixed_shape_t<axis, typename std::decay_t<CT>::shape_type...>;

include/xtensor/generators/xgenerator.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ namespace xt
6060
template <class F, class R, class S>
6161
class xgenerator;
6262

63+
template <typename T>
64+
concept xgenerator_concept = is_specialization_of<xgenerator, std::decay_t<T>>::value;
65+
6366
template <class C, class R, class S>
6467
struct xiterable_inner_types<xgenerator<C, R, S>>
6568
{
@@ -80,10 +83,9 @@ namespace xt
8083
* overlapping_memory_checker_traits *
8184
*************************************/
8285

83-
template <class E>
84-
struct overlapping_memory_checker_traits<
85-
E,
86-
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xgenerator, E>::value>>
86+
template <xgenerator_concept E>
87+
requires(without_memory_address_concept<E>)
88+
struct overlapping_memory_checker_traits<E>
8789
{
8890
static bool check_overlap(const E&, const memory_range&)
8991
{
@@ -165,8 +167,9 @@ namespace xt
165167
template <class O>
166168
const_stepper stepper_end(const O& shape, layout_type) const noexcept;
167169

168-
template <class E, class FE = F, class = std::enable_if_t<has_assign_to<E, FE>::value>>
169-
void assign_to(xexpression<E>& e) const noexcept;
170+
template <class E, class FE = F>
171+
void assign_to(xexpression<E>& e) const noexcept
172+
requires(has_assign_to_v<E, FE>);
170173

171174
const functor_type& functor() const noexcept;
172175

@@ -371,8 +374,9 @@ namespace xt
371374
}
372375

373376
template <class F, class R, class S>
374-
template <class E, class, class>
377+
template <class E, class FE>
375378
inline void xgenerator<F, R, S>::assign_to(xexpression<E>& e) const noexcept
379+
requires(has_assign_to_v<E, FE>)
376380
{
377381
e.derived_cast().resize(m_shape);
378382
m_f.assign_to(e);

include/xtensor/generators/xrandom.hpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "../core/xtensor_config.hpp"
2828
#include "../generators/xbuilder.hpp"
2929
#include "../generators/xgenerator.hpp"
30+
#include "../misc/xtl_concepts.hpp"
3031
#include "../views/xindex_view.hpp"
3132
#include "../views/xview.hpp"
3233

@@ -175,13 +176,11 @@ namespace xt
175176
template <class T, class E = random::default_engine_type>
176177
void shuffle(xexpression<T>& e, E& engine = random::get_default_random_engine());
177178

178-
template <class T, class E = random::default_engine_type>
179-
std::enable_if_t<xtl::is_integral<T>::value, xtensor<T, 1>>
180-
permutation(T e, E& engine = random::get_default_random_engine());
179+
template <xtl::integral_concept T, class E = random::default_engine_type>
180+
xtensor<T, 1> permutation(T e, E& engine = random::get_default_random_engine());
181181

182-
template <class T, class E = random::default_engine_type>
183-
std::enable_if_t<is_xexpression<std::decay_t<T>>::value, std::decay_t<T>>
184-
permutation(T&& e, E& engine = random::get_default_random_engine());
182+
template <xexpression_concept T, class E = random::default_engine_type>
183+
std::decay_t<T> permutation(T&& e, E& engine = random::get_default_random_engine());
185184

186185
template <class T, class E = random::default_engine_type>
187186
xtensor<typename T::value_type, 1> choice(
@@ -835,17 +834,17 @@ namespace xt
835834
*
836835
* @return randomly permuted copy of container or arange.
837836
*/
838-
template <class T, class E>
839-
std::enable_if_t<xtl::is_integral<T>::value, xtensor<T, 1>> permutation(T e, E& engine)
837+
template <xtl::integral_concept T, class E>
838+
xtensor<T, 1> permutation(T e, E& engine)
840839
{
841840
xt::xtensor<T, 1> res = xt::arange<T>(e);
842841
shuffle(res, engine);
843842
return res;
844843
}
845844

846845
/// @cond DOXYGEN_INCLUDE_SFINAE
847-
template <class T, class E>
848-
std::enable_if_t<is_xexpression<std::decay_t<T>>::value, std::decay_t<T>> permutation(T&& e, E& engine)
846+
template <xexpression_concept T, class E>
847+
std::decay_t<T> permutation(T&& e, E& engine)
849848
{
850849
using copy_type = std::decay_t<T>;
851850
copy_type res = e;

include/xtensor/misc/xfft.hpp

Lines changed: 43 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
#include "../misc/xcomplex.hpp"
1313
#include "../views/xaxis_slice_iterator.hpp"
1414
#include "../views/xview.hpp"
15+
#include "./xtl_concepts.hpp"
1516

1617
namespace xt
1718
{
1819
namespace fft
1920
{
2021
namespace detail
2122
{
22-
template <
23-
class E,
24-
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
23+
template <xtl::complex_concept E>
2524
inline auto radix2(E&& e)
2625
{
2726
using namespace xt::placeholders;
@@ -125,72 +124,59 @@ namespace xt
125124
* @param axis the axis along which to perform the 1D FFT
126125
* @return a transformed xarray of the specified precision
127126
*/
128-
template <
129-
class E,
130-
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
127+
template <class E>
131128
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
132129
{
133-
using value_type = typename std::decay_t<E>::value_type;
134-
using precision = typename value_type::value_type;
135-
const auto saxis = xt::normalize_axis(e.dimension(), axis);
136-
const size_t N = e.shape(saxis);
137-
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
138-
xt::xarray<std::complex<precision>> out = xt::eval(e);
139-
auto begin = xt::axis_slice_begin(out, saxis);
140-
auto end = xt::axis_slice_end(out, saxis);
141-
for (auto iter = begin; iter != end; iter++)
130+
using value_type = typename std::decay<E>::type::value_type;
131+
if constexpr (xtl::is_complex<typename std::decay<E>::type::value_type>::value)
142132
{
143-
if (powerOfTwo)
144-
{
145-
xt::noalias(*iter) = detail::radix2(*iter);
146-
}
147-
else
133+
using precision = typename value_type::value_type;
134+
const auto saxis = xt::normalize_axis(e.dimension(), axis);
135+
const size_t N = e.shape(saxis);
136+
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
137+
xt::xarray<std::complex<precision>> out = xt::eval(e);
138+
auto begin = xt::axis_slice_begin(out, saxis);
139+
auto end = xt::axis_slice_end(out, saxis);
140+
for (auto iter = begin; iter != end; iter++)
148141
{
149-
xt::noalias(*iter) = detail::transform_bluestein(*iter);
142+
if (powerOfTwo)
143+
{
144+
xt::noalias(*iter) = detail::radix2(*iter);
145+
}
146+
else
147+
{
148+
xt::noalias(*iter) = detail::transform_bluestein(*iter);
149+
}
150150
}
151+
return out;
151152
}
152-
return out;
153-
}
154-
155-
/**
156-
* @brief 1D FFT of an Nd array along a specified axis
157-
* @param e an Nd expression to be transformed to the fourier domain
158-
* @param axis the axis along which to perform the 1D FFT
159-
* @return a transformed xarray of the specified precision
160-
*/
161-
template <
162-
class E,
163-
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
164-
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
165-
{
166-
using value_type = typename std::decay<E>::type::value_type;
167-
return fft(xt::cast<std::complex<value_type>>(e), axis);
168-
}
169-
170-
template <
171-
class E,
172-
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
173-
auto ifft(E&& e, std::ptrdiff_t axis = -1)
174-
{
175-
// check the length of the data on that axis
176-
const std::size_t n = e.shape(axis);
177-
if (n == 0)
153+
else
178154
{
179-
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
155+
return fft(xt::cast<std::complex<value_type>>(e), axis);
180156
}
181-
auto complex_args = xt::conj(e);
182-
auto fft_res = xt::fft::fft(complex_args, axis);
183-
fft_res = xt::conj(fft_res);
184-
return fft_res;
185157
}
186158

187-
template <
188-
class E,
189-
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
159+
template <class E>
190160
inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
191161
{
192-
using value_type = typename std::decay<E>::type::value_type;
193-
return ifft(xt::cast<std::complex<value_type>>(e), axis);
162+
if constexpr (xtl::is_complex<typename std::decay<E>::type::value_type>::value)
163+
{
164+
// check the length of the data on that axis
165+
const std::size_t n = e.shape(axis);
166+
if (n == 0)
167+
{
168+
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
169+
}
170+
auto complex_args = xt::conj(e);
171+
auto fft_res = xt::fft::fft(complex_args, axis);
172+
fft_res = xt::conj(fft_res);
173+
return fft_res;
174+
}
175+
else
176+
{
177+
using value_type = typename std::decay<E>::type::value_type;
178+
return ifft(xt::cast<std::complex<value_type>>(e), axis);
179+
}
194180
}
195181

196182
/*

0 commit comments

Comments
 (0)