Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions extension/vector/src/catalog/hnsw_index_catalog_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ std::string HNSWIndexAuxInfo::toCypher(const IndexCatalogEntry& indexEntry,
catalog->getTableCatalogEntry(context->getTransaction(), indexEntry.getTableID());
auto tableName = tableEntry->getName();
auto propertyName = tableEntry->getProperty(indexEntry.getPropertyIDs()[0]).getName();
auto distFuncName = HNSWIndexConfig::distFuncToString(config.distFunc);
auto metricName = HNSWIndexConfig::metricToString(config.metric);
cypher += common::stringFormat("CALL CREATE_HNSW_INDEX('{}', '{}', '{}', mu := {}, ml := {}, "
"pu := {}, distFunc := '{}', alpha := {}, efc := {});",
"pu := {}, metric := '{}', alpha := {}, efc := {});",
tableName, indexEntry.getIndexName(), propertyName, config.mu, config.ml, config.pu,
distFuncName, config.alpha, config.efc);
metricName, config.alpha, config.efc);
return cypher;
}

Expand Down
11 changes: 6 additions & 5 deletions extension/vector/src/function/create_hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ static std::unique_ptr<processor::PhysicalOperator> getPhysicalPlan(
auto sharedState = info.function.initSharedStateFunc(initInput);
auto finalizeFuncSharedState = sharedState->ptrCast<FinalizeHNSWSharedState>();
finalizeFuncSharedState->hnswIndex = createFuncSharedState->hnswIndex;
finalizeFuncSharedState->numRows =
(logicalCallBoundData->numRows + StorageConfig::NODE_GROUP_SIZE - 1) /
StorageConfig::NODE_GROUP_SIZE;
auto numNodeGroups =
logicalCallBoundData->numRows == 0 ?
0 :
storage::StorageUtils::getNodeGroupIdx(logicalCallBoundData->numRows) + 1;
finalizeFuncSharedState->numRows = numNodeGroups;
finalizeFuncSharedState->maxMorselSize = 1;
finalizeFuncSharedState->bindData = logicalCallBoundData->copy();
auto finalizeCallOp = std::make_unique<processor::TableFunctionCall>(std::move(info),
Expand Down Expand Up @@ -304,8 +306,7 @@ static std::string rewriteCreateHNSWQuery(main::ClientContext& context,
params += stringFormat("mu := {}, ", config.mu);
params += stringFormat("ml := {}, ", config.ml);
params += stringFormat("efc := {}, ", config.efc);
params +=
stringFormat("distFunc := '{}', ", HNSWIndexConfig::distFuncToString(config.distFunc));
params += stringFormat("metric := '{}', ", HNSWIndexConfig::metricToString(config.metric));
params += stringFormat("alpha := {}, ", config.alpha);
params += stringFormat("pu := {}", config.pu);
auto columnName = hnswBindData->tableEntry->getProperty(hnswBindData->propertyID).getName();
Expand Down
36 changes: 27 additions & 9 deletions extension/vector/src/include/index/hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace kuzu {
namespace vector_extension {

enum class DistFuncType : uint8_t { Cosine = 0, L2 = 1, L2_SQUARE = 2, DotProduct = 3 };
enum class MetricType : uint8_t { Cosine = 0, L2 = 1, L2_SQUARE = 2, DotProduct = 3 };

// Max degree of the upper graph.
struct Mu {
Expand Down Expand Up @@ -35,12 +35,12 @@ struct Pu {
static void validate(double value);
};

struct DistFunc {
static constexpr const char* NAME = "distfunc";
struct Metric {
static constexpr const char* NAME = "metric";
static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::STRING;
static constexpr DistFuncType DEFAULT_VALUE = DistFuncType::Cosine;
static constexpr MetricType DEFAULT_VALUE = MetricType::Cosine;

static void validate(const std::string& distFuncName);
static void validate(const std::string& metric);
};

struct Alpha {
Expand All @@ -67,11 +67,27 @@ struct Efs {
static void validate(int64_t value);
};

struct BlindSearchUpSelThreshold {
static constexpr const char* NAME = "blind_search_up_sel";
static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::DOUBLE;
static constexpr double DEFAULT_VALUE = 0.08;

static void validate(double value);
};

struct DirectedSearchUpSelThreshold {
static constexpr const char* NAME = "directed_search_up_sel";
static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::DOUBLE;
static constexpr double DEFAULT_VALUE = 0.4;

static void validate(double value);
};

struct HNSWIndexConfig {
int64_t mu = Mu::DEFAULT_VALUE;
int64_t ml = Ml::DEFAULT_VALUE;
double pu = Pu::DEFAULT_VALUE;
DistFuncType distFunc = DistFunc::DEFAULT_VALUE;
MetricType metric = Metric::DEFAULT_VALUE;
double alpha = Alpha::DEFAULT_VALUE;
int64_t efc = Efc::DEFAULT_VALUE;

Expand All @@ -85,18 +101,20 @@ struct HNSWIndexConfig {

static HNSWIndexConfig deserialize(common::Deserializer& deSer);

static std::string distFuncToString(DistFuncType distFunc);
static std::string metricToString(MetricType metric);

private:
HNSWIndexConfig(const HNSWIndexConfig& other)
: mu{other.mu}, ml{other.ml}, pu{other.pu}, distFunc{other.distFunc}, alpha{other.alpha},
: mu{other.mu}, ml{other.ml}, pu{other.pu}, metric{other.metric}, alpha{other.alpha},
efc{other.efc} {}

static DistFuncType getDistFuncType(const std::string& funcName);
static MetricType getMetricType(const std::string& metricName);
};

struct QueryHNSWConfig {
int64_t efs = Efs::DEFAULT_VALUE;
double blindSearchUpSelThreshold = BlindSearchUpSelThreshold::DEFAULT_VALUE;
double directedSearchUpSelThreshold = DirectedSearchUpSelThreshold::DEFAULT_VALUE;

QueryHNSWConfig() = default;

Expand Down
4 changes: 2 additions & 2 deletions extension/vector/src/include/index/hnsw_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ struct NodeWithDistance {
struct HNSWGraphInfo {
common::offset_t numNodes;
EmbeddingColumn* embeddings;
DistFuncType distFunc;
MetricType distFunc;

HNSWGraphInfo(common::offset_t numNodes, EmbeddingColumn* embeddings, DistFuncType distFunc)
HNSWGraphInfo(common::offset_t numNodes, EmbeddingColumn* embeddings, MetricType distFunc)
: numNodes{numNodes}, embeddings{embeddings}, distFunc{distFunc} {}
};

Expand Down
11 changes: 4 additions & 7 deletions extension/vector/src/include/index/hnsw_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class HNSWIndex {
struct InMemHNSWLayerInfo {
common::offset_t numNodes;
InMemEmbeddings* embeddings;
DistFuncType distFunc;
MetricType metric;
// The degree threshold of a node that will start to trigger shrinking during insertions. Thus,
// it is also the max degree of a node in the graph before shrinking.
int64_t degreeThresholdToShrink;
Expand All @@ -97,10 +97,9 @@ struct InMemHNSWLayerInfo {
double alpha;
int64_t efc;

InMemHNSWLayerInfo(common::offset_t numNodes, InMemEmbeddings* embeddings,
DistFuncType distFunc, int64_t degreeThresholdToShrink, int64_t maxDegree, double alpha,
int64_t efc)
: numNodes{numNodes}, embeddings{embeddings}, distFunc{distFunc},
InMemHNSWLayerInfo(common::offset_t numNodes, InMemEmbeddings* embeddings, MetricType metric,
int64_t degreeThresholdToShrink, int64_t maxDegree, double alpha, int64_t efc)
: numNodes{numNodes}, embeddings{embeddings}, metric{metric},
degreeThresholdToShrink{degreeThresholdToShrink}, maxDegree{maxDegree}, alpha{alpha},
efc{efc} {}
};
Expand Down Expand Up @@ -257,8 +256,6 @@ class OnDiskHNSWIndex final : public HNSWIndex {
max_node_priority_queue_t& results) const;

private:
static constexpr double BLIND_SEARCH_UP_SEL_THRESHOLD = 0.08;
static constexpr double DIRECTED_SEARCH_UP_SEL_THRESHOLD = 0.4;
static constexpr uint64_t FILTERED_SEARCH_INITIAL_CANDIDATES = 10;

common::table_id_t nodeTableID;
Expand Down
2 changes: 1 addition & 1 deletion extension/vector/src/include/index/hnsw_index_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct HNSWIndexUtils {
static void validateAutoTransaction(const main::ClientContext& context,
const std::string& funcName);

static double computeDistance(DistFuncType funcType, const float* left, const float* right,
static double computeDistance(MetricType funcType, const float* left, const float* right,
uint32_t dimension);

static void validateColumnType(const catalog::TableCatalogEntry& tableEntry,
Expand Down
103 changes: 66 additions & 37 deletions extension/vector/src/index/hnsw_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ void Pu::validate(double value) {
}
}

void DistFunc::validate(const std::string& distFuncName) {
const auto lowerCaseFuncName = common::StringUtils::getLower(distFuncName);
if (lowerCaseFuncName != "cosine" && lowerCaseFuncName != "l2" && lowerCaseFuncName != "l2sq" &&
lowerCaseFuncName != "dotproduct") {
throw common::BinderException{"DistFunc must be one of COSINE, L2, L2SQ or DOTPRODUCT."};
void Metric::validate(const std::string& metric) {
const auto lowerCaseMetric = common::StringUtils::getLower(metric);
if (lowerCaseMetric != "cosine" && lowerCaseMetric != "l2" && lowerCaseMetric != "l2sq" &&
lowerCaseMetric != "dotproduct" && lowerCaseMetric != "dot_product") {
throw common::BinderException{"Metric must be one of COSINE, L2, L2SQ or DOTPRODUCT."};
}
}

Expand All @@ -54,6 +54,20 @@ void Efs::validate(int64_t value) {
}
}

void BlindSearchUpSelThreshold::validate(double value) {
if (value < 0 || value > 1) {
throw common::BinderException{
"Blind search upper selectivity threshold must be a double between 0 and 1."};
}
}

void DirectedSearchUpSelThreshold::validate(double value) {
if (value < 0 || value > 1) {
throw common::BinderException{
"Directed search upper selectivity threshold must be a double between 0 and 1."};
}
}

HNSWIndexConfig::HNSWIndexConfig(const function::optional_params_t& optionalParams) {
for (auto& [name, value] : optionalParams) {
auto lowerCaseName = common::StringUtils::getLower(name);
Expand All @@ -69,11 +83,11 @@ HNSWIndexConfig::HNSWIndexConfig(const function::optional_params_t& optionalPara
value.validateType(Pu::TYPE);
pu = value.getValue<double>();
Pu::validate(pu);
} else if (DistFunc::NAME == lowerCaseName) {
value.validateType(DistFunc::TYPE);
} else if (Metric::NAME == lowerCaseName) {
value.validateType(Metric::TYPE);
auto funcName = value.getValue<std::string>();
DistFunc::validate(funcName);
distFunc = getDistFuncType(funcName);
Metric::validate(funcName);
metric = getMetricType(funcName);
} else if (Alpha::NAME == lowerCaseName) {
value.validateType(Alpha::TYPE);
alpha = value.getValue<double>();
Expand All @@ -89,23 +103,23 @@ HNSWIndexConfig::HNSWIndexConfig(const function::optional_params_t& optionalPara
}
}

std::string HNSWIndexConfig::distFuncToString(DistFuncType distFunc) {
switch (distFunc) {
case DistFuncType::Cosine: {
std::string HNSWIndexConfig::metricToString(MetricType metric) {
switch (metric) {
case MetricType::Cosine: {
return "cosine";
}
case DistFuncType::L2: {
case MetricType::L2: {
return "l2";
}
case DistFuncType::L2_SQUARE: {
case MetricType::L2_SQUARE: {
return "l2sq";
}
case DistFuncType::DotProduct: {
case MetricType::DotProduct: {
return "dotproduct";
}
default: {
throw common::RuntimeException(common::stringFormat("Unknown distance function type {}.",
static_cast<int64_t>(distFunc)));
static_cast<int64_t>(metric)));
}
}
}
Expand All @@ -115,8 +129,8 @@ void HNSWIndexConfig::serialize(common::Serializer& ser) const {
ser.serializeValue(mu);
ser.writeDebuggingInfo("degreeInLowerLayer");
ser.serializeValue(ml);
ser.writeDebuggingInfo("distFunc");
ser.serializeValue<uint8_t>(static_cast<uint8_t>(distFunc));
ser.writeDebuggingInfo("metric");
ser.serializeValue<uint8_t>(static_cast<uint8_t>(metric));
ser.writeDebuggingInfo("alpha");
ser.serializeValue(alpha);
ser.writeDebuggingInfo("efc");
Expand All @@ -125,35 +139,35 @@ void HNSWIndexConfig::serialize(common::Serializer& ser) const {

HNSWIndexConfig HNSWIndexConfig::deserialize(common::Deserializer& deSer) {
auto config = HNSWIndexConfig{};
std::string debugginInfo;
deSer.validateDebuggingInfo(debugginInfo, "degreeInUpperLayer");
std::string debuggingInfo;
deSer.validateDebuggingInfo(debuggingInfo, "degreeInUpperLayer");
deSer.deserializeValue(config.mu);
deSer.validateDebuggingInfo(debugginInfo, "degreeInLowerLayer");
deSer.validateDebuggingInfo(debuggingInfo, "degreeInLowerLayer");
deSer.deserializeValue(config.ml);
deSer.validateDebuggingInfo(debugginInfo, "distFunc");
uint8_t distFunc = 0;
deSer.deserializeValue(distFunc);
config.distFunc = static_cast<DistFuncType>(distFunc);
deSer.validateDebuggingInfo(debugginInfo, "alpha");
deSer.validateDebuggingInfo(debuggingInfo, "metric");
uint8_t metric = 0;
deSer.deserializeValue(metric);
config.metric = static_cast<MetricType>(metric);
deSer.validateDebuggingInfo(debuggingInfo, "alpha");
deSer.deserializeValue(config.alpha);
deSer.validateDebuggingInfo(debugginInfo, "efc");
deSer.validateDebuggingInfo(debuggingInfo, "efc");
deSer.deserializeValue(config.efc);
return config;
}

DistFuncType HNSWIndexConfig::getDistFuncType(const std::string& funcName) {
const auto lowerFuncName = common::StringUtils::getLower(funcName);
if (lowerFuncName == "cosine") {
return DistFuncType::Cosine;
MetricType HNSWIndexConfig::getMetricType(const std::string& metricName) {
const auto lowerMetricName = common::StringUtils::getLower(metricName);
if (lowerMetricName == "cosine") {
return MetricType::Cosine;
}
if (lowerFuncName == "l2") {
return DistFuncType::L2;
if (lowerMetricName == "l2") {
return MetricType::L2;
}
if (lowerFuncName == "l2sq") {
return DistFuncType::L2_SQUARE;
if (lowerMetricName == "l2sq") {
return MetricType::L2_SQUARE;
}
if (lowerFuncName == "dotproduct") {
return DistFuncType::DotProduct;
if (lowerMetricName == "dot_product" || lowerMetricName == "dotproduct") {
return MetricType::DotProduct;
}
KU_UNREACHABLE;
}
Expand All @@ -165,11 +179,26 @@ QueryHNSWConfig::QueryHNSWConfig(const function::optional_params_t& optionalPara
value.validateType(Efs::TYPE);
efs = value.getValue<int64_t>();
Efs::validate(efs);
} else if (BlindSearchUpSelThreshold::NAME == lowerCaseName) {
value.validateType(BlindSearchUpSelThreshold::TYPE);
blindSearchUpSelThreshold = value.getValue<double>();
BlindSearchUpSelThreshold::validate(blindSearchUpSelThreshold);
} else if (DirectedSearchUpSelThreshold::NAME == lowerCaseName) {
value.validateType(DirectedSearchUpSelThreshold::TYPE);
directedSearchUpSelThreshold = value.getValue<double>();
DirectedSearchUpSelThreshold::validate(directedSearchUpSelThreshold);
} else {
throw common::BinderException{common::stringFormat(
"Unrecognized optional parameter {} in {}.", name, QueryHNSWIndexFunction::name)};
}
}
if (blindSearchUpSelThreshold >= directedSearchUpSelThreshold) {
throw common::BinderException{common::stringFormat(
"Blind search upper selectivity threshold is set to {}, but the directed search upper "
"selectivity threshold is set to {}. The blind search upper selectivity threshold must "
"be less than the directed search upper selectivity threshold.",
blindSearchUpSelThreshold, directedSearchUpSelThreshold)};
}
}

} // namespace vector_extension
Expand Down
Loading