Skip to content

Commit 79293aa

Browse files
Subhadeep Karanfacebook-github-bot
authored andcommitted
Dynamic Dispatch Task 7 (distances.cpp) (#4559)
Summary: Pull Request resolved: #4559 Reviewed By: mnorris11 Differential Revision: D80500327
1 parent c50514c commit 79293aa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2449
-1688
lines changed

demos/demo_simd_levels.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import os
77
import time
88
from collections import defaultdict
9+
import numpy as np
10+
import os
11+
from collections import defaultdict
912

1013
import faiss
11-
import numpy as np
1214
from faiss.contrib.datasets import SyntheticDataset
1315

1416

@@ -26,14 +28,13 @@
2628
index.train(ds.get_train())
2729
index.add(ds.get_database())
2830

31+
simd_levels = [faiss.SIMDLevel_NONE, faiss.SIMDLevel_AVX2, faiss.SIMDLevel_AVX512]
2932

3033
if False:
3134
faiss.omp_set_num_threads(1)
3235
print("PID=", os.getpid())
3336
input("press enter to continue")
34-
# for simd_level in faiss.NONE, faiss.AVX2, faiss.AVX512F:
35-
for simd_level in faiss.AVX2, faiss.AVX512F:
36-
37+
for simd_level in simd_levels:
3738
faiss.SIMDConfig.set_level(simd_level)
3839
print("simd_level=", faiss.SIMDConfig.get_level_name())
3940
for _ in range(1000):
@@ -42,8 +43,7 @@
4243
times = defaultdict(list)
4344

4445
for _ in range(10):
45-
for simd_level in (faiss.SIMDLevel_NONE, faiss.SIMDLevel_AVX2,
46-
faiss.SIMDLevel_AVX512F):
46+
for simd_level in simd_levels:
4747
faiss.SIMDConfig.set_level(simd_level)
4848

4949
t0 = time.time()

faiss/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ set(FAISS_SRC
9999
utils/utils.cpp
100100
utils/distances_fused/avx512.cpp
101101
utils/distances_fused/distances_fused.cpp
102-
utils/distances_fused/simdlib_based.cpp
103102
)
104103

105104
set(FAISS_HEADERS
@@ -227,7 +226,6 @@ set(FAISS_HEADERS
227226
utils/utils.h
228227
utils/distances_fused/avx512.h
229228
utils/distances_fused/distances_fused.h
230-
utils/distances_fused/simdlib_based.h
231229
utils/approx_topk/approx_topk.h
232230
utils/approx_topk/avx2-inl.h
233231
utils/approx_topk/generic.h

faiss/IndexIVFPQ.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,24 @@ size_t precomputed_table_max_bytes = ((size_t)1) << 31;
373373
* is faster when the length of the lists is > ksub * M.
374374
*/
375375

376+
void initialize_IVFPQ_precomputed_table(
377+
int& use_precomputed_table,
378+
const Index* quantizer,
379+
const ProductQuantizer& pq,
380+
AlignedTable<float>& precomputed_table,
381+
bool by_residual,
382+
bool verbose) {
383+
DISPATCH_SIMDLevel(
384+
initialize_IVFPQ_precomputed_table,
385+
use_precomputed_table,
386+
quantizer,
387+
pq,
388+
precomputed_table,
389+
by_residual,
390+
verbose);
391+
}
392+
393+
template <SIMDLevel SL>
376394
void initialize_IVFPQ_precomputed_table(
377395
int& use_precomputed_table,
378396
const Index* quantizer,
@@ -438,8 +456,7 @@ void initialize_IVFPQ_precomputed_table(
438456

439457
float* tab = &precomputed_table[i * pq.M * pq.ksub];
440458
pq.compute_inner_prod_table(centroid.data(), tab);
441-
fvec_madd<SIMDLevel::NONE>(
442-
pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
459+
fvec_madd<SL>(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
443460
}
444461
} else if (use_precomputed_table == 2) {
445462
const MultiIndexQuantizer* miq =
@@ -466,8 +483,7 @@ void initialize_IVFPQ_precomputed_table(
466483

467484
for (size_t i = 0; i < cpq.ksub; i++) {
468485
float* tab = &precomputed_table[i * pq.M * pq.ksub];
469-
fvec_madd<SIMDLevel::NONE>(
470-
pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
486+
fvec_madd<SL>(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
471487
}
472488
}
473489
}
@@ -645,6 +661,11 @@ struct QueryTables {
645661
* compute tables for L2 distance
646662
*****************************************************/
647663

664+
float precompute_list_tables_L2() {
665+
DISPATCH_SIMDLevel(precompute_list_tables_L2);
666+
}
667+
668+
template <SIMDLevel SL>
648669
float precompute_list_tables_L2() {
649670
float dis0 = 0;
650671

@@ -659,7 +680,7 @@ struct QueryTables {
659680
} else if (use_precomputed_table == 1) {
660681
dis0 = coarse_dis;
661682

662-
fvec_madd<SIMDLevel::NONE>(
683+
fvec_madd<SL>(
663684
pq.M * pq.ksub,
664685
ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
665686
-2.0,
@@ -695,13 +716,12 @@ struct QueryTables {
695716

696717
if (polysemous_ht == 0) {
697718
// sum up with query-specific table
698-
fvec_madd<SIMDLevel::NONE>(
699-
Mf * pq.ksub, pc, -2.0, qtab, ltab);
719+
fvec_madd<SL>(Mf * pq.ksub, pc, -2.0, qtab, ltab);
700720
ltab += Mf * pq.ksub;
701721
qtab += Mf * pq.ksub;
702722
} else {
703723
for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
704-
q_code[m] = fvec_madd_and_argmin(
724+
q_code[m] = fvec_madd_and_argmin<SL>(
705725
pq.ksub, pc, -2, qtab, ltab);
706726
pc += pq.ksub;
707727
ltab += pq.ksub;

faiss/IndexIVFPQ.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <faiss/IndexPQ.h>
1717
#include <faiss/impl/platform_macros.h>
1818
#include <faiss/utils/AlignedTable.h>
19+
#include <faiss/utils/simd_levels.h>
1920

2021
namespace faiss {
2122

@@ -172,6 +173,15 @@ void initialize_IVFPQ_precomputed_table(
172173
bool by_residual,
173174
bool verbose);
174175

176+
template <SIMDLevel SL>
177+
void initialize_IVFPQ_precomputed_table(
178+
int& use_precomputed_table,
179+
const Index* quantizer,
180+
const ProductQuantizer& pq,
181+
AlignedTable<float>& precomputed_table,
182+
bool by_residual,
183+
bool verbose);
184+
175185
/// statistics are robust to internal threading, but not if
176186
/// IndexIVFPQ::search_preassigned is called by multiple threads
177187
struct IndexIVFPQStats {

faiss/impl/FaissAssert.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575

7676
#define FAISS_THROW_FMT(FMT, ...) \
7777
do { \
78-
std::string __s; \
78+
::std::string __s; \
7979
int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \
8080
__s.resize(__size + 1); \
8181
snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \

faiss/impl/ResidualQuantizer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ void ResidualQuantizer::compute_codes_add_centroids(
456456
cent,
457457
pool0);
458458
} else if (use_beam_LUT == 1) {
459-
compute_codes_add_centroids_mp_lut1<SIMDLevel::NONE>(
459+
compute_codes_add_centroids_mp_lut1(
460460
*this,
461461
x + i0 * d,
462462
codes_out + i0 * code_size,
@@ -501,7 +501,8 @@ void ResidualQuantizer::refine_beam_LUT(
501501
float* out_distances) const {
502502
RefineBeamLUTMemoryPool pool;
503503

504-
refine_beam_LUT_mp<SIMDLevel::NONE>(
504+
DISPATCH_SIMDLevel(
505+
refine_beam_LUT_mp,
505506
*this,
506507
n,
507508
query_norms,

faiss/impl/pq_4bit/LookupTableScaler.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
#pragma once
99

10-
#include <cstdint>
11-
#include <cstdlib>
12-
10+
#include <faiss/impl/FaissAssert.h>
1311
#include <faiss/utils/simdlib.h>
1412

1513
/*******************************************
@@ -43,7 +41,7 @@ struct DummyScaler {
4341
return simd16uint16(0);
4442
}
4543

46-
#ifdef __AVX512F__
44+
#if defined(COMPILE_SIMD_AVX512) && defined(__AVX512F__)
4745

4846
using simd64uint8 = simd64uint8<SIMDLevel::AVX512>;
4947
using simd32uint16 = simd32uint16<SIMDLevel::AVX512>;
@@ -96,7 +94,7 @@ struct Scaler2x4bit {
9694
return (simd16uint16(res) >> 8) * scale_simd;
9795
}
9896

99-
#ifdef __AVX512F__
97+
#if defined(COMPILE_SIMD_AVX512) && defined(__AVX512F__)
10098
using simd64uint8 = simd64uint8<SIMDLevel::AVX512>;
10199
using simd32uint16 = simd32uint16<SIMDLevel::AVX512>;
102100

faiss/impl/pq_4bit/decompose_qbs.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
// This is not standalone code, it is intended to be included in the kernels-X.h
1111
// files.
1212

13+
#include <cstddef>
14+
#include <cstdint>
15+
16+
#include <faiss/impl/pq_4bit/simd_result_handlers.h>
17+
#include <faiss/utils/simd_levels.h>
18+
1319
// handle at most 4 blocks of queries
1420
template <int QBS, class ResultHandler, class Scaler>
1521
void accumulate_q_4step(
@@ -19,15 +25,15 @@ void accumulate_q_4step(
1925
const uint8_t* LUT0,
2026
ResultHandler& res,
2127
const Scaler& scaler) {
22-
constexpr SIMDLevel SL = ResultHandler::SL;
28+
constexpr faiss::SIMDLevel SL = ResultHandler::SL;
2329
constexpr int Q1 = QBS & 15;
2430
constexpr int Q2 = (QBS >> 4) & 15;
2531
constexpr int Q3 = (QBS >> 8) & 15;
2632
constexpr int Q4 = (QBS >> 12) & 15;
2733
constexpr int SQ = Q1 + Q2 + Q3 + Q4;
2834

2935
for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
30-
FixedStorageHandler<SQ, 2, SL> res2;
36+
faiss::simd_result_handlers::FixedStorageHandler<SQ, 2, SL> res2;
3137
const uint8_t* LUT = LUT0;
3238
kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler);
3339
LUT += Q1 * nsq * 16;
@@ -77,8 +83,8 @@ void accumulate(
7783
ResultHandler& res,
7884
const Scaler& scaler) {
7985
assert(nsq % 2 == 0);
80-
assert(is_aligned_pointer(codes));
81-
assert(is_aligned_pointer(LUT));
86+
assert(faiss::is_aligned_pointer(codes));
87+
assert(faiss::is_aligned_pointer(LUT));
8288

8389
#define DISPATCH(NQ) \
8490
case NQ: \
@@ -107,8 +113,8 @@ void pq4_accumulate_loop_qbs_fixed_scaler(
107113
ResultHandler& res,
108114
const Scaler& scaler) {
109115
assert(nsq % 2 == 0);
110-
assert(is_aligned_pointer(codes));
111-
assert(is_aligned_pointer(LUT0));
116+
assert(faiss::is_aligned_pointer(codes));
117+
assert(faiss::is_aligned_pointer(LUT0));
112118

113119
// try out optimized versions
114120
switch (qbs) {

faiss/impl/pq_4bit/dispatching.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212

1313
#include <faiss/impl/FaissAssert.h>
1414
#include <faiss/impl/platform_macros.h>
15+
#include <faiss/impl/pq_4bit/LookupTableScaler.h>
1516
#include <faiss/impl/pq_4bit/pq4_fast_scan.h>
1617
#include <faiss/impl/pq_4bit/simd_result_handlers.h>
1718

18-
#include <faiss/impl/pq_4bit/kernels_simd256.h>
19+
#include <faiss/impl/pq_4bit/kernels_common.h>
1920

2021
namespace faiss {
22+
using namespace simd_result_handlers;
2123

2224
/** Mix-in class that manages both an SIMD result hander and offers the actual
2325
* scanning routines. */
@@ -69,14 +71,18 @@ PQ4CodeScanner* make_handler_2(
6971
const IDSelector* sel) {
7072
if (k == 1) {
7173
return new ScannerMixIn<
72-
SingleResultHandler<C, with_id_map, SL>,
74+
faiss::simd_result_handlers::
75+
SingleResultHandler<C, with_id_map, SL>,
7376
Scaler>(norm_scale, nq, ntotal, dis, ids, sel);
7477
} else if (use_reservoir) {
75-
return new ScannerMixIn<ReservoirHandler<C, with_id_map, SL>, Scaler>(
76-
norm_scale, nq, ntotal, k, 2 * k, dis, ids, sel);
78+
return new ScannerMixIn<
79+
faiss::simd_result_handlers::
80+
ReservoirHandler<C, with_id_map, SL>,
81+
Scaler>(norm_scale, nq, ntotal, k, 2 * k, dis, ids, sel);
7782
} else {
78-
return new ScannerMixIn<HeapHandler<C, with_id_map, SL>, Scaler>(
79-
norm_scale, nq, ntotal, k, dis, ids, sel);
83+
return new ScannerMixIn<
84+
faiss::simd_result_handlers::HeapHandler<C, with_id_map, SL>,
85+
Scaler>(norm_scale, nq, ntotal, k, dis, ids, sel);
8086
}
8187
}
8288

@@ -87,8 +93,9 @@ PQ4CodeScanner* make_handler_2(
8793
float radius,
8894
size_t ntotal,
8995
const IDSelector* sel) {
90-
return new ScannerMixIn<RangeHandler<C, with_id_map, SL>, Scaler>(
91-
norm_scale, rres, radius, ntotal, sel);
96+
return new ScannerMixIn<
97+
faiss::simd_result_handlers::RangeHandler<C, with_id_map, SL>,
98+
Scaler>(norm_scale, rres, radius, ntotal, sel);
9299
}
93100

94101
template <SIMDLevel SL, bool with_id_map, class C, class Scaler>
@@ -100,8 +107,10 @@ PQ4CodeScanner* make_handler_2(
100107
size_t q0,
101108
size_t q1,
102109
const IDSelector* sel) {
103-
return new ScannerMixIn<PartialRangeHandler<C, with_id_map, SL>, Scaler>(
104-
norm_scale, pres, radius, ntotal, q0, q1, sel);
110+
return new ScannerMixIn<
111+
faiss::simd_result_handlers::
112+
PartialRangeHandler<C, with_id_map, SL>,
113+
Scaler>(norm_scale, pres, radius, ntotal, q0, q1, sel);
105114
}
106115

107116
// this function dispatches runtime -> template parameters. It is generic for

faiss/impl/pq_4bit/impl-neon.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#include <faiss/utils/simd_levels.h>
99

10-
#ifdef COMPILE_SIMD_NEON
10+
#ifdef COMPILE_SIMD_ARM_NEON
1111

1212
#ifndef __aarch64__
1313
#error "this can only run on aarch64"
@@ -42,4 +42,4 @@ PQ4CodeScanner* make_pq4_scanner<SIMDLevel::ARM_NEON, true>(PRES_ARGS_LIST) {
4242

4343
} // namespace faiss
4444

45-
#endif // COMPILE_SIMD_NEON
45+
#endif // COMPILE_SIMD_ARM_NEON

0 commit comments

Comments
 (0)