Skip to content

Commit 982bb61

Browse files
authored
feat: support if()/ifnull()/nullif()/coalesce()/Case ... When ... (#143)
1 parent 3907fcd commit 982bb61

File tree

19 files changed

+926
-100
lines changed

19 files changed

+926
-100
lines changed

src/binder/aggregate.rs

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,52 @@ impl<'a, T: Transaction> Binder<'a, T> {
139139
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
140140
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
141141
ScalarExpression::Tuple(args)
142-
| ScalarExpression::Function(ScalarFunction { args, .. }) => {
142+
| ScalarExpression::Function(ScalarFunction { args, .. })
143+
| ScalarExpression::Coalesce { exprs: args, .. } => {
143144
for expr in args {
144145
self.visit_column_agg_expr(expr)?;
145146
}
146147
}
148+
ScalarExpression::If {
149+
condition,
150+
left_expr,
151+
right_expr,
152+
..
153+
} => {
154+
self.visit_column_agg_expr(condition)?;
155+
self.visit_column_agg_expr(left_expr)?;
156+
self.visit_column_agg_expr(right_expr)?;
157+
}
158+
ScalarExpression::IfNull {
159+
left_expr,
160+
right_expr,
161+
..
162+
}
163+
| ScalarExpression::NullIf {
164+
left_expr,
165+
right_expr,
166+
..
167+
} => {
168+
self.visit_column_agg_expr(left_expr)?;
169+
self.visit_column_agg_expr(right_expr)?;
170+
}
171+
ScalarExpression::CaseWhen {
172+
operand_expr,
173+
expr_pairs,
174+
else_expr,
175+
..
176+
} => {
177+
if let Some(expr) = operand_expr {
178+
self.visit_column_agg_expr(expr)?;
179+
}
180+
for (expr_1, expr_2) in expr_pairs {
181+
self.visit_column_agg_expr(expr_1)?;
182+
self.visit_column_agg_expr(expr_2)?;
183+
}
184+
if let Some(expr) = else_expr {
185+
self.visit_column_agg_expr(expr)?;
186+
}
187+
}
147188
}
148189

149190
Ok(())
@@ -318,12 +359,59 @@ impl<'a, T: Transaction> Binder<'a, T> {
318359
ScalarExpression::Constant(_) => Ok(()),
319360
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
320361
ScalarExpression::Tuple(args)
321-
| ScalarExpression::Function(ScalarFunction { args, .. }) => {
362+
| ScalarExpression::Function(ScalarFunction { args, .. })
363+
| ScalarExpression::Coalesce { exprs: args, .. } => {
322364
for expr in args {
323365
self.validate_having_orderby(expr)?;
324366
}
325367
Ok(())
326368
}
369+
ScalarExpression::If {
370+
condition,
371+
left_expr,
372+
right_expr,
373+
..
374+
} => {
375+
self.validate_having_orderby(condition)?;
376+
self.validate_having_orderby(left_expr)?;
377+
self.validate_having_orderby(right_expr)?;
378+
379+
Ok(())
380+
}
381+
ScalarExpression::IfNull {
382+
left_expr,
383+
right_expr,
384+
..
385+
}
386+
| ScalarExpression::NullIf {
387+
left_expr,
388+
right_expr,
389+
..
390+
} => {
391+
self.validate_having_orderby(left_expr)?;
392+
self.validate_having_orderby(right_expr)?;
393+
394+
Ok(())
395+
}
396+
ScalarExpression::CaseWhen {
397+
operand_expr,
398+
expr_pairs,
399+
else_expr,
400+
..
401+
} => {
402+
if let Some(expr) = operand_expr {
403+
self.validate_having_orderby(expr)?;
404+
}
405+
for (expr_1, expr_2) in expr_pairs {
406+
self.validate_having_orderby(expr_1)?;
407+
self.validate_having_orderby(expr_2)?;
408+
}
409+
if let Some(expr) = else_expr {
410+
self.validate_having_orderby(expr)?;
411+
}
412+
413+
Ok(())
414+
}
327415
}
328416
}
329417
}

src/binder/expr.rs

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,44 @@ impl<'a, T: Transaction> Binder<'a, T> {
122122
}
123123
Ok(ScalarExpression::Tuple(bond_exprs))
124124
}
125+
Expr::Case {
126+
operand,
127+
conditions,
128+
results,
129+
else_result,
130+
} => {
131+
let mut operand_expr = None;
132+
let mut ty = LogicalType::SqlNull;
133+
if let Some(expr) = operand {
134+
operand_expr = Some(Box::new(self.bind_expr(expr)?));
135+
}
136+
let mut expr_pairs = Vec::with_capacity(conditions.len());
137+
for i in 0..conditions.len() {
138+
let result = self.bind_expr(&results[i])?;
139+
let result_ty = result.return_type();
140+
141+
if result_ty != LogicalType::SqlNull {
142+
if ty == LogicalType::SqlNull {
143+
ty = result_ty;
144+
} else if ty != result_ty {
145+
return Err(DatabaseError::Incomparable(ty, result_ty));
146+
}
147+
}
148+
expr_pairs.push((self.bind_expr(&conditions[i])?, result))
149+
}
150+
151+
let mut else_expr = None;
152+
if let Some(expr) = else_result {
153+
else_expr = Some(Box::new(self.bind_expr(expr)?));
154+
}
155+
156+
Ok(ScalarExpression::CaseWhen {
157+
operand_expr,
158+
expr_pairs,
159+
else_expr,
160+
ty,
161+
})
162+
}
125163
_ => {
126164
todo!()
127165
}
@@ -272,14 +310,20 @@ impl<'a, T: Transaction> Binder<'a, T> {
272310

273311
match function_name.as_str() {
274312
"count" => {
313+
if args.len() != 1 {
314+
return Err(DatabaseError::MisMatch("number of count() parameters", "1"));
315+
}
275316
return Ok(ScalarExpression::AggCall {
276317
distinct: func.distinct,
277318
kind: AggKind::Count,
278319
args,
279320
ty: LogicalType::Integer,
280-
})
321+
});
281322
}
282323
"sum" => {
324+
if args.len() != 1 {
325+
return Err(DatabaseError::MisMatch("number of sum() parameters", "1"));
326+
}
283327
let ty = args[0].return_type();
284328

285329
return Ok(ScalarExpression::AggCall {
@@ -290,6 +334,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
290334
});
291335
}
292336
"min" => {
337+
if args.len() != 1 {
338+
return Err(DatabaseError::MisMatch("number of min() parameters", "1"));
339+
}
293340
let ty = args[0].return_type();
294341

295342
return Ok(ScalarExpression::AggCall {
@@ -300,6 +347,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
300347
});
301348
}
302349
"max" => {
350+
if args.len() != 1 {
351+
return Err(DatabaseError::MisMatch("number of max() parameters", "1"));
352+
}
303353
let ty = args[0].return_type();
304354

305355
return Ok(ScalarExpression::AggCall {
@@ -310,6 +360,9 @@ impl<'a, T: Transaction> Binder<'a, T> {
310360
});
311361
}
312362
"avg" => {
363+
if args.len() != 1 {
364+
return Err(DatabaseError::MisMatch("number of avg() parameters", "1"));
365+
}
313366
let ty = args[0].return_type();
314367

315368
return Ok(ScalarExpression::AggCall {
@@ -319,6 +372,77 @@ impl<'a, T: Transaction> Binder<'a, T> {
319372
ty,
320373
});
321374
}
375+
"if" => {
376+
if args.len() != 3 {
377+
return Err(DatabaseError::MisMatch("number of if() parameters", "3"));
378+
}
379+
let ty = Self::return_type(&args[1], &args[2])?;
380+
let right_expr = Box::new(args.pop().unwrap());
381+
let left_expr = Box::new(args.pop().unwrap());
382+
let condition = Box::new(args.pop().unwrap());
383+
384+
return Ok(ScalarExpression::If {
385+
condition,
386+
left_expr,
387+
right_expr,
388+
ty,
389+
});
390+
}
391+
"nullif" => {
392+
if args.len() != 2 {
393+
return Err(DatabaseError::MisMatch(
394+
"number of nullif() parameters",
395+
"3",
396+
));
397+
}
398+
let ty = Self::return_type(&args[0], &args[1])?;
399+
let right_expr = Box::new(args.pop().unwrap());
400+
let left_expr = Box::new(args.pop().unwrap());
401+
402+
return Ok(ScalarExpression::NullIf {
403+
left_expr,
404+
right_expr,
405+
ty,
406+
});
407+
}
408+
"ifnull" => {
409+
if args.len() != 2 {
410+
return Err(DatabaseError::MisMatch(
411+
"number of ifnull() parameters",
412+
"3",
413+
));
414+
}
415+
let ty = Self::return_type(&args[0], &args[1])?;
416+
let right_expr = Box::new(args.pop().unwrap());
417+
let left_expr = Box::new(args.pop().unwrap());
418+
419+
return Ok(ScalarExpression::IfNull {
420+
left_expr,
421+
right_expr,
422+
ty,
423+
});
424+
}
425+
"coalesce" => {
426+
let mut ty = LogicalType::SqlNull;
427+
428+
if !args.is_empty() {
429+
ty = args[0].return_type();
430+
431+
for arg in args.iter() {
432+
let temp_ty = arg.return_type();
433+
434+
if temp_ty == LogicalType::SqlNull {
435+
continue;
436+
}
437+
if ty == LogicalType::SqlNull && temp_ty != LogicalType::SqlNull {
438+
ty = temp_ty;
439+
} else if ty != temp_ty {
440+
return Err(DatabaseError::Incomparable(ty, temp_ty));
441+
}
442+
}
443+
}
444+
return Ok(ScalarExpression::Coalesce { exprs: args, ty });
445+
}
322446
_ => (),
323447
}
324448
let arg_types = args.iter().map(ScalarExpression::return_type).collect_vec();
@@ -336,6 +460,20 @@ impl<'a, T: Transaction> Binder<'a, T> {
336460
Err(DatabaseError::NotFound("function", summary.name))
337461
}
338462

463+
fn return_type(
464+
expr_1: &ScalarExpression,
465+
expr_2: &ScalarExpression,
466+
) -> Result<LogicalType, DatabaseError> {
467+
let temp_ty_1 = expr_1.return_type();
468+
let temp_ty_2 = expr_2.return_type();
469+
470+
match (temp_ty_1, temp_ty_2) {
471+
(LogicalType::SqlNull, LogicalType::SqlNull) => Ok(LogicalType::SqlNull),
472+
(ty, LogicalType::SqlNull) | (LogicalType::SqlNull, ty) => Ok(ty),
473+
(ty_1, ty_2) => LogicalType::max_logical_type(&ty_1, &ty_2),
474+
}
475+
}
476+
339477
fn bind_is_null(
340478
&mut self,
341479
expr: &Expr,

src/binder/select.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,16 @@ impl<'a, T: Transaction> Binder<'a, T> {
273273
columns: alias_column,
274274
}) = alias
275275
{
276-
let table_alias = Arc::new(name.value.to_lowercase());
277-
278276
if tables.len() > 1 {
279277
todo!("Implement virtual tables for multiple table aliases");
280278
}
281-
self.register_alias(alias_column, table_alias.to_string(), tables.remove(0))?;
279+
let table_alias = Arc::new(name.value.to_lowercase());
280+
281+
self.register_alias(
282+
alias_column,
283+
table_alias.to_string(),
284+
tables.pop().unwrap(),
285+
)?;
282286

283287
(Some(table_alias), plan)
284288
} else {

src/db.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use ahash::HashMap;
2-
use sqlparser::ast::Statement;
32
use std::path::PathBuf;
43
use std::sync::Arc;
54

@@ -101,8 +100,7 @@ impl<S: Storage> Database<S> {
101100
/// Run SQL queries.
102101
pub async fn run<T: AsRef<str>>(&self, sql: T) -> Result<Vec<Tuple>, DatabaseError> {
103102
let transaction = self.storage.transaction().await?;
104-
let (plan, _) =
105-
Self::build_plan::<T, S::TransactionType>(sql, &transaction, &self.functions)?;
103+
let plan = Self::build_plan::<T, S::TransactionType>(sql, &transaction, &self.functions)?;
106104

107105
Self::run_volcano(transaction, plan).await
108106
}
@@ -133,9 +131,9 @@ impl<S: Storage> Database<S> {
133131
sql: V,
134132
transaction: &<S as Storage>::TransactionType,
135133
functions: &Functions,
136-
) -> Result<(LogicalPlan, Statement), DatabaseError> {
134+
) -> Result<LogicalPlan, DatabaseError> {
137135
// parse
138-
let mut stmts = parse_sql(sql)?;
136+
let stmts = parse_sql(sql)?;
139137
if stmts.is_empty() {
140138
return Err(DatabaseError::EmptyStatement);
141139
}
@@ -154,7 +152,7 @@ impl<S: Storage> Database<S> {
154152
Self::default_optimizer(source_plan).find_best(Some(&transaction.meta_loader()))?;
155153
// println!("best_plan plan: {:#?}", best_plan);
156154

157-
Ok((best_plan, stmts.remove(0)))
155+
Ok(best_plan)
158156
}
159157

160158
pub(crate) fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer {
@@ -241,7 +239,7 @@ pub struct DBTransaction<S: Storage> {
241239

242240
impl<S: Storage> DBTransaction<S> {
243241
pub async fn run<T: AsRef<str>>(&mut self, sql: T) -> Result<Vec<Tuple>, DatabaseError> {
244-
let (plan, _) =
242+
let plan =
245243
Database::<S>::build_plan::<T, S::TransactionType>(sql, &self.inner, &self.functions)?;
246244
let mut stream = build_write(plan, &mut self.inner);
247245

0 commit comments

Comments
 (0)