diff --git a/helix-db/src/grammar.pest b/helix-db/src/grammar.pest index a7a0dff5..9f85c460 100644 --- a/helix-db/src/grammar.pest +++ b/helix-db/src/grammar.pest @@ -127,7 +127,7 @@ evaluates_to_number = { // --------------------------------------------------------------------- // Return statement // --------------------------------------------------------------------- -return_stmt = { "RETURN" ~ evaluates_to_anything ~ ("," ~ evaluates_to_anything)* } +return_stmt = { "RETURN" ~ (evaluates_to_anything ~ ("," ~ evaluates_to_anything)* | array_creation | object_creation) } // --------------------------------------------------------------------- // Creation steps @@ -226,6 +226,10 @@ exclude_field = { "!" ~ "{" ~ identifier ~ ("," ~ identifier)* ~ "}" } closure_step = { "|" ~ identifier ~ "|" ~ object_step } spread_object = { ".." ~ ","?} mapping_field = { (identifier ~ (":" ~ (anonymous_traversal | evaluates_to_anything | object_step))) | identifier } +array_creation = { "[" ~ (identifier | object_creation ) ~ ("," ~ (identifier | object_creation))* ~ ","? ~ "]" } +object_creation = { "{" ~ object_inner ~ ("," ~ object_inner)* ~ ","? ~ "}" } +object_inner = { identifier ~ ":" ~ object_field } +object_field = { (object_creation | array_creation | evaluates_to_anything) } // --------------------------------------------------------------------- // Macros diff --git a/helix-db/src/helixc/analyzer/analyzer.rs b/helix-db/src/helixc/analyzer/analyzer.rs index 5f06befb..ddba9c80 100644 --- a/helix-db/src/helixc/analyzer/analyzer.rs +++ b/helix-db/src/helixc/analyzer/analyzer.rs @@ -5,12 +5,12 @@ use crate::helixc::{ methods::{ migration_validation::validate_migration, query_validation::validate_query, - schema_methods::{SchemaVersionMap, build_field_lookups, check_schema}, + schema_methods::{build_field_lookups, check_schema, SchemaVersionMap}, }, types::Type, }, generator::Source as GeneratedSource, - parser::helix_parser::{EdgeSchema, ExpressionType, Field, Query, Source}, + parser::helix_parser::{EdgeSchema, ExpressionType, Field, Query, ReturnType, Source}, }; use itertools::Itertools; use serde::Serialize; @@ -250,8 +250,12 @@ impl QueryData { .return_values .iter() .flat_map(|e| { - if let ExpressionType::Identifier(ident) = &e.expr { - Some(ident.clone()) + if let ReturnType::Expression(expr) = e { + if let ExpressionType::Identifier(ident) = &expr.expr { + Some(ident.clone()) + } else { + None + } } else { None } diff --git a/helix-db/src/helixc/analyzer/methods/query_validation.rs b/helix-db/src/helixc/analyzer/methods/query_validation.rs index f9151e82..77832ef4 100644 --- a/helix-db/src/helixc/analyzer/methods/query_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/query_validation.rs @@ -39,18 +39,19 @@ pub(crate) fn validate_query<'a>(ctx: &mut Ctx<'a>, original_query: &'a Query) { // ------------------------------------------------- for param in &original_query.parameters { if let FieldType::Identifier(ref id) = param.param_type.1 - && is_valid_identifier(ctx, original_query, param.param_type.0.clone(), id.as_str()) { - // TODO: add support for edges - if !ctx.node_set.contains(id.as_str()) { - generate_error!( - ctx, - original_query, - param.param_type.0.clone(), - E209, - &id, - ¶m.name.1 - ); - } + && is_valid_identifier(ctx, original_query, param.param_type.0.clone(), id.as_str()) + { + // TODO: add support for edges + if !ctx.node_set.contains(id.as_str()) { + generate_error!( + ctx, + original_query, + param.param_type.0.clone(), + E209, + &id, + ¶m.name.1 + ); + } } // constructs parameters and sub‑parameters for generator GeneratedParameter::unwrap_param( @@ -101,98 +102,7 @@ pub(crate) fn validate_query<'a>(ctx: &mut Ctx<'a>, original_query: &'a Query) { ); } for ret in &original_query.return_values { - let (_, stmt) = infer_expr_type(ctx, ret, &mut scope, original_query, None, &mut query); - - assert!(stmt.is_some(), "RETURN value should be a valid expression"); - match stmt.unwrap() { - GeneratedStatement::Traversal(traversal) => { - match &traversal.source_step.inner() { - SourceStep::Identifier(v) => { - is_valid_identifier( - ctx, - original_query, - ret.loc.clone(), - v.inner().as_str(), - ); - - // if is single object, need to handle it as a single object - // if is array, need to handle it as an array - match traversal.should_collect { - ShouldCollect::ToVec => { - query.return_values.push(ReturnValue::new_named( - GeneratedValue::Literal(GenRef::Literal(v.inner().clone())), - ReturnValueExpr::Traversal(traversal.clone()), - )); - } - ShouldCollect::ToVal => { - query.return_values.push(ReturnValue::new_single_named( - GeneratedValue::Literal(GenRef::Literal(v.inner().clone())), - ReturnValueExpr::Traversal(traversal.clone()), - )); - } - _ => { - unreachable!() - } - } - } - _ => { - query.return_values.push(ReturnValue::new_unnamed( - ReturnValueExpr::Traversal(traversal.clone()), - )); - } - } - } - GeneratedStatement::Identifier(id) => { - is_valid_identifier(ctx, original_query, ret.loc.clone(), id.inner().as_str()); - let identifier_end_type = match scope.get(id.inner().as_str()) { - Some(t) => t.clone(), - None => { - generate_error!( - ctx, - original_query, - ret.loc.clone(), - E301, - id.inner().as_str() - ); - Type::Unknown - } - }; - let value = - gen_identifier_or_param(original_query, id.inner().as_str(), false, true); - - match identifier_end_type { - Type::Scalar(_) | Type::Boolean => { - query.return_values.push(ReturnValue::new_named_literal( - GeneratedValue::Literal(GenRef::Literal(id.inner().clone())), - value, - )); - } - Type::Node(_) | Type::Vector(_) | Type::Edge(_) => { - query.return_values.push(ReturnValue::new_single_named( - GeneratedValue::Literal(GenRef::Literal(id.inner().clone())), - ReturnValueExpr::Identifier(value), - )); - } - _ => { - query.return_values.push(ReturnValue::new_named( - GeneratedValue::Literal(GenRef::Literal(id.inner().clone())), - ReturnValueExpr::Identifier(value), - )); - } - } - } - GeneratedStatement::Literal(l) => { - query.return_values.push(ReturnValue::new_literal( - GeneratedValue::Literal(l.clone()), - GeneratedValue::Literal(l.clone()), - )); - } - GeneratedStatement::Empty => query.return_values = vec![], - - // given all erroneous statements are caught by the analyzer, this should never happen - // all malformed statements (not gramatically correct) should be caught by the parser - _ => unreachable!(), - } + analyze_return_expr(ctx, original_query, &mut scope, &mut query, ret); } if let Some(BuiltInMacro::MCP) = &original_query.built_in_macro { @@ -211,3 +121,220 @@ pub(crate) fn validate_query<'a>(ctx: &mut Ctx<'a>, original_query: &'a Query) { ctx.output.queries.push(query); } + +fn analyze_return_expr<'a>( + ctx: &mut Ctx<'a>, + original_query: &'a Query, + scope: &mut HashMap<&'a str, Type>, + query: &mut GeneratedQuery, + ret: &'a ReturnType, +) { + match ret { + ReturnType::Expression(expr) => { + let (_, stmt) = infer_expr_type(ctx, expr, scope, original_query, None, query); + + match stmt.unwrap() { + GeneratedStatement::Traversal(traversal) => { + match &traversal.source_step.inner() { + SourceStep::Identifier(v) => { + is_valid_identifier( + ctx, + original_query, + expr.loc.clone(), + v.inner().as_str(), + ); + + // if is single object, need to handle it as a single object + // if is array, need to handle it as an array + match traversal.should_collect { + ShouldCollect::ToVec => { + query.return_values.push(ReturnValue::new_named( + GeneratedValue::Literal(GenRef::Literal(v.inner().clone())), + ReturnValueExpr::Traversal(traversal.clone()), + )); + } + ShouldCollect::ToVal => { + query.return_values.push(ReturnValue::new_single_named( + GeneratedValue::Literal(GenRef::Literal(v.inner().clone())), + ReturnValueExpr::Traversal(traversal.clone()), + )); + } + _ => { + unreachable!() + } + } + } + _ => { + query.return_values.push(ReturnValue::new_unnamed( + ReturnValueExpr::Traversal(traversal.clone()), + )); + } + } + } + GeneratedStatement::Identifier(id) => { + is_valid_identifier(ctx, original_query, expr.loc.clone(), id.inner().as_str()); + let identifier_end_type = match scope.get(id.inner().as_str()) { + Some(t) => t.clone(), + None => { + generate_error!( + ctx, + original_query, + expr.loc.clone(), + E301, + id.inner().as_str() + ); + Type::Unknown + } + }; + let value = + gen_identifier_or_param(original_query, id.inner().as_str(), false, true); + + match identifier_end_type { + Type::Scalar(_) | Type::Boolean => { + query.return_values.push(ReturnValue::new_named_literal( + GeneratedValue::Literal(GenRef::Literal(id.inner().clone())), + value, + )); + } + Type::Node(_) | Type::Vector(_) | Type::Edge(_) => { + query.return_values.push(ReturnValue::new_single_named( + GeneratedValue::Literal(GenRef::Literal(id.inner().clone())), + ReturnValueExpr::Identifier(value), + )); + } + _ => { + query.return_values.push(ReturnValue::new_named( + GeneratedValue::Literal(GenRef::Literal(id.inner().clone())), + ReturnValueExpr::Identifier(value), + )); + } + } + } + GeneratedStatement::Literal(l) => { + query.return_values.push(ReturnValue::new_literal( + GeneratedValue::Literal(l.clone()), + GeneratedValue::Literal(l.clone()), + )); + } + GeneratedStatement::Empty => query.return_values = vec![], + + // given all erroneous statements are caught by the analyzer, this should never happen + // all malformed statements (not gramatically correct) should be caught by the parser + _ => unreachable!(), + } + } + ReturnType::Array(values) => { + let values = values + .iter() + .map(|object| process_return_object(ctx, original_query, scope, object, query)) + .collect::>(); + query.return_values.push(ReturnValue::new_array(values)); + } + ReturnType::Object(values) => { + let values = values + .iter() + .map(|(key, value)| { + ( + key.clone(), + process_return_object(ctx, original_query, scope, value, query), + ) + }) + .collect::>(); + query.return_values.push(ReturnValue::new_object(values)); + } + ReturnType::Empty => {} + } +} + +fn process_return_object<'a>( + ctx: &mut Ctx<'a>, + original_query: &'a Query, + scope: &mut HashMap<&'a str, Type>, + return_type: &'a ReturnType, + query: &mut GeneratedQuery, +) -> ReturnValueExpr { + match return_type { + ReturnType::Expression(expr) => { + let (_, stmt) = infer_expr_type(ctx, expr, scope, original_query, None, query); + match stmt.unwrap() { + GeneratedStatement::Traversal(traversal) => { + match &traversal.source_step.inner() { + SourceStep::Identifier(v) => { + is_valid_identifier( + ctx, + original_query, + expr.loc.clone(), + v.inner().as_str(), + ); + + // if is single object, need to handle it as a single object + // if is array, need to handle it as an array + match traversal.should_collect { + ShouldCollect::ToVec => { + ReturnValueExpr::Traversal(traversal.clone()) + } + ShouldCollect::ToVal => { + ReturnValueExpr::Traversal(traversal.clone()) + } + _ => { + unreachable!() + } + } + } + _ => ReturnValueExpr::Traversal(traversal.clone()), + } + } + GeneratedStatement::Identifier(id) => { + is_valid_identifier(ctx, original_query, expr.loc.clone(), id.inner().as_str()); + let identifier_end_type = match scope.get(id.inner().as_str()) { + Some(t) => t.clone(), + None => { + generate_error!( + ctx, + original_query, + expr.loc.clone(), + E301, + id.inner().as_str() + ); + Type::Unknown + } + }; + let value = + gen_identifier_or_param(original_query, id.inner().as_str(), false, true); + + match identifier_end_type { + Type::Scalar(_) | Type::Boolean => ReturnValueExpr::Identifier(value), + Type::Node(_) | Type::Vector(_) | Type::Edge(_) => { + ReturnValueExpr::Identifier(value) + } + _ => ReturnValueExpr::Identifier(value), + } + } + GeneratedStatement::Literal(l) => { + ReturnValueExpr::Value(GeneratedValue::Literal(l.clone())) + } + _ => unreachable!(), + } + } + ReturnType::Array(values) => { + let values = values + .iter() + .map(|value| process_return_object(ctx, original_query, scope, value, query)) + .collect::>(); + ReturnValueExpr::Array(values) + } + ReturnType::Object(values) => { + let values = values + .iter() + .map(|(key, value)| { + ( + key.clone(), + process_return_object(ctx, original_query, scope, value, query), + ) + }) + .collect::>(); + ReturnValueExpr::Object(values) + } + _ => unreachable!(), + } +} diff --git a/helix-db/src/helixc/generator/queries.rs b/helix-db/src/helixc/generator/queries.rs index 3709a1e5..0718aeb0 100644 --- a/helix-db/src/helixc/generator/queries.rs +++ b/helix-db/src/helixc/generator/queries.rs @@ -84,8 +84,10 @@ impl Query { fn print_query(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // prints the function signature - self.print_input_struct(f)?; - self.print_parameters(f)?; + if !self.parameters.is_empty() { + self.print_input_struct(f)?; + self.print_parameters(f)?; + } self.print_handler(f)?; writeln!( f, @@ -95,17 +97,19 @@ impl Query { // print the db boilerplate writeln!(f, "let db = Arc::clone(&input.graph.storage);")?; - match self.hoisted_embedding_calls.is_empty() { - true => writeln!( - f, - "let data = input.request.in_fmt.deserialize::<{}Input>(&input.request.body)?;", - self.name - )?, - false => writeln!( - f, - "let data = input.request.in_fmt.deserialize::<{}Input>(&input.request.body)?.into_owned();", - self.name - )?, + if !self.parameters.is_empty() { + match self.hoisted_embedding_calls.is_empty() { + true => writeln!( + f, + "let data = input.request.in_fmt.deserialize::<{}Input>(&input.request.body)?;", + self.name + )?, + false => writeln!( + f, + "let data = input.request.in_fmt.deserialize::<{}Input>(&input.request.body)?.into_owned();", + self.name + )?, + } } // print embedding calls @@ -221,10 +225,7 @@ impl Query { )?; writeln!(f, "connection.iter = result.into_iter();")?; - writeln!( - f, - "let mut connections = connections.lock().unwrap();" - )?; + writeln!(f, "let mut connections = connections.lock().unwrap();")?; writeln!(f, "connections.add_connection(connection);")?; writeln!(f, "drop(connections);")?; writeln!( diff --git a/helix-db/src/helixc/generator/return_values.rs b/helix-db/src/helixc/generator/return_values.rs index dfdc6eb2..c92a419d 100644 --- a/helix-db/src/helixc/generator/return_values.rs +++ b/helix-db/src/helixc/generator/return_values.rs @@ -1,9 +1,8 @@ use core::fmt; -use std::fmt::Display; +use std::{collections::HashMap, fmt::Display}; use crate::helixc::generator::{traversal_steps::Traversal, utils::GeneratedValue}; - pub struct ReturnValue { pub value: ReturnValueExpr, pub return_type: ReturnType, @@ -40,8 +39,25 @@ impl Display for ReturnValue { ) } ReturnType::UnnamedExpr => { - write!(f, "// need to implement unnamed return value\n todo!()")?; - panic!("Unnamed return value is not supported"); + writeln!( + f, + " return_vals.insert(\"data\".to_string(), ReturnValue::from_traversal_value_array_with_mixin({}.clone(), remapping_vals.borrow_mut()));", + self.value + ) + } + ReturnType::HashMap => { + writeln!( + f, + " return_vals.insert(\"data\".to_string(), ReturnValue::from({}));", + self.value + ) + } + ReturnType::Array => { + writeln!( + f, + " return_vals.insert(\"data\".to_string(), ReturnValue::from({}));", + self.value + ) } } } @@ -55,6 +71,8 @@ impl ReturnValue { ReturnType::NamedExpr(name) => name.inner().inner().to_string(), ReturnType::SingleExpr(name) => name.inner().inner().to_string(), ReturnType::UnnamedExpr => todo!(), + ReturnType::HashMap => todo!(), + ReturnType::Array => todo!(), } } @@ -88,6 +106,18 @@ impl ReturnValue { return_type: ReturnType::UnnamedExpr, } } + pub fn new_array(values: Vec) -> Self { + Self { + value: ReturnValueExpr::Array(values), + return_type: ReturnType::Array, + } + } + pub fn new_object(values: HashMap) -> Self { + Self { + value: ReturnValueExpr::Object(values), + return_type: ReturnType::HashMap, + } + } } #[derive(Clone)] @@ -97,12 +127,16 @@ pub enum ReturnType { NamedExpr(GeneratedValue), SingleExpr(GeneratedValue), UnnamedExpr, + HashMap, + Array, } #[derive(Clone)] pub enum ReturnValueExpr { Traversal(Traversal), Identifier(GeneratedValue), Value(GeneratedValue), + Array(Vec), + Object(HashMap), } impl Display for ReturnValueExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -110,6 +144,22 @@ impl Display for ReturnValueExpr { ReturnValueExpr::Traversal(traversal) => write!(f, "{traversal}"), ReturnValueExpr::Identifier(identifier) => write!(f, "{identifier}"), ReturnValueExpr::Value(value) => write!(f, "{value}"), + ReturnValueExpr::Array(values) => { + write!(f, "vec![")?; + // if traversal then use the other from functions + for value in values { + write!(f, "ReturnValue::from({value}),")?; + } + write!(f, "]") + } + ReturnValueExpr::Object(values) => { + write!(f, "HashMap::from([")?; + // if traversal then use the other from functions + for (key, value) in values { + write!(f, "(String::from(\"{key}\"), ReturnValue::from({value})),")?; + } + write!(f, "])") + } } } } diff --git a/helix-db/src/helixc/parser/helix_parser.rs b/helix-db/src/helixc/parser/helix_parser.rs index 471409f8..eebb2698 100644 --- a/helix-db/src/helixc/parser/helix_parser.rs +++ b/helix-db/src/helixc/parser/helix_parser.rs @@ -411,7 +411,7 @@ pub struct Query { pub name: String, pub parameters: Vec, pub statements: Vec, - pub return_values: Vec, + pub return_values: Vec, pub loc: Loc, } @@ -501,6 +501,15 @@ pub enum ExpressionType { BM25Search(BM25Search), Empty, } + +#[derive(Debug, Clone)] +pub enum ReturnType { + Array(Vec), + Object(HashMap), + Expression(Expression), + Empty, +} + impl Debug for ExpressionType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -2212,11 +2221,91 @@ impl HelixParser { }) } - fn parse_return_statement(&self, pair: Pair) -> Result, ParserError> { + fn parse_return_statement(&self, pair: Pair) -> Result, ParserError> { // println!("pair: {:?}", pair.clone().into_inner()); + let inner = pair.into_inner(); + let mut return_types = Vec::new(); + for pair in inner { + match pair.as_rule() { + Rule::array_creation => { + return_types.push(ReturnType::Array(self.parse_array_creation(pair)?)); + } + Rule::object_creation => { + return_types.push(ReturnType::Object(self.parse_object_creation(pair)?)); + } + Rule::evaluates_to_anything => { + return_types.push(ReturnType::Expression(self.parse_expression(pair)?)); + } + _ => { + return Err(ParserError::from(format!( + "Unexpected rule in return statement: {:?}", + pair.as_rule() + ))); + } + } + } + Ok(return_types) + } + + fn parse_array_creation(&self, pair: Pair) -> Result, ParserError> { + let pairs = pair.into_inner(); + let mut objects = Vec::new(); + for p in pairs { + match p.as_rule() { + Rule::identifier => { + objects.push(ReturnType::Expression(Expression { + loc: p.loc(), + expr: ExpressionType::Identifier(p.as_str().to_string()), + })); + } + _ => { + objects.push(ReturnType::Object(self.parse_object_creation(p)?)); + } + } + } + Ok(objects) + } + + fn parse_object_creation( + &self, + pair: Pair, + ) -> Result, ParserError> { pair.into_inner() - .map(|p| self.parse_expression(p)) - .collect() + .map(|p| { + let mut object_inner = p.into_inner(); + let key = object_inner + .next() + .ok_or_else(|| ParserError::from("Missing object inner"))?; + let value = object_inner + .next() + .ok_or_else(|| ParserError::from("Missing object inner"))?; + let value = self.parse_object_inner(value)?; + Ok((key.as_str().to_string(), value)) + }) + .collect::, _>>() + } + + fn parse_object_inner(&self, object_field: Pair) -> Result { + let object_field_inner = object_field + .into_inner() + .next() + .ok_or_else(|| ParserError::from("Missing object inner"))?; + + match object_field_inner.as_rule() { + Rule::evaluates_to_anything => Ok(ReturnType::Expression( + self.parse_expression(object_field_inner)?, + )), + Rule::object_creation => Ok(ReturnType::Object( + self.parse_object_creation(object_field_inner)?, + )), + Rule::array_creation => Ok(ReturnType::Array( + self.parse_array_creation(object_field_inner)?, + )), + _ => Err(ParserError::from(format!( + "Unexpected rule in parse_object_inner: {:?}", + object_field_inner.as_rule() + ))), + } } fn parse_expression_vec(&self, pairs: Pairs) -> Result, ParserError> { diff --git a/helix-db/src/protocol/return_values.rs b/helix-db/src/protocol/return_values.rs index 1e37c2e0..91a2907d 100644 --- a/helix-db/src/protocol/return_values.rs +++ b/helix-db/src/protocol/return_values.rs @@ -3,10 +3,12 @@ use super::{ value::Value, }; use crate::{debug_println, helix_engine::traversal_core::traversal_value::TraversalValue}; -use crate::utils::{ - count::Count, - filterable::{Filterable, FilterableType}, - items::{Edge, Node}, +use crate::{ + utils::{ + count::Count, + filterable::{Filterable, FilterableType}, + items::{Edge, Node}, + }, }; use sonic_rs::{Deserialize, Serialize}; use std::{cell::RefMut, collections::HashMap}; @@ -197,31 +199,6 @@ impl ReturnValue { return_value } - #[inline(always)] - #[allow(unused_attributes)] - #[ignore = "No use for this function yet, however, I believe it may be useful in the future so I'm keeping it here"] - pub fn mixin_other(&self, item: I, secondary_properties: ResponseRemapping) -> Self - where - I: Filterable + Clone, - { - let mut return_val = ReturnValue::default(); - if !secondary_properties.should_spread { - match item.type_name() { - FilterableType::Node => { - return_val = ReturnValue::from(item); - } - FilterableType::Edge => { - return_val = ReturnValue::from(item); - } - FilterableType::Vector => { - return_val = ReturnValue::from(item); - } - } - } - return_val = return_val.mixin_remapping(secondary_properties.remappings); - return_val - } - #[inline] pub fn from_traversal_value( item: T, @@ -278,6 +255,7 @@ impl ReturnValue { ReturnValue::Object(properties) } + } impl From for ReturnValue { @@ -461,6 +439,12 @@ impl From> for ReturnValue { } } +impl From> for ReturnValue { + fn from(array: Vec) -> Self { + ReturnValue::Array(array) + } +} + impl From for ReturnValue { fn from(val: TraversalValue) -> Self { match val { diff --git a/test_simple.hx b/test_simple.hx new file mode 100644 index 00000000..0519ecba --- /dev/null +++ b/test_simple.hx @@ -0,0 +1 @@ + \ No newline at end of file