diff --git a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReading.kt b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReading.kt index 0ddcdf3fb9..dac1fe680e 100644 --- a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReading.kt +++ b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReading.kt @@ -1,6 +1,7 @@ package org.jetbrains.kotlinx.dataframe.io import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.ipc.ArrowReader import org.apache.commons.compress.utils.SeekableInMemoryByteChannel import org.jetbrains.kotlinx.dataframe.AnyFrame import org.jetbrains.kotlinx.dataframe.DataFrame @@ -170,3 +171,18 @@ public fun DataFrame.Companion.readArrowFeather( } else { readArrowFeather(File(path), nullability) } + +/** + * Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [reader] + */ +public fun DataFrame.Companion.readArrow( + reader: ArrowReader, + nullability: NullabilityOptions = NullabilityOptions.Infer +): AnyFrame = readArrowImpl(reader, nullability) + +/** + * Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [ArrowReader] + */ +public fun ArrowReader.toDataFrame( + nullability: NullabilityOptions = NullabilityOptions.Infer +): AnyFrame = DataFrame.Companion.readArrowImpl(this, nullability) diff --git a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt index 5c09295179..d7d67b6a87 100644 --- a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt +++ b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt @@ -32,6 +32,7 @@ import org.apache.arrow.vector.VarCharVector import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.complex.StructVector import org.apache.arrow.vector.ipc.ArrowFileReader +import org.apache.arrow.vector.ipc.ArrowReader import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.arrow.vector.types.pojo.Field import org.apache.arrow.vector.util.DateUtility @@ -262,17 +263,7 @@ internal fun DataFrame.Companion.readArrowIPCImpl( allocator: RootAllocator = Allocator.ROOT, nullability: NullabilityOptions = NullabilityOptions.Infer, ): AnyFrame { - ArrowStreamReader(channel, allocator).use { reader -> - val flattened = buildList { - val root = reader.vectorSchemaRoot - val schema = root.schema - while (reader.loadNextBatch()) { - val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame() - add(df) - } - } - return flattened.concatKeepingSchema() - } + return readArrowImpl(ArrowStreamReader(channel, allocator), nullability) } /** @@ -283,14 +274,36 @@ internal fun DataFrame.Companion.readArrowFeatherImpl( allocator: RootAllocator = Allocator.ROOT, nullability: NullabilityOptions = NullabilityOptions.Infer, ): AnyFrame { - ArrowFileReader(channel, allocator).use { reader -> + return readArrowImpl(ArrowFileReader(channel, allocator), nullability) +} + +/** + * Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [reader] + */ +internal fun DataFrame.Companion.readArrowImpl( + reader: ArrowReader, + nullability: NullabilityOptions = NullabilityOptions.Infer +): AnyFrame { + reader.use { val flattened = buildList { - reader.recordBlocks.forEach { block -> - reader.loadRecordBatch(block) - val root = reader.vectorSchemaRoot - val schema = root.schema - val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame() - add(df) + when (reader) { + is ArrowFileReader -> { + reader.recordBlocks.forEach { block -> + reader.loadRecordBatch(block) + val root = reader.vectorSchemaRoot + val schema = root.schema + val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame() + add(df) + } + } + is ArrowStreamReader -> { + val root = reader.vectorSchemaRoot + val schema = root.schema + while (reader.loadNextBatch()) { + val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame() + add(df) + } + } } } return flattened.concatKeepingSchema() diff --git a/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt b/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt index 500780ec54..513a46b21d 100644 --- a/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt +++ b/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt @@ -9,7 +9,9 @@ import org.apache.arrow.vector.TimeStampMilliVector import org.apache.arrow.vector.TimeStampNanoVector import org.apache.arrow.vector.TimeStampSecVector import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowFileReader import org.apache.arrow.vector.ipc.ArrowFileWriter +import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.arrow.vector.types.FloatingPointPrecision import org.apache.arrow.vector.types.TimeUnit @@ -17,6 +19,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.arrow.vector.types.pojo.Field import org.apache.arrow.vector.types.pojo.FieldType import org.apache.arrow.vector.types.pojo.Schema +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.arrow.vector.util.Text import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.DataFrame @@ -32,6 +35,7 @@ import org.jetbrains.kotlinx.dataframe.api.remove import org.jetbrains.kotlinx.dataframe.api.toColumn import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException import org.junit.Test +import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.io.File import java.net.URL @@ -553,4 +557,30 @@ internal class ArrowKtTest { } } } + + @Test + fun testArrowReaderExtension() { + val dates = listOf( + LocalDateTime.of(2023, 11, 23, 9, 30, 25), + LocalDateTime.of(2015, 5, 25, 14, 20, 13), + LocalDateTime.of(2013, 6, 19, 11, 20, 13), + LocalDateTime.of(2000, 1, 1, 0, 0, 0) + ) + + val expected = dataFrameOf( + "string" to listOf("a", "b", "c", "d"), + "int" to listOf(1, 2, 3, 4), + "float" to listOf(1.0f, 2.0f, 3.0f, 4.0f), + "double" to listOf(1.0, 2.0, 3.0, 4.0), + "datetime" to dates + ) + + val featherChannel = ByteArrayReadableSeekableByteChannel(expected.saveArrowFeatherToByteArray()) + val arrowFileReader = ArrowFileReader(featherChannel, RootAllocator()) + arrowFileReader.toDataFrame() shouldBe expected + + val ipcInputStream = ByteArrayInputStream(expected.saveArrowIPCToByteArray()) + val arrowStreamReader = ArrowStreamReader(ipcInputStream, RootAllocator()) + arrowStreamReader.toDataFrame() shouldBe expected + } } diff --git a/docs/StardustDocs/topics/read.md b/docs/StardustDocs/topics/read.md index e9e056d718..e732ce25da 100644 --- a/docs/StardustDocs/topics/read.md +++ b/docs/StardustDocs/topics/read.md @@ -445,7 +445,7 @@ val df = DataFrame.readArrowFeather(file) [`DataFrame`](DataFrame.md) supports reading [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) and [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) -from raw Channel (ReadableByteChannel for streaming and SeekableByteChannel for random access), InputStream, File or ByteArray. +from raw Channel (ReadableByteChannel for streaming and SeekableByteChannel for random access), ArrowReader, InputStream, File or ByteArray. > If you use Java 9+, follow the [Apache Arrow Java compatibility](https://arrow.apache.org/docs/java/install.html#java-compatibility) guide. >