Skip to content

Commit c50514c

Browse files
mnorris11facebook-github-bot
authored andcommitted
Add ARM_SVE and ARM_NEON fvec_madd, yy
Summary: buck test mode/opt -c fbcode.arch=aarch64 //faiss/tests/:test_fast_scan_ivf Differential Revision: D74275821
1 parent 33b646f commit c50514c

19 files changed

+914
-1364
lines changed

faiss/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ set(FAISS_SRC
7676
impl/pq4_fast_scan.cpp
7777
impl/pq4_fast_scan_search_1.cpp
7878
impl/pq4_fast_scan_search_qbs.cpp
79-
impl/residual_quantizer_encode_steps.cpp
79+
impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps.cpp
80+
impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps-avx2.cpp
81+
impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps-neon.cpp
8082
impl/zerocopy_io.cpp
8183
impl/NNDescent.cpp
8284
invlists/BlockInvertedLists.cpp
@@ -183,7 +185,7 @@ set(FAISS_HEADERS
183185
impl/lattice_Zn.h
184186
impl/platform_macros.h
185187
impl/pq4_fast_scan.h
186-
impl/residual_quantizer_encode_steps.h
188+
impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps.h
187189
impl/simd_result_handlers.h
188190
impl/zerocopy_io.h
189191
impl/pq_code_distance/code_distance.h

faiss/IndexIVFPQ.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ void initialize_IVFPQ_precomputed_table(
438438

439439
float* tab = &precomputed_table[i * pq.M * pq.ksub];
440440
pq.compute_inner_prod_table(centroid.data(), tab);
441-
fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
441+
fvec_madd<SIMDLevel::NONE>(
442+
pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
442443
}
443444
} else if (use_precomputed_table == 2) {
444445
const MultiIndexQuantizer* miq =
@@ -465,7 +466,8 @@ void initialize_IVFPQ_precomputed_table(
465466

466467
for (size_t i = 0; i < cpq.ksub; i++) {
467468
float* tab = &precomputed_table[i * pq.M * pq.ksub];
468-
fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
469+
fvec_madd<SIMDLevel::NONE>(
470+
pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
469471
}
470472
}
471473
}
@@ -657,7 +659,7 @@ struct QueryTables {
657659
} else if (use_precomputed_table == 1) {
658660
dis0 = coarse_dis;
659661

660-
fvec_madd(
662+
fvec_madd<SIMDLevel::NONE>(
661663
pq.M * pq.ksub,
662664
ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
663665
-2.0,
@@ -693,7 +695,8 @@ struct QueryTables {
693695

694696
if (polysemous_ht == 0) {
695697
// sum up with query-specific table
696-
fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
698+
fvec_madd<SIMDLevel::NONE>(
699+
Mf * pq.ksub, pc, -2.0, qtab, ltab);
697700
ltab += Mf * pq.ksub;
698701
qtab += Mf * pq.ksub;
699702
} else {
@@ -1355,7 +1358,6 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
13551358
const IDSelector* sel,
13561359
const IVFSearchParameters*) const {
13571360
DISPATCH_SIMDLevel(get_InvertedListScanner3, *this, store_pairs, sel);
1358-
return nullptr;
13591361
}
13601362

13611363
IndexIVFPQStats indexIVFPQ_stats;

faiss/IndexIVFPQFastScan.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77

88
#include <faiss/IndexIVFPQFastScan.h>
99

10-
#include <cassert>
11-
#include <cstdio>
12-
13-
#include <memory>
14-
1510
#include <faiss/impl/AuxIndexStructures.h>
1611
#include <faiss/impl/FaissAssert.h>
1712
#include <faiss/utils/distances.h>
@@ -181,7 +176,8 @@ bool IndexIVFPQFastScan::lookup_table_is_3d() const {
181176
return by_residual && metric_type == METRIC_L2;
182177
}
183178

184-
void IndexIVFPQFastScan::compute_LUT(
179+
template <SIMDLevel SL>
180+
void IndexIVFPQFastScan::compute_LUT_helper(
185181
size_t n,
186182
const float* x,
187183
const CoarseQuantized& cq,
@@ -209,8 +205,7 @@ void IndexIVFPQFastScan::compute_LUT(
209205
idx_t cij = cq.ids[ij];
210206

211207
if (cij >= 0) {
212-
// TODO avoid dynamic dispatch
213-
fvec_madd(
208+
fvec_madd<SL>(
214209
dim12,
215210
precomputed_table.get() + cij * dim12,
216211
-2,
@@ -269,4 +264,13 @@ void IndexIVFPQFastScan::compute_LUT(
269264
}
270265
}
271266

267+
void IndexIVFPQFastScan::compute_LUT(
268+
size_t n,
269+
const float* x,
270+
const CoarseQuantized& cq,
271+
AlignedTable<float>& dis_tables,
272+
AlignedTable<float>& biases) const {
273+
DISPATCH_SIMDLevel(compute_LUT_helper, n, x, cq, dis_tables, biases);
274+
}
275+
272276
} // namespace faiss

faiss/IndexIVFPQFastScan.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <faiss/IndexIVFPQ.h>
1414
#include <faiss/impl/ProductQuantizer.h>
1515
#include <faiss/utils/AlignedTable.h>
16+
#include <faiss/utils/simd_levels.h>
1617

1718
namespace faiss {
1819

@@ -81,6 +82,15 @@ struct IndexIVFPQFastScan : IndexIVFFastScan {
8182
const CoarseQuantized& cq,
8283
AlignedTable<float>& dis_tables,
8384
AlignedTable<float>& biases) const override;
85+
86+
protected:
87+
template <SIMDLevel SL>
88+
void compute_LUT_helper(
89+
size_t n,
90+
const float* x,
91+
const CoarseQuantized& cq,
92+
AlignedTable<float>& dis_tables,
93+
AlignedTable<float>& biases) const;
8494
};
8595

8696
} // namespace faiss

faiss/impl/ResidualQuantizer.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <faiss/IndexFlat.h>
1818
#include <faiss/VectorTransform.h>
1919
#include <faiss/impl/FaissAssert.h>
20-
#include <faiss/impl/residual_quantizer_encode_steps.h>
20+
#include <faiss/impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps.h>
2121
#include <faiss/utils/distances.h>
2222
#include <faiss/utils/hamming.h>
2323
#include <faiss/utils/utils.h>
@@ -456,7 +456,7 @@ void ResidualQuantizer::compute_codes_add_centroids(
456456
cent,
457457
pool0);
458458
} else if (use_beam_LUT == 1) {
459-
compute_codes_add_centroids_mp_lut1(
459+
compute_codes_add_centroids_mp_lut1<SIMDLevel::NONE>(
460460
*this,
461461
x + i0 * d,
462462
codes_out + i0 * code_size,
@@ -500,7 +500,8 @@ void ResidualQuantizer::refine_beam_LUT(
500500
int32_t* out_codes,
501501
float* out_distances) const {
502502
RefineBeamLUTMemoryPool pool;
503-
refine_beam_LUT_mp(
503+
504+
refine_beam_LUT_mp<SIMDLevel::NONE>(
504505
*this,
505506
n,
506507
query_norms,

faiss/impl/pq_4bit/pq4_fast_scan.cpp

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@
99
#include <faiss/impl/pq_4bit/pq4_fast_scan.h>
1010
#include <faiss/impl/pq_4bit/simd_result_handlers.h>
1111

12-
#ifdef __x86_64__
13-
#ifdef __AVX2__
14-
#error "this should not be compiled with AVX2"
15-
#endif
16-
#endif
17-
1812
#include <faiss/impl/pq_4bit/kernels_simd256.h>
1913

2014
#include <faiss/impl/pq_4bit/dispatching.h>
@@ -376,6 +370,8 @@ int pq4_preferred_qbs(int n) {
376370

377371
/**************************** Dispatching */
378372

373+
#ifdef COMPILE_SIMD_NONE
374+
379375
template <>
380376
PQ4CodeScanner* make_pq4_scanner<SIMDLevel::NONE, false>(KNN_ARGS_LIST) {
381377
return make_pq4_scanner_1<SIMDLevel::NONE, false>(KNN_ARGS_LIST_2);
@@ -395,6 +391,7 @@ template <>
395391
PQ4CodeScanner* make_pq4_scanner<SIMDLevel::NONE, true>(PRES_ARGS_LIST) {
396392
return make_pq4_scanner_1<SIMDLevel::NONE, true>(PRES_ARGS_LIST_2);
397393
}
394+
#endif // COMPILE_SIMD_NONE
398395

399396
template <bool with_id_map>
400397
PQ4CodeScanner* make_knn_scanner(
@@ -407,22 +404,31 @@ PQ4CodeScanner* make_knn_scanner(
407404
float* dis,
408405
idx_t* ids,
409406
const IDSelector* sel) {
410-
#ifdef COMPILE_SIMD_AVX512
411-
if (SIMDConfig::level == SIMDLevel::AVX512) {
412-
return make_pq4_scanner<SIMDLevel::AVX512, with_id_map>(
413-
is_max, ns, ur, nq, ntotal, k, dis, ids, sel);
414-
} else
407+
// code for dynamic dispatching is commented out
408+
#ifdef COMPILE_SIMD_AVX512F
409+
// if (SIMDConfig::level == SIMDLevel::AVX512F) {
410+
return make_pq4_scanner<SIMDLevel::AVX512F, with_id_map>(
411+
is_max, ns, ur, nq, ntotal, k, dis, ids, sel);
412+
// } else
415413
#endif
416414
#ifdef COMPILE_SIMD_AVX2
417-
if (SIMDConfig::level == SIMDLevel::AVX2) {
418-
return make_pq4_scanner<SIMDLevel::AVX2, with_id_map>(
419-
is_max, ns, ur, nq, ntotal, k, dis, ids, sel);
420-
} else
415+
// if (SIMDConfig::level == SIMDLevel::AVX2) {
416+
return make_pq4_scanner<SIMDLevel::AVX2, with_id_map>(
417+
is_max, ns, ur, nq, ntotal, k, dis, ids, sel);
418+
// } else
421419
#endif
420+
#ifdef COMPILE_SIMD_NEON
421+
// if (SIMDConfig::level == SIMDLevel::ARM_NEON) {
422+
return make_pq4_scanner<SIMDLevel::ARM_NEON, with_id_map>(
423+
is_max, ns, ur, nq, ntotal, k, dis, ids, sel);
424+
// } else
425+
#endif
426+
#ifdef COMPILE_SIMD_NONE
422427
{
423428
return make_pq4_scanner<SIMDLevel::NONE, with_id_map>(
424429
is_max, ns, ur, nq, ntotal, k, dis, ids, sel);
425430
}
431+
#endif
426432
}
427433

428434
PQ4CodeScanner* pq4_make_flat_knn_handler(
@@ -462,22 +468,30 @@ PQ4CodeScanner* pq4_make_ivf_range_handler(
462468
float radius,
463469
int norm_scale,
464470
const IDSelector* sel) {
465-
#ifdef COMPILE_SIMD_AVX512
466-
if (SIMDConfig::level == SIMDLevel::AVX512) {
467-
return make_pq4_scanner<SIMDLevel::AVX512, true>(
468-
is_max, norm_scale, &rres, radius, 0, sel);
469-
} else
471+
#ifdef COMPILE_SIMD_AVX512F
472+
// if (SIMDConfig::level == SIMDLevel::AVX512F) {
473+
return make_pq4_scanner<SIMDLevel::AVX512F, true>(
474+
is_max, norm_scale, &rres, radius, 0, sel);
475+
// } else
470476
#endif
471477
#ifdef COMPILE_SIMD_AVX2
472-
if (SIMDConfig::level == SIMDLevel::AVX2) {
473-
return make_pq4_scanner<SIMDLevel::AVX2, true>(
474-
is_max, norm_scale, &rres, radius, 0, sel);
475-
} else
478+
// if (SIMDConfig::level == SIMDLevel::AVX2) {
479+
return make_pq4_scanner<SIMDLevel::AVX2, true>(
480+
is_max, norm_scale, &rres, radius, 0, sel);
481+
// } else
476482
#endif
483+
#ifdef COMPILE_SIMD_NEON
484+
// if (SIMDConfig::level == SIMDLevel::ARM_NEON) {
485+
return make_pq4_scanner<SIMDLevel::ARM_NEON, true>(
486+
is_max, norm_scale, &rres, radius, 0, sel);
487+
// } else
488+
#endif
489+
#ifdef COMPILE_SIMD_NONE
477490
{
478491
return make_pq4_scanner<SIMDLevel::NONE, true>(
479492
is_max, norm_scale, &rres, radius, 0, sel);
480493
}
494+
#endif
481495
}
482496

483497
PQ4CodeScanner* pq4_make_ivf_partial_range_handler(
@@ -488,22 +502,30 @@ PQ4CodeScanner* pq4_make_ivf_partial_range_handler(
488502
idx_t i1,
489503
int norm_scale,
490504
const IDSelector* sel) {
491-
#ifdef COMPILE_SIMD_AVX512
492-
if (SIMDConfig::level == SIMDLevel::AVX512) {
493-
return make_pq4_scanner<SIMDLevel::AVX512, true>(
494-
is_max, norm_scale, &pres, radius, 0, i0, i1, sel);
495-
} else
505+
#ifdef COMPILE_SIMD_AVX512F
506+
// if (SIMDConfig::level == SIMDLevel::AVX512F) {
507+
return make_pq4_scanner<SIMDLevel::AVX512F, true>(
508+
is_max, norm_scale, &pres, radius, 0, i0, i1, sel);
509+
// } else
496510
#endif
497511
#ifdef COMPILE_SIMD_AVX2
498-
if (SIMDConfig::level == SIMDLevel::AVX2) {
499-
return make_pq4_scanner<SIMDLevel::AVX2, true>(
500-
is_max, norm_scale, &pres, radius, 0, i0, i1, sel);
501-
} else
512+
// if (SIMDConfig::level == SIMDLevel::AVX2) {
513+
return make_pq4_scanner<SIMDLevel::AVX2, true>(
514+
is_max, norm_scale, &pres, radius, 0, i0, i1, sel);
515+
// } else
502516
#endif
517+
#ifdef COMPILE_SIMD_NEON
518+
// if (SIMDConfig::level == SIMDLevel::ARM_NEON) {
519+
return make_pq4_scanner<SIMDLevel::ARM_NEON, true>(
520+
is_max, norm_scale, &pres, radius, 0, i0, i1, sel);
521+
// } else
522+
#endif
523+
#ifdef COMPILE_SIMD_NONE
503524
{
504525
return make_pq4_scanner<SIMDLevel::NONE, true>(
505526
is_max, norm_scale, &pres, radius, 0, i0, i1, sel);
506527
}
528+
#endif
507529
}
508530

509531
} // namespace faiss

0 commit comments

Comments
 (0)