@@ -941,6 +941,71 @@ class TransactionStateManagerTest {
941
941
assertEquals(0 , transactionManager.loadingPartitions.size)
942
942
}
943
943
944
+ private def createEmptyBatch (baseOffset : Long , lastOffset : Long ): MemoryRecords = {
945
+ val buffer = ByteBuffer .allocate(DefaultRecordBatch .RECORD_BATCH_OVERHEAD )
946
+ DefaultRecordBatch .writeEmptyHeader(buffer, RecordBatch .CURRENT_MAGIC_VALUE , RecordBatch .NO_PRODUCER_ID ,
947
+ RecordBatch .NO_PRODUCER_EPOCH , RecordBatch .NO_SEQUENCE , baseOffset, lastOffset, RecordBatch .NO_PARTITION_LEADER_EPOCH ,
948
+ TimestampType .CREATE_TIME , System .currentTimeMillis, false , false )
949
+ buffer.flip
950
+ MemoryRecords .readableRecords(buffer)
951
+ }
952
+
953
+ @ Test
954
+ def testLoadTransactionMetadataContainingSegmentEndingWithEmptyBatch (): Unit = {
955
+ // Simulate a case where a log contains two segments and the first segment ending with an empty batch.
956
+ txnMetadata1.state = PrepareCommit
957
+ txnMetadata1.addPartitions(Set [TopicPartition ](new TopicPartition (" topic1" , 0 )))
958
+ txnMetadata2.state = Ongoing
959
+ txnMetadata2.addPartitions(Set [TopicPartition ](new TopicPartition (" topic2" , 0 )))
960
+
961
+ // Create the first segment which contains two batches.
962
+ // The first batch has one transactional record
963
+ val txnRecords1 = new SimpleRecord (txnMessageKeyBytes1, TransactionLog .valueToBytes(txnMetadata1.prepareNoTransit(), true ))
964
+ val records1 = MemoryRecords .withRecords(RecordBatch .MAGIC_VALUE_V2 , 0L , Compression .NONE , TimestampType .CREATE_TIME , txnRecords1)
965
+ // The second batch is an empty batch.
966
+ val records2 = createEmptyBatch(1L , 1L )
967
+
968
+ val combinedBuffer = ByteBuffer .allocate(records1.buffer.limit + records2.buffer.limit)
969
+ combinedBuffer.put(records1.buffer)
970
+ combinedBuffer.put(records2.buffer)
971
+ combinedBuffer.flip
972
+ val firstSegmentRecords = MemoryRecords .readableRecords(combinedBuffer)
973
+
974
+ // Create the second segment which contains one batch
975
+ val txnRecords3 = new SimpleRecord (txnMessageKeyBytes2, TransactionLog .valueToBytes(txnMetadata2.prepareNoTransit(), true ))
976
+ val secondSegmentRecords = MemoryRecords .withRecords(RecordBatch .MAGIC_VALUE_V2 , 2L , Compression .NONE , TimestampType .CREATE_TIME , txnRecords3)
977
+
978
+ // Prepare a txn log
979
+ reset(replicaManager)
980
+
981
+ val logMock = mock(classOf [UnifiedLog ])
982
+ when(replicaManager.getLog(topicPartition)).thenReturn(Some (logMock))
983
+ when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some (3L ))
984
+
985
+ when(logMock.logStartOffset).thenReturn(0L )
986
+ when(logMock.read(ArgumentMatchers .eq(0L ),
987
+ maxLength = anyInt(),
988
+ isolation = ArgumentMatchers .eq(FetchIsolation .LOG_END ),
989
+ minOneMessage = ArgumentMatchers .eq(true )))
990
+ .thenReturn(new FetchDataInfo (new LogOffsetMetadata (0L ), firstSegmentRecords))
991
+ when(logMock.read(ArgumentMatchers .eq(2L ),
992
+ maxLength = anyInt(),
993
+ isolation = ArgumentMatchers .eq(FetchIsolation .LOG_END ),
994
+ minOneMessage = ArgumentMatchers .eq(true )))
995
+ .thenReturn(new FetchDataInfo (new LogOffsetMetadata (2L ), secondSegmentRecords))
996
+
997
+ // Load transactions should not stuck.
998
+ transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 1 , (_, _, _, _) => ())
999
+ assertEquals(0 , transactionManager.loadingPartitions.size)
1000
+ assertEquals(1 , transactionManager.transactionMetadataCache.size)
1001
+ assertTrue(transactionManager.transactionMetadataCache.contains(partitionId))
1002
+ // all transactions should have been loaded
1003
+ val txnMetadataPool = transactionManager.transactionMetadataCache(partitionId).metadataPerTransactionalId
1004
+ assertEquals(2 , txnMetadataPool.size)
1005
+ assertTrue(txnMetadataPool.contains(transactionalId1))
1006
+ assertTrue(txnMetadataPool.contains(transactionalId2))
1007
+ }
1008
+
944
1009
private def verifyMetadataDoesExistAndIsUsable (transactionalId : String ): Unit = {
945
1010
transactionManager.getTransactionState(transactionalId) match {
946
1011
case Left (_) => fail(" shouldn't have been any errors" )
0 commit comments