Skip to content

Commit 1368e0d

Browse files
authored
feat: Add Arrow IPC conversion for RecordBatches (#5143)
## Changes Made I need it for both a process UDF PR and for the dashboard stuff, so just moved it to this PR to rebase the others on top of it. Lmk if you want me to add a test for this specifically, but I do test it E2E in the others.
1 parent 2e7a6f7 commit 1368e0d

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

src/daft-recordbatch/src/lib.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::{
44
collections::{HashMap, HashSet},
55
fmt::{Display, Formatter, Result},
66
hash::{Hash, Hasher},
7+
io::Cursor,
78
sync::Arc,
89
};
910

@@ -1068,6 +1069,44 @@ impl RecordBatch {
10681069
pub fn to_chunk(&self) -> Chunk<Box<dyn Array>> {
10691070
Chunk::new(self.columns.iter().map(|s| s.to_arrow()).collect())
10701071
}
1072+
1073+
pub fn to_ipc_stream(&self) -> DaftResult<Vec<u8>> {
1074+
let buffer = Vec::with_capacity(self.size_bytes());
1075+
let schema = self.schema.to_arrow()?;
1076+
let options = arrow2::io::ipc::write::WriteOptions { compression: None };
1077+
let mut writer = arrow2::io::ipc::write::StreamWriter::new(buffer, options);
1078+
writer.start(&schema, None)?;
1079+
1080+
let chunk = self.to_chunk();
1081+
writer.write(&chunk, None)?;
1082+
1083+
writer.finish()?;
1084+
let mut finished_buffer = writer.into_inner();
1085+
finished_buffer.shrink_to_fit();
1086+
Ok(finished_buffer)
1087+
}
1088+
1089+
pub fn from_ipc_stream(buffer: &[u8]) -> DaftResult<Self> {
1090+
let mut cursor = Cursor::new(buffer);
1091+
let stream_metadata = arrow2::io::ipc::read::read_stream_metadata(&mut cursor).unwrap();
1092+
let schema = Arc::new(Schema::from(stream_metadata.schema.clone()));
1093+
let reader = arrow2::io::ipc::read::StreamReader::new(cursor, stream_metadata, None);
1094+
1095+
let mut tables = reader
1096+
.into_iter()
1097+
.map(|state| {
1098+
let state = state?;
1099+
let arrow_chunk = match state {
1100+
arrow2::io::ipc::read::StreamState::Some(chunk) => chunk,
1101+
_ => panic!("State should not be waiting when reading from IPC buffer"),
1102+
};
1103+
Self::from_arrow(schema.clone(), arrow_chunk.into_arrays())
1104+
})
1105+
.collect::<DaftResult<Vec<_>>>()?;
1106+
1107+
assert_eq!(tables.len(), 1);
1108+
Ok(tables.pop().expect("Expected exactly one table"))
1109+
}
10711110
}
10721111

10731112
#[cfg(feature = "arrow")]

src/daft-recordbatch/src/python.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use daft_dsl::{
1212
python::PyExpr,
1313
};
1414
use indexmap::IndexMap;
15-
use pyo3::{exceptions::PyValueError, prelude::*};
15+
use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes};
1616

1717
use crate::{
1818
RecordBatch, ffi,
@@ -536,6 +536,19 @@ impl PyRecordBatch {
536536
let table: RecordBatch = file_infos.try_into()?;
537537
Ok(table.into())
538538
}
539+
540+
#[staticmethod]
541+
pub fn from_ipc_stream(bytes: Bound<'_, PyBytes>, py: Python) -> PyResult<Self> {
542+
let buffer = bytes.as_bytes();
543+
let record_batch = py.allow_threads(|| RecordBatch::from_ipc_stream(buffer))?;
544+
Ok(record_batch.into())
545+
}
546+
547+
pub fn to_ipc_stream<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyBytes>> {
548+
let buffer = py.allow_threads(|| self.record_batch.to_ipc_stream())?;
549+
let bytes = PyBytes::new(py, &buffer);
550+
Ok(bytes)
551+
}
539552
}
540553

541554
impl From<RecordBatch> for PyRecordBatch {

0 commit comments

Comments
 (0)