Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions demos/demo_simd_levels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
from collections import defaultdict
import numpy as np

import faiss
from faiss.contrib.datasets import SyntheticDataset


print("compile options", faiss.get_compile_options())
print("SIMD level: ", faiss.SIMDConfig.get_level_name())


ds = SyntheticDataset(32, 8000, 10000, 8000)


index = faiss.index_factory(ds.d, "PQ16x4fs")
# index = faiss.index_factory(ds.d, "IVF64,PQ16x4fs")
# index = faiss.index_factory(ds.d, "SQ8")

index.train(ds.get_train())
index.add(ds.get_database())

simd_levels = [
faiss.SIMDLevel_NONE, faiss.SIMDLevel_AVX2, faiss.SIMDLevel_AVX512
]

if False:
faiss.omp_set_num_threads(1)
print("PID=", os.getpid())
input("press enter to continue")
for simd_level in simd_levels:
faiss.SIMDConfig.set_level(simd_level)
print("simd_level=", faiss.SIMDConfig.get_level_name())
for _ in range(1000):
D, I = index.search(ds.get_queries(), 10)

times = defaultdict(list)

for _ in range(10):
for simd_level in simd_levels:
faiss.SIMDConfig.set_level(simd_level)

t0 = time.time()
D, I = index.search(ds.get_queries(), 10)
t1 = time.time()

times[faiss.SIMDConfig.get_level_name()].append(t1 - t0)

for simd_level in times:
mean_search_time = np.mean(times[simd_level]) * 1000
print(
f"simd_level={simd_level} search time: {mean_search_time:.3f} ms"
)
18 changes: 9 additions & 9 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ set(FAISS_SRC
impl/pq4_fast_scan.cpp
impl/pq4_fast_scan_search_1.cpp
impl/pq4_fast_scan_search_qbs.cpp
impl/residual_quantizer_encode_steps.cpp
impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps.cpp
impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps-avx2.cpp
impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps-neon.cpp
impl/zerocopy_io.cpp
impl/NNDescent.cpp
invlists/BlockInvertedLists.cpp
Expand All @@ -97,7 +99,6 @@ set(FAISS_SRC
utils/utils.cpp
utils/distances_fused/avx512.cpp
utils/distances_fused/distances_fused.cpp
utils/distances_fused/simdlib_based.cpp
)

set(FAISS_HEADERS
Expand Down Expand Up @@ -183,14 +184,14 @@ set(FAISS_HEADERS
impl/lattice_Zn.h
impl/platform_macros.h
impl/pq4_fast_scan.h
impl/residual_quantizer_encode_steps.h
impl/residual_quantizer_encode_steps/residual_quantizer_encode_steps.h
impl/simd_result_handlers.h
impl/zerocopy_io.h
impl/code_distance/code_distance.h
impl/code_distance/code_distance-generic.h
impl/code_distance/code_distance-avx2.h
impl/code_distance/code_distance-avx512.h
impl/code_distance/code_distance-sve.h
impl/pq_code_distance/code_distance.h
impl/pq_code_distance/code_distance-generic.h
impl/pq_code_distance/code_distance-avx2.h
impl/pq_code_distance/code_distance-avx512.h
impl/pq_code_distance/code_distance-sve.h
invlists/BlockInvertedLists.h
invlists/DirectMap.h
invlists/InvertedLists.h
Expand Down Expand Up @@ -225,7 +226,6 @@ set(FAISS_HEADERS
utils/utils.h
utils/distances_fused/avx512.h
utils/distances_fused/distances_fused.h
utils/distances_fused/simdlib_based.h
utils/approx_topk/approx_topk.h
utils/approx_topk/avx2-inl.h
utils/approx_topk/generic.h
Expand Down
13 changes: 2 additions & 11 deletions faiss/IndexAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@

#include <faiss/IndexAdditiveQuantizerFastScan.h>

#include <cassert>
#include <memory>

#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/LocalSearchQuantizer.h>
#include <faiss/impl/LookupTableScaler.h>
#include <faiss/impl/ResidualQuantizer.h>
#include <faiss/impl/pq4_fast_scan.h>
#include <faiss/impl/pq_4bit/pq4_fast_scan.h>
#include <faiss/utils/quantize_lut.h>
#include <faiss/utils/utils.h>

Expand Down Expand Up @@ -199,12 +195,7 @@ void IndexAdditiveQuantizerFastScan::search(
return;
}

NormTableScaler scaler(norm_scale);
if (metric_type == METRIC_L2) {
search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
} else {
search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
}
search_dispatch_implem(n, x, k, distances, labels, norm_scale);
}

void IndexAdditiveQuantizerFastScan::sa_decode(
Expand Down
Loading
Loading