diff --git a/src/include/storage/table/update_info.h b/src/include/storage/table/update_info.h index 7d36a5d3981..720675fa03e 100644 --- a/src/include/storage/table/update_info.h +++ b/src/include/storage/table/update_info.h @@ -53,6 +53,7 @@ struct UpdateNode { UpdateNode() : info{nullptr} {} UpdateNode(UpdateNode&& other) noexcept : info{std::move(other.info)} {} + UpdateNode(const UpdateNode& other) = delete; bool isEmpty() const { std::shared_lock lock{mtx}; @@ -74,7 +75,7 @@ class UpdateInfo { void clearVectorInfo(common::idx_t vectorIdx) { std::unique_lock lock{mtx}; - updates[vectorIdx].clear(); + updates[vectorIdx]->clear(); } common::idx_t getNumVectors() const { @@ -95,7 +96,7 @@ class UpdateInfo { const std::function& func) const; void commit(common::idx_t vectorIdx, VectorUpdateInfo* info, common::transaction_t commitTS); - void rollback(common::idx_t vectorIdx, VectorUpdateInfo* info); + void rollback(common::idx_t vectorIdx, common::transaction_t version); common::row_idx_t getNumUpdatedRows(const transaction::Transaction* transaction) const; @@ -124,7 +125,7 @@ class UpdateInfo { private: mutable std::shared_mutex mtx; - std::vector updates; + std::vector> updates; }; } // namespace storage diff --git a/src/include/storage/undo_buffer.h b/src/include/storage/undo_buffer.h index 0fdcf0cb97b..8999808d12c 100644 --- a/src/include/storage/undo_buffer.h +++ b/src/include/storage/undo_buffer.h @@ -88,7 +88,7 @@ class UndoBuffer { void createDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, const VersionRecordHandler* versionRecordHandler); void createVectorUpdateInfo(UpdateInfo* updateInfo, common::idx_t vectorIdx, - VectorUpdateInfo* vectorUpdateInfo); + VectorUpdateInfo* vectorUpdateInfo, common::transaction_t version); void commit(common::transaction_t commitTS) const; void rollback(main::ClientContext* context) const; diff --git a/src/include/transaction/transaction.h b/src/include/transaction/transaction.h index d3b6c6b4f7c..d6a181b0b95 100644 --- a/src/include/transaction/transaction.h +++ b/src/include/transaction/transaction.h @@ -141,7 +141,7 @@ class KUZU_API Transaction { void pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common::row_idx_t startRow, common::row_idx_t numRows, const storage::VersionRecordHandler* versionRecordHandler) const; void pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, common::idx_t vectorIdx, - storage::VectorUpdateInfo& vectorUpdateInfo) const; + storage::VectorUpdateInfo& vectorUpdateInfo, common::transaction_t version) const; static Transaction* Get(const main::ClientContext& context); diff --git a/src/storage/table/column_chunk.cpp b/src/storage/table/column_chunk.cpp index 3a0374431ca..b98bdecf5c6 100644 --- a/src/storage/table/column_chunk.cpp +++ b/src/storage/table/column_chunk.cpp @@ -113,7 +113,8 @@ void ColumnChunk::update(const Transaction* transaction, offset_t offsetInChunk, const auto rowIdxInVector = offsetInChunk % DEFAULT_VECTOR_CAPACITY; auto& vectorUpdateInfo = updateInfo.update(data->getMemoryManager(), transaction, vectorIdx, rowIdxInVector, values); - transaction->pushVectorUpdateInfo(updateInfo, vectorIdx, vectorUpdateInfo); + transaction->pushVectorUpdateInfo(updateInfo, vectorIdx, vectorUpdateInfo, + transaction->getID()); } MergedColumnChunkStats ColumnChunk::getMergedColumnChunkStats() const { diff --git a/src/storage/table/update_info.cpp b/src/storage/table/update_info.cpp index a59d5d94641..94c3eea528d 100644 --- a/src/storage/table/update_info.cpp +++ b/src/storage/table/update_info.cpp @@ -116,10 +116,10 @@ void UpdateInfo::iterateVectorInfo(const Transaction* transaction, idx_t idx, const UpdateNode* head = nullptr; { std::shared_lock lock{mtx}; - if (idx >= updates.size() || !updates[idx].isEmpty()) { + if (idx >= updates.size() || !updates[idx]->isEmpty()) { return; } - head = &updates[idx]; + head = updates[idx].get(); } // We lock the head of the chain to ensure that we can safely read from any part of the // chain. @@ -161,14 +161,14 @@ void UpdateInfo::commit(idx_t vectorIdx, VectorUpdateInfo* info, transaction_t c info->version = commitTS; } -void UpdateInfo::rollback(idx_t vectorIdx, VectorUpdateInfo* info) { +void UpdateInfo::rollback(idx_t vectorIdx, transaction_t version) { UpdateNode* header = nullptr; // Note that we lock the entire UpdateInfo structure here because we might modify the // head of the version chain. This is just a simplification and should be optimized later. { std::unique_lock lock{mtx}; KU_ASSERT(updates.size() > vectorIdx); - header = &updates[vectorIdx]; + header = updates[vectorIdx].get(); } KU_ASSERT(header); std::unique_lock chainLock{header->mtx}; @@ -177,24 +177,26 @@ void UpdateInfo::rollback(idx_t vectorIdx, VectorUpdateInfo* info) { // TODO(Guodong): This will be optimized by moving VectorUpdateInfo into UndoBuffer. auto current = header->info.get(); while (current) { - if (current != info) { - current = current->getPrev(); - continue; - } - if (info->next) { - // Has newer version. Remove this from the version chain. - const auto newerVersion = info->next; - auto prevVersion = info->movePrev(); - if (prevVersion) { - prevVersion->next = newerVersion; + if (current->version == version) { + auto prevVersion = current->movePrev(); + if (current->next) { + // Has newer version. Remove this from the version chain. + const auto newerVersion = current->next; + if (prevVersion) { + prevVersion->next = newerVersion; + } + newerVersion->setPrev(std::move(prevVersion)); + } else { + KU_ASSERT(header->info.get() == current); + // This is the beginning of the version chain. + if (prevVersion) { + prevVersion->next = nullptr; + } + header->info = std::move(prevVersion); } - newerVersion->setPrev(std::move(prevVersion)); - } else { - KU_ASSERT(header->info.get() == info); - // This is the beginning of the version chain. - header->info = std::move(info->prev); + break; } - break; + current = current->getPrev(); } } @@ -224,15 +226,20 @@ UpdateNode& UpdateInfo::getUpdateNode(idx_t vectorIdx) { throw InternalException( "UpdateInfo does not have update node for vector index: " + std::to_string(vectorIdx)); } - return updates[vectorIdx]; + return *updates[vectorIdx]; } UpdateNode& UpdateInfo::getOrCreateUpdateNode(idx_t vectorIdx) { std::unique_lock lock{mtx}; if (vectorIdx >= updates.size()) { updates.resize(vectorIdx + 1); + for (auto i = 0u; i < updates.size(); i++) { + if (!updates[i]) { + updates[i] = std::make_unique(); + } + } } - return updates[vectorIdx]; + return *updates[vectorIdx]; } void UpdateInfo::iterateScan(const Transaction* transaction, uint64_t startOffsetToScan, diff --git a/src/storage/undo_buffer.cpp b/src/storage/undo_buffer.cpp index 5ee0dcb781d..2ae40e1f72a 100644 --- a/src/storage/undo_buffer.cpp +++ b/src/storage/undo_buffer.cpp @@ -49,6 +49,7 @@ struct VectorUpdateRecord { UpdateInfo* updateInfo; idx_t vectorIdx; VectorUpdateInfo* vectorUpdateInfo; + transaction_t version; // This is used during roll back. }; template @@ -135,12 +136,12 @@ void UndoBuffer::createVersionInfo(const UndoRecordType recordType, row_idx_t st } void UndoBuffer::createVectorUpdateInfo(UpdateInfo* updateInfo, const idx_t vectorIdx, - VectorUpdateInfo* vectorUpdateInfo) { + VectorUpdateInfo* vectorUpdateInfo, transaction_t version) { auto buffer = createUndoRecord(sizeof(UndoRecordHeader) + sizeof(VectorUpdateRecord)); const UndoRecordHeader recordHeader{UndoRecordType::UPDATE_INFO, sizeof(VectorUpdateRecord)}; *reinterpret_cast(buffer) = recordHeader; buffer += sizeof(UndoRecordHeader); - const VectorUpdateRecord vectorUpdateRecord{updateInfo, vectorIdx, vectorUpdateInfo}; + const VectorUpdateRecord vectorUpdateRecord{updateInfo, vectorIdx, vectorUpdateInfo, version}; *reinterpret_cast(buffer) = vectorUpdateRecord; } @@ -301,7 +302,7 @@ void UndoBuffer::rollbackVersionInfo(ClientContext* context, UndoRecordType reco void UndoBuffer::rollbackVectorUpdateInfo(const uint8_t* record) { auto& undoRecord = *reinterpret_cast(record); KU_ASSERT(undoRecord.updateInfo); - undoRecord.updateInfo->rollback(undoRecord.vectorIdx, undoRecord.vectorUpdateInfo); + undoRecord.updateInfo->rollback(undoRecord.vectorIdx, undoRecord.version); } } // namespace storage diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index da1f78ee4ad..52819b73447 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -194,8 +194,9 @@ void Transaction::pushDeleteInfo(common::node_group_idx_t nodeGroupIdx, common:: } void Transaction::pushVectorUpdateInfo(storage::UpdateInfo& updateInfo, - const common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo) const { - undoBuffer->createVectorUpdateInfo(&updateInfo, vectorIdx, &vectorUpdateInfo); + const common::idx_t vectorIdx, storage::VectorUpdateInfo& vectorUpdateInfo, + common::transaction_t version) const { + undoBuffer->createVectorUpdateInfo(&updateInfo, vectorIdx, &vectorUpdateInfo, version); } Transaction::~Transaction() = default; diff --git a/test/test_files/transaction/concurrency/dml_empty_serial_execution.test b/test/test_files/transaction/concurrency/dml_empty_serial_execution.test index c093fc11720..c2e1c9ea3a9 100644 --- a/test/test_files/transaction/concurrency/dml_empty_serial_execution.test +++ b/test/test_files/transaction/concurrency/dml_empty_serial_execution.test @@ -64,6 +64,87 @@ Runtime exception: Write-write conflict of updating the same row. ---- 1 21 +-CASE NodeUpdatesRollback +-STATEMENT CALL debug_enable_multi_writes=true; +---- ok +-STATEMENT CALL auto_checkpoint=false; +---- ok +-CREATE_CONNECTION conn2 +-STATEMENT CREATE NODE TABLE test(id INT64, val INT64, PRIMARY KEY(id)); +---- ok +-STATEMENT COPY test FROM (UNWIND RANGE(0, 1000) AS id RETURN id, id + 1 AS val); +---- ok +-STATEMENT BEGIN TRANSACTION; +---- ok +-LOOP i 1 100 +-STATEMENT MATCH (p:test) WHERE p.ID=${i} SET p.val=p.val+1000; +---- ok +-ENDLOOP +-STATEMENT [conn2] BEGIN TRANSACTION; +---- ok +-LOOP i 101 200 +-STATEMENT [conn2] MATCH (p:test) WHERE p.ID=${i} SET p.val=p.val+2000; +---- ok +-ENDLOOP +-STATEMENT [conn2] ROLLBACK; +---- ok +-STATEMENT ROLLBACK; +---- ok +-STATEMENT MATCH (p:test) WHERE p.ID>1000 RETURN COUNT(*); +---- 1 +0 + +-CASE NodeUpdatesMixedCommitAndRollback +-STATEMENT CALL debug_enable_multi_writes=true; +---- ok +-STATEMENT CALL auto_checkpoint=false; +---- ok +-CREATE_CONNECTION conn2 +-CREATE_CONNECTION conn3 +-CREATE_CONNECTION conn4 +-STATEMENT CREATE NODE TABLE test(id INT64, val INT64, PRIMARY KEY(id)); +---- ok +-STATEMENT COPY test FROM (UNWIND RANGE(0, 1000) AS id RETURN id, id + 1 AS val); +---- ok +-STATEMENT BEGIN TRANSACTION; +---- ok +-LOOP i 1 100 +-STATEMENT MATCH (p:test) WHERE p.ID=${i} SET p.val=p.val+1000; +---- ok +-ENDLOOP +-STATEMENT [conn2] BEGIN TRANSACTION; +---- ok +-LOOP i 101 200 +-STATEMENT [conn2] MATCH (p:test) WHERE p.ID=${i} SET p.val=p.val+2000; +---- ok +-ENDLOOP +-STATEMENT [conn3] BEGIN TRANSACTION; +---- ok +-LOOP i 201 300 +-STATEMENT [conn3] MATCH (p:test) WHERE p.ID=${i} SET p.val=p.val+3000; +---- ok +-ENDLOOP +-STATEMENT [conn4] BEGIN TRANSACTION; +---- ok +-LOOP i 301 400 +-STATEMENT [conn4] MATCH (p:test) WHERE p.ID=${i} SET p.val=p.val+4000; +---- ok +-ENDLOOP +-STATEMENT [conn4] COMMIT; +---- ok +-STATEMENT [conn2] ROLLBACK; +---- ok +-STATEMENT ROLLBACK; +---- ok +-STATEMENT [conn3] COMMIT; +---- ok +-STATEMENT MATCH (p:test) WHERE p.val>3000 RETURN COUNT(*); +---- 1 +200 +-STATEMENT MATCH (p:test) WHERE p.val>4000 RETURN COUNT(*); +---- 1 +100 + -CASE WWConflictNodeCopyDelete -STATEMENT CALL debug_enable_multi_writes=true; ---- ok