@@ -938,6 +938,71 @@ class TransactionStateManagerTest {
938
938
assertEquals(0 , transactionManager.loadingPartitions.size)
939
939
}
940
940
941
+ private def createEmptyBatch (baseOffset : Long , lastOffset : Long ): MemoryRecords = {
942
+ val buffer = ByteBuffer .allocate(DefaultRecordBatch .RECORD_BATCH_OVERHEAD )
943
+ DefaultRecordBatch .writeEmptyHeader(buffer, RecordBatch .CURRENT_MAGIC_VALUE , RecordBatch .NO_PRODUCER_ID ,
944
+ RecordBatch .NO_PRODUCER_EPOCH , RecordBatch .NO_SEQUENCE , baseOffset, lastOffset, RecordBatch .NO_PARTITION_LEADER_EPOCH ,
945
+ TimestampType .CREATE_TIME , System .currentTimeMillis, false , false )
946
+ buffer.flip
947
+ MemoryRecords .readableRecords(buffer)
948
+ }
949
+
950
+ @ Test
951
+ def testLoadTransactionMetadataContainingSegmentEndingWithEmptyBatch (): Unit = {
952
+ // Simulate a case where a log contains two segments and the first segment ending with an empty batch.
953
+ txnMetadata1.state = PrepareCommit
954
+ txnMetadata1.addPartitions(Set [TopicPartition ](new TopicPartition (" topic1" , 0 )))
955
+ txnMetadata2.state = Ongoing
956
+ txnMetadata2.addPartitions(Set [TopicPartition ](new TopicPartition (" topic2" , 0 )))
957
+
958
+ // Create the first segment which contains two batches.
959
+ // The first batch has one transactional record
960
+ val txnRecords1 = new SimpleRecord (txnMessageKeyBytes1, TransactionLog .valueToBytes(txnMetadata1.prepareNoTransit(), TV_2 ))
961
+ val records1 = MemoryRecords .withRecords(RecordBatch .MAGIC_VALUE_V2 , 0L , Compression .NONE , TimestampType .CREATE_TIME , txnRecords1)
962
+ // The second batch is an empty batch.
963
+ val records2 = createEmptyBatch(1L , 1L )
964
+
965
+ val combinedBuffer = ByteBuffer .allocate(records1.buffer.limit + records2.buffer.limit)
966
+ combinedBuffer.put(records1.buffer)
967
+ combinedBuffer.put(records2.buffer)
968
+ combinedBuffer.flip
969
+ val firstSegmentRecords = MemoryRecords .readableRecords(combinedBuffer)
970
+
971
+ // Create the second segment which contains one batch
972
+ val txnRecords3 = new SimpleRecord (txnMessageKeyBytes2, TransactionLog .valueToBytes(txnMetadata2.prepareNoTransit(), TV_2 ))
973
+ val secondSegmentRecords = MemoryRecords .withRecords(RecordBatch .MAGIC_VALUE_V2 , 2L , Compression .NONE , TimestampType .CREATE_TIME , txnRecords3)
974
+
975
+ // Prepare a txn log
976
+ reset(replicaManager)
977
+
978
+ val logMock = mock(classOf [UnifiedLog ])
979
+ when(replicaManager.getLog(topicPartition)).thenReturn(Some (logMock))
980
+ when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some (3L ))
981
+
982
+ when(logMock.logStartOffset).thenReturn(0L )
983
+ when(logMock.read(ArgumentMatchers .eq(0L ),
984
+ maxLength = anyInt(),
985
+ isolation = ArgumentMatchers .eq(FetchIsolation .LOG_END ),
986
+ minOneMessage = ArgumentMatchers .eq(true )))
987
+ .thenReturn(new FetchDataInfo (new LogOffsetMetadata (0L ), firstSegmentRecords))
988
+ when(logMock.read(ArgumentMatchers .eq(2L ),
989
+ maxLength = anyInt(),
990
+ isolation = ArgumentMatchers .eq(FetchIsolation .LOG_END ),
991
+ minOneMessage = ArgumentMatchers .eq(true )))
992
+ .thenReturn(new FetchDataInfo (new LogOffsetMetadata (2L ), secondSegmentRecords))
993
+
994
+ // Load transactions should not stuck.
995
+ transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 1 , (_, _, _, _) => ())
996
+ assertEquals(0 , transactionManager.loadingPartitions.size)
997
+ assertEquals(1 , transactionManager.transactionMetadataCache.size)
998
+ assertTrue(transactionManager.transactionMetadataCache.contains(partitionId))
999
+ // all transactions should have been loaded
1000
+ val txnMetadataPool = transactionManager.transactionMetadataCache(partitionId).metadataPerTransactionalId
1001
+ assertEquals(2 , txnMetadataPool.size)
1002
+ assertTrue(txnMetadataPool.contains(transactionalId1))
1003
+ assertTrue(txnMetadataPool.contains(transactionalId2))
1004
+ }
1005
+
941
1006
private def verifyMetadataDoesExistAndIsUsable (transactionalId : String ): Unit = {
942
1007
transactionManager.getTransactionState(transactionalId) match {
943
1008
case Left (_) => fail(" shouldn't have been any errors" )
0 commit comments