Skip to content

Commit 98231bb

Browse files
authored
feat: Base64 Encoding (#5158)
1 parent 71d75f8 commit 98231bb

File tree

8 files changed

+50
-10
lines changed

8 files changed

+50
-10
lines changed

.github/workflows/pr-test-suite.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ jobs:
12541254
- name: Python And Rust Style Check
12551255
run: |
12561256
source .venv/bin/activate
1257-
pre-commit run --all-files
1257+
pre-commit run --all-files -v
12581258
12591259
- name: Send Slack notification on failure
12601260
uses: slackapi/[email protected]

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ test: .venv build ## Run tests
7171
HYPOTHESIS_MAX_EXAMPLES=$(HYPOTHESIS_MAX_EXAMPLES) $(VENV_BIN)/pytest --hypothesis-seed=$(HYPOTHESIS_SEED) --ignore tests/integration $(EXTRA_ARGS)
7272

7373
.PHONY: doctests
74-
doctests:
75-
DAFT_BOLD_TABLE_HEADERS=0 pytest --doctest-modules --continue-on-collection-errors --ignore=daft/functions/llm.py daft/dataframe/dataframe.py daft/expressions/expressions.py daft/convert.py daft/udf/__init__.py daft/functions/ daft/datatype.py
74+
doctests: .venv
75+
DAFT_BOLD_TABLE_HEADERS=0 DAFT_PROGRESS_BAR=0 $(VENV_BIN)/pytest --doctest-modules --continue-on-collection-errors --ignore=daft/functions/llm.py daft/dataframe/dataframe.py daft/expressions/expressions.py daft/convert.py daft/udf/__init__.py daft/functions/ daft/datatype.py
7676

7777
.PHONY: dsdgen
7878
dsdgen: .venv ## Generate TPC-DS data

daft/expressions/expressions.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from daft.udf.legacy import BoundUDFArgs, InitArgsType, UninitializedUdf
4848
from daft.window import Window
4949

50-
EncodingCodec = Literal["deflate", "gzip", "gz", "utf-8", "utf8" "zlib"]
50+
EncodingCodec = Literal["base64", "deflate", "gzip", "gz", "utf-8", "utf8", "zlib"]
5151

5252

5353
def lit(value: object) -> Expression:
@@ -1595,7 +1595,7 @@ def encode(self, codec: EncodingCodec) -> Expression:
15951595
r"""Encodes the expression (binary strings) using the specified codec.
15961596
15971597
Args:
1598-
codec (str): encoding codec (deflate, gzip, zlib)
1598+
codec (str): encoding codec (base64, deflate, gzip, zlib)
15991599
16001600
Returns:
16011601
Expression: A new expression, of type `binary`, with the encoded value.
@@ -1641,7 +1641,7 @@ def decode(self, codec: EncodingCodec) -> Expression:
16411641
"""Decodes the expression (binary strings) using the specified codec.
16421642
16431643
Args:
1644-
codec (str): decoding codec (deflate, gzip, zlib)
1644+
codec (str): decoding codec (base64, deflate, gzip, zlib)
16451645
16461646
Returns:
16471647
Expression: A new expression with the decoded values.
@@ -1651,6 +1651,20 @@ def decode(self, codec: EncodingCodec) -> Expression:
16511651
only decoding with 'utf-8' returns a string.
16521652
16531653
Examples:
1654+
>>> import daft
1655+
>>> from daft import col
1656+
>>> df = daft.from_pydict({"bytes": [b"aGVsbG8sIHdvcmxkIQ=="]})
1657+
>>> df.select(col("bytes").decode("base64")).show()
1658+
╭──────────────────╮
1659+
│ bytes │
1660+
│ --- │
1661+
│ Binary │
1662+
╞══════════════════╡
1663+
│ b"hello, world!" │
1664+
╰──────────────────╯
1665+
<BLANKLINE>
1666+
(Showing first 1 of 1 rows)
1667+
16541668
>>> import daft
16551669
>>> import zlib
16561670
>>> from daft import col

src/daft-functions-binary/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[dependencies]
22
arrow2 = {workspace = true}
3+
base64 = {workspace = true}
34
common-error = {path = "../common/error", default-features = false}
45
common-macros = {path = "../common/macros"}
56
daft-core = {path = "../daft-core", default-features = false}

src/daft-functions-binary/src/codecs.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use simdutf8::basic::from_utf8;
1111
/// Supported codecs for the decode and encode functions.
1212
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
1313
pub enum Codec {
14+
Base64,
1415
Deflate,
1516
Gzip,
1617
Utf8,
@@ -20,6 +21,7 @@ pub enum Codec {
2021
impl Display for Codec {
2122
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2223
f.write_str(match self {
24+
Self::Base64 => "base64",
2325
Self::Deflate => "deflate",
2426
Self::Gzip => "gzip",
2527
Self::Utf8 => "utf8",
@@ -61,6 +63,7 @@ pub(crate) type Transform = fn(input: &[u8]) -> DaftResult<Vec<u8>>;
6163
impl Codec {
6264
pub(crate) fn encoder(&self) -> Transform {
6365
match self {
66+
Self::Base64 => base64_encoder,
6467
Self::Deflate => deflate_encoder,
6568
Self::Gzip => gzip_encoder,
6669
Self::Utf8 => utf8_encoder,
@@ -70,6 +73,7 @@ impl Codec {
7073

7174
pub(crate) fn decoder(&self) -> Transform {
7275
match self {
76+
Self::Base64 => base64_decoder,
7377
Self::Deflate => deflate_decoder,
7478
Self::Gzip => gzip_decoder,
7579
Self::Utf8 => utf8_decoder,
@@ -79,6 +83,7 @@ impl Codec {
7983

8084
pub(crate) fn kind(&self) -> CodecKind {
8185
match self {
86+
Self::Base64 => CodecKind::Binary,
8287
Self::Deflate => CodecKind::Binary,
8388
Self::Gzip => CodecKind::Binary,
8489
Self::Utf8 => CodecKind::Text,
@@ -99,6 +104,7 @@ impl FromStr for Codec {
99104

100105
fn from_str(s: &str) -> Result<Self, Self::Err> {
101106
match s.to_lowercase().as_str() {
107+
"base64" => Ok(Self::Base64),
102108
"deflate" => Ok(Self::Deflate),
103109
"gzip" | "gz" => Ok(Self::Gzip),
104110
"zlib" => Ok(Self::Zlib),
@@ -115,6 +121,13 @@ impl FromStr for Codec {
115121
// ENCODERS
116122
//
117123

124+
#[inline]
125+
fn base64_encoder(input: &[u8]) -> DaftResult<Vec<u8>> {
126+
use base64::{Engine, engine::general_purpose::STANDARD};
127+
128+
Ok(STANDARD.encode(input).into_bytes())
129+
}
130+
118131
#[inline]
119132
fn deflate_encoder(input: &[u8]) -> DaftResult<Vec<u8>> {
120133
use std::io::Write;
@@ -160,6 +173,14 @@ fn zlib_encoder(input: &[u8]) -> DaftResult<Vec<u8>> {
160173
// DECODERS
161174
//
162175

176+
#[inline]
177+
fn base64_decoder(input: &[u8]) -> DaftResult<Vec<u8>> {
178+
use base64::{Engine, engine::general_purpose::STANDARD};
179+
STANDARD
180+
.decode(input)
181+
.map_err(|e| DaftError::ValueError(format!("Invalid base64 input: {}", e)))
182+
}
183+
163184
#[inline]
164185
fn deflate_decoder(input: &[u8]) -> DaftResult<Vec<u8>> {
165186
use std::io::Read;
@@ -223,6 +244,8 @@ mod tests {
223244
assert_eq!("zlib".parse::<Codec>().unwrap(), Codec::Zlib);
224245
assert_eq!("ZLIB".parse::<Codec>().unwrap(), Codec::Zlib);
225246
assert_eq!("ZlIb".parse::<Codec>().unwrap(), Codec::Zlib);
247+
assert_eq!("base64".parse::<Codec>().unwrap(), Codec::Base64);
248+
assert_eq!("BASE64".parse::<Codec>().unwrap(), Codec::Base64);
226249
assert!("unknown".parse::<Codec>().is_err());
227250
}
228251
}

src/daft-functions-binary/src/encode.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl ScalarUDF for BinaryEncode {
3232
inputs: FunctionArgs<ExprRef>,
3333
schema: &Schema,
3434
) -> DaftResult<Field> {
35-
let Args { input, codec: _ } = inputs.try_into()?;
35+
let Args { input, codec } = inputs.try_into()?;
3636
let input = input.to_field(schema)?;
3737

3838
ensure!(
@@ -44,7 +44,7 @@ impl ScalarUDF for BinaryEncode {
4444
input.dtype
4545
);
4646

47-
Ok(Field::new(input.name, DataType::Binary))
47+
Ok(Field::new(input.name, codec.returns()))
4848
}
4949

5050
fn call(&self, inputs: daft_dsl::functions::FunctionArgs<Series>) -> DaftResult<Series> {

tests/functions/test_codecs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ def test_codec_zlib():
105105

106106

107107
def test_codec_base64():
108-
with pytest.raises(Exception, match="unsupported codec"):
109-
_test_codec("base64", None)
108+
import base64
109+
110+
_test_codec("base64", buff=base64.b64encode(TEXT))
110111

111112

112113
def test_codec_zstd():

0 commit comments

Comments
 (0)