Skip to content

Commit 0ddf628

Browse files
authored
Parallel CreateHNSWIndex finalize (#4938)
1 parent c031db9 commit 0ddf628

File tree

9 files changed

+301
-216
lines changed

9 files changed

+301
-216
lines changed

src/function/function_collection.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ FunctionCollection* FunctionCollection::getFunctions() {
231231
// TODO(Guodong): Move this from builtin to extension and also move _CreateHNSWIndexFunction
232232
// and _DropHNSWIndexFunction to private functions.
233233
STANDALONE_TABLE_FUNCTION(InternalCreateHNSWIndexFunction),
234+
STANDALONE_TABLE_FUNCTION(InternalFinalizeHNSWIndexFunction),
234235
STANDALONE_TABLE_FUNCTION(CreateHNSWIndexFunction),
235236
STANDALONE_TABLE_FUNCTION(InternalDropHNSWIndexFunction),
236237
STANDALONE_TABLE_FUNCTION(DropHNSWIndexFunction),

src/function/table/hnsw/create_hnsw_index.cpp

Lines changed: 197 additions & 102 deletions
Large diffs are not rendered by default.

src/include/function/table/hnsw/hnsw_index_functions.h

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,40 @@ struct CreateHNSWIndexBindData final : TableFuncBindData {
2929
}
3030
};
3131

32-
struct CreateHNSWSharedState final : TableFuncSharedState {
32+
struct CreateInMemHNSWSharedState final : TableFuncSharedState {
3333
std::string name;
34-
std::unique_ptr<storage::InMemHNSWIndex> hnswIndex;
34+
std::shared_ptr<storage::InMemHNSWIndex> hnswIndex;
3535
storage::NodeTable& nodeTable;
3636
common::offset_t numNodes;
3737
std::atomic<common::offset_t> numNodesInserted = 0;
3838

3939
const CreateHNSWIndexBindData* bindData;
40-
std::shared_ptr<storage::HNSWIndexPartitionerSharedState> partitionerSharedState;
4140

42-
explicit CreateHNSWSharedState(const CreateHNSWIndexBindData& bindData);
41+
explicit CreateInMemHNSWSharedState(const CreateHNSWIndexBindData& bindData);
4342
};
4443

45-
struct CreateHNSWLocalState final : TableFuncLocalState {
44+
struct CreateInMemHNSWLocalState final : TableFuncLocalState {
4645
storage::VisitedState upperVisited;
4746
storage::VisitedState lowerVisited;
4847

49-
explicit CreateHNSWLocalState(common::offset_t numNodes)
48+
explicit CreateInMemHNSWLocalState(common::offset_t numNodes)
5049
: upperVisited{numNodes}, lowerVisited{numNodes} {}
5150
};
5251

52+
struct FinalizeHNSWSharedState final : TableFuncSharedState {
53+
std::shared_ptr<storage::InMemHNSWIndex> hnswIndex;
54+
std::shared_ptr<storage::HNSWIndexPartitionerSharedState> partitionerSharedState;
55+
std::unique_ptr<TableFuncBindData> bindData;
56+
57+
std::atomic<common::node_group_idx_t> numNodeGroupsFinalized = 0;
58+
59+
explicit FinalizeHNSWSharedState(storage::MemoryManager& mm) {
60+
partitionerSharedState = std::make_shared<storage::HNSWIndexPartitionerSharedState>(mm);
61+
}
62+
63+
TableFuncMorsel getMorsel() override;
64+
};
65+
5366
struct BoundQueryHNSWIndexInput {
5467
catalog::NodeTableCatalogEntry* nodeTableEntry;
5568
catalog::IndexCatalogEntry* indexEntry;
@@ -105,6 +118,12 @@ struct InternalCreateHNSWIndexFunction final {
105118
static function_set getFunctionSet();
106119
};
107120

121+
struct InternalFinalizeHNSWIndexFunction final {
122+
static constexpr const char* name = "_FINALIZE_HNSW_INDEX";
123+
124+
static function_set getFunctionSet();
125+
};
126+
108127
struct CreateHNSWIndexFunction final {
109128
static constexpr const char* name = "CREATE_HNSW_INDEX";
110129

src/include/storage/index/hnsw_graph.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class InMemEmbeddings final : public EmbeddingColumn {
3434
const common::LogicalType& columnType);
3535

3636
float* getEmbedding(common::offset_t offset) const;
37-
ColumnChunkData& getData() const { return *data; }
3837

3938
void initialize(main::ClientContext* context, NodeTable& table,
4039
common::column_id_t columnID) override;
@@ -97,10 +96,8 @@ class InMemHNSWGraph {
9796
resetCSRLengthAndDstNodes();
9897
}
9998

100-
common::offset_vec_t getNeighbors(transaction::Transaction* transaction,
101-
common::offset_t nodeOffset) const;
102-
common::offset_vec_t getNeighbors(transaction::Transaction* transaction,
103-
common::offset_t nodeOffset, common::length_t numNbrs) const;
99+
common::offset_vec_t getNeighbors(common::offset_t nodeOffset) const;
100+
common::offset_vec_t getNeighbors(common::offset_t nodeOffset, common::length_t numNbrs) const;
104101

105102
common::length_t getMaxDegree() const { return maxDegree; }
106103

@@ -120,14 +117,14 @@ class InMemHNSWGraph {
120117
dstNodes[csrOffset].store(dstNode, std::memory_order_relaxed);
121118
}
122119

123-
void finalize(MemoryManager& mm,
120+
void finalize(MemoryManager& mm, common::node_group_idx_t nodeGroupIdx,
124121
const processor::PartitionerSharedState& partitionerSharedState);
125122

126123
private:
127124
void resetCSRLengthAndDstNodes();
128125

129126
void finalizeNodeGroup(MemoryManager& mm, common::node_group_idx_t nodeGroupIdx,
130-
common::table_id_t srcNodeTableID, common::table_id_t dstNodeTableID,
127+
uint64_t numRels, common::table_id_t srcNodeTableID, common::table_id_t dstNodeTableID,
131128
common::table_id_t relTableID, InMemChunkedNodeGroupCollection& partition) const;
132129

133130
common::offset_t getDstNode(common::offset_t csrOffset) const {
@@ -142,8 +139,6 @@ class InMemHNSWGraph {
142139
std::atomic<common::offset_t>* dstNodes;
143140
// Max allowed degree of a node in the graph before shrinking.
144141
common::length_t maxDegree;
145-
146-
std::vector<common::length_t> numRelsPerNodeGroup;
147142
};
148143

149144
} // namespace storage

src/include/storage/index/hnsw_index.h

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,18 @@ class InMemHNSWLayer {
107107
void setEntryPoint(common::offset_t offset) { entryPoint.store(offset); }
108108
common::offset_t getEntryPoint() const { return entryPoint.load(); }
109109

110-
void insert(transaction::Transaction* transaction, common::offset_t offset,
111-
common::offset_t entryPoint_, VisitedState& visited);
112-
common::offset_t searchNN(transaction::Transaction* transaction, common::offset_t node,
113-
common::offset_t entryNode) const;
114-
void shrink(transaction::Transaction* transaction);
115-
void finalize(MemoryManager& mm,
110+
void insert(common::offset_t offset, common::offset_t entryPoint_, VisitedState& visited);
111+
common::offset_t searchNN(common::offset_t node, common::offset_t entryNode) const;
112+
void finalize(MemoryManager& mm, common::node_group_idx_t nodeGroupIdx,
116113
const processor::PartitionerSharedState& partitionerSharedState) const;
117114

118115
private:
119-
std::vector<NodeWithDistance> searchKNN(transaction::Transaction* transaction,
120-
const float* queryVector, common::offset_t entryNode, common::length_t k,
121-
uint64_t configuredEf, VisitedState& visited) const;
122-
static void shrinkForNode(transaction::Transaction* transaction, const InMemHNSWLayerInfo& info,
123-
InMemHNSWGraph* graph, common::offset_t nodeOffset, common::length_t numNbrs);
116+
std::vector<NodeWithDistance> searchKNN(const float* queryVector, common::offset_t entryNode,
117+
common::length_t k, uint64_t configuredEf, VisitedState& visited) const;
118+
static void shrinkForNode(const InMemHNSWLayerInfo& info, InMemHNSWGraph* graph,
119+
common::offset_t nodeOffset, common::length_t numNbrs);
124120

125-
void insertRel(transaction::Transaction* transaction, common::offset_t srcNode,
126-
common::offset_t dstNode);
121+
void insertRel(common::offset_t srcNode, common::offset_t dstNode);
127122

128123
private:
129124
std::atomic<common::offset_t> entryPoint;
@@ -145,10 +140,11 @@ class InMemHNSWIndex final : public HNSWIndex {
145140
common::offset_t getLowerEntryPoint() const override { return lowerLayer->getEntryPoint(); }
146141

147142
// Note that the input is only `offset`, as we assume embeddings are already cached in memory.
148-
void insert(common::offset_t offset, transaction::Transaction* transaction,
149-
VisitedState& upperVisited, VisitedState& lowerVisited);
150-
void shrink(transaction::Transaction* transaction);
151-
void finalize(MemoryManager& mm, const HNSWIndexPartitionerSharedState& partitionerSharedState);
143+
void insert(common::offset_t offset, VisitedState& upperVisited, VisitedState& lowerVisited);
144+
void finalize(MemoryManager& mm, common::node_group_idx_t nodeGroupIdx,
145+
const HNSWIndexPartitionerSharedState& partitionerSharedState);
146+
147+
void resetEmbeddings() { embeddings.reset(); }
152148

153149
private:
154150
static constexpr int64_t INSERT_TO_UPPER_LAYER_RAND_UPPER_BOUND = 100;

src/include/storage/index/hnsw_index_utils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ class TableCatalogEntry;
99
} // namespace catalog
1010

1111
namespace storage {
12-
class EmbeddingColumn;
1312

1413
struct HNSWIndexUtils {
14+
15+
static double computeDistance(DistFuncType funcType, const float* left, const float* right,
16+
uint32_t dimension);
17+
1518
static void validateColumnType(const catalog::TableCatalogEntry& tableEntry,
1619
const std::string& columnName);
1720

@@ -29,9 +32,6 @@ struct HNSWIndexUtils {
2932
return &listChunk.getDataColumnChunk()->getData<T>()[offset * dimension];
3033
}
3134

32-
static double computeDistance(DistFuncType funcType, const float* left, const float* right,
33-
uint32_t dimension);
34-
3535
private:
3636
static void validateColumnType(const common::LogicalType& type);
3737
};

src/storage/index/hnsw_graph.cpp

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -87,32 +87,24 @@ float* OnDiskEmbeddings::getEmbedding(transaction::Transaction* transaction,
8787
return reinterpret_cast<float*>(dataVector->getData()) + value.offset;
8888
}
8989

90-
void InMemHNSWGraph::finalize(MemoryManager& mm,
90+
// NOLINTNEXTLINE(readability-make-member-function-const): Semantically non-const function.
91+
void InMemHNSWGraph::finalize(MemoryManager& mm, common::node_group_idx_t nodeGroupIdx,
9192
const processor::PartitionerSharedState& partitionerSharedState) {
9293
const auto& partitionBuffers = partitionerSharedState.partitioningBuffers[0]->partitions;
93-
const auto numNodeGroups = (numNodes + common::StorageConfig::NODE_GROUP_SIZE - 1) /
94-
common::StorageConfig::NODE_GROUP_SIZE;
95-
KU_ASSERT(numNodeGroups == partitionerSharedState.numPartitions[0]);
96-
numRelsPerNodeGroup.resize(numNodeGroups);
97-
for (auto nodeGroupIdx = 0u; nodeGroupIdx < numNodeGroups; nodeGroupIdx++) {
98-
auto numRels = 0u;
99-
const auto startNodeOffset = StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx);
100-
const auto numNodesInGroup =
101-
std::min(common::StorageConfig::NODE_GROUP_SIZE, numNodes - startNodeOffset);
102-
for (auto i = 0u; i < numNodesInGroup; i++) {
103-
numRels += getCSRLength(startNodeOffset + i);
104-
}
105-
numRelsPerNodeGroup[nodeGroupIdx] = numRels;
106-
}
107-
for (auto nodeGroupIdx = 0u; nodeGroupIdx < numNodeGroups; nodeGroupIdx++) {
108-
finalizeNodeGroup(mm, nodeGroupIdx, partitionerSharedState.srcNodeTable->getTableID(),
109-
partitionerSharedState.dstNodeTable->getTableID(),
110-
partitionerSharedState.relTable->getTableID(), *partitionBuffers[nodeGroupIdx]);
94+
auto numRels = 0u;
95+
const auto startNodeOffset = StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx);
96+
const auto numNodesInGroup =
97+
std::min(common::StorageConfig::NODE_GROUP_SIZE, numNodes - startNodeOffset);
98+
for (auto i = 0u; i < numNodesInGroup; i++) {
99+
numRels += getCSRLength(startNodeOffset + i);
111100
}
101+
finalizeNodeGroup(mm, nodeGroupIdx, numRels, partitionerSharedState.srcNodeTable->getTableID(),
102+
partitionerSharedState.dstNodeTable->getTableID(),
103+
partitionerSharedState.relTable->getTableID(), *partitionBuffers[nodeGroupIdx]);
112104
}
113105

114106
void InMemHNSWGraph::finalizeNodeGroup(MemoryManager& mm, common::node_group_idx_t nodeGroupIdx,
115-
common::table_id_t srcNodeTableID, common::table_id_t dstNodeTableID,
107+
uint64_t numRels, common::table_id_t srcNodeTableID, common::table_id_t dstNodeTableID,
116108
common::table_id_t relTableID, InMemChunkedNodeGroupCollection& partition) const {
117109
const auto startNodeOffset = StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx);
118110
const auto numNodesInGroup =
@@ -122,10 +114,8 @@ void InMemHNSWGraph::finalizeNodeGroup(MemoryManager& mm, common::node_group_idx
122114
columnTypes.push_back(common::LogicalType::INTERNAL_ID());
123115
columnTypes.push_back(common::LogicalType::INTERNAL_ID());
124116
columnTypes.push_back(common::LogicalType::INTERNAL_ID());
125-
auto numRelsInGroup = numRelsPerNodeGroup[nodeGroupIdx];
126-
auto chunkedNodeGroup =
127-
std::make_unique<ChunkedNodeGroup>(mm, columnTypes, false /* enableCompression */,
128-
numRelsInGroup, 0 /* startRowIdx */, ResidencyState::IN_MEMORY);
117+
auto chunkedNodeGroup = std::make_unique<ChunkedNodeGroup>(mm, columnTypes,
118+
false /* enableCompression */, numRels, 0 /* startRowIdx */, ResidencyState::IN_MEMORY);
129119

130120
auto currNumRels = 0u;
131121
auto& boundColumnChunk = chunkedNodeGroup->getColumnChunk(0).getData();
@@ -158,23 +148,22 @@ void InMemHNSWGraph::finalizeNodeGroup(MemoryManager& mm, common::node_group_idx
158148
partition.merge(std::move(chunkedNodeGroup));
159149
}
160150

161-
common::offset_vec_t InMemHNSWGraph::getNeighbors(transaction::Transaction* transaction,
162-
common::offset_t nodeOffset) const {
151+
common::offset_vec_t InMemHNSWGraph::getNeighbors(common::offset_t nodeOffset) const {
163152
const auto numNbrs = getCSRLength(nodeOffset);
164-
return getNeighbors(transaction, nodeOffset, numNbrs);
153+
return getNeighbors(nodeOffset, numNbrs);
165154
}
166155

167-
common::offset_vec_t InMemHNSWGraph::getNeighbors(transaction::Transaction*,
168-
common::offset_t nodeOffset, common::length_t numNbrs) const {
156+
common::offset_vec_t InMemHNSWGraph::getNeighbors(common::offset_t nodeOffset,
157+
common::length_t numNbrs) const {
169158
common::offset_vec_t neighbors;
170159
neighbors.reserve(numNbrs);
171160
for (common::offset_t i = 0; i < numNbrs; i++) {
172161
auto nbr = getDstNode(nodeOffset * maxDegree + i);
173162
// Note: we might have INVALID_OFFSET at the end of the array of neighbor nodes. This is due
174-
// to that when we append a new neighbor node to node x, we don't exclusively lock the x,
175-
// instead, we increment the csrLength first, then set the dstNode. This design eases lock
176-
// contentions. However, if this function (`getNeighbors`) is called before the dstNode is
177-
// set, we will get INVALID_OFFSET. As csrLength is always synchorized, this design
163+
// to that when we append a new neighbor node to node x, we don't exclusively lock the
164+
// x, instead, we increment the csrLength first, then set the dstNode. This design eases
165+
// lock contentions. However, if this function (`getNeighbors`) is called before the dstNode
166+
// is set, we will get INVALID_OFFSET. As csrLength is always synchorized, this design
178167
// shouldn't have correctness issue.
179168
if (nbr == common::INVALID_OFFSET) {
180169
continue;

0 commit comments

Comments
 (0)