Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "binder/expression/literal_expression.h"
#include "binder/expression/path_expression.h"
#include "binder/expression/property_expression.h"
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_visitor.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
Expand All @@ -11,8 +12,10 @@
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "common/utils.h"
#include "function/built_in_function_utils.h"
#include "function/cast/functions/cast_from_string_functions.h"
#include "function/rewrite_function.h"
#include "function/scalar_function.h"
#include "function/schema/vector_node_rel_functions.h"
#include "main/client_context.h"

Expand Down Expand Up @@ -56,6 +59,106 @@ QueryGraph Binder::bindPatternElement(const PatternElement& patternElement) {
nodeAndRels.push_back(rightNode);
leftNode = rightNode;
}
if (clientContext->getClientConfig()->recursivePatternSemantic != common::PathSemantic::WALK) {
if (queryGraph.hasRecursiveRel()) {
// The only one recursive rel doesn't need to be handled because RecursiveJoin
// already implements path semantics. Like (a)-[b*2]-(c)
// What needs to be handled is the recursive
// rels with other nodes/rels. Like (a)-[b*2]-(c)-[d*2]-(e) or (a)-[b*2]-(c)-[d]-(e)
auto rels = queryGraph.getQueryRels();
if (rels.size() > 1) {
std::vector<LogicalType> childrenTypes;
binder::expression_vector childrenExpressions;
std::string funcName;
if (clientContext->getClientConfig()->recursivePatternSemantic ==
common::PathSemantic::ACYCLIC) {
auto nodes = queryGraph.getQueryNodes();
for (uint32_t j = 0; j < nodes.size(); ++j) {
childrenTypes.push_back(nodes[j]->getInternalID()->dataType.copy());
childrenExpressions.push_back(nodes[j]->getInternalID());
}
for (uint32_t j = 0; j < rels.size(); ++j) {
if (rels[j]->isRecursive()) {
childrenTypes.push_back(rels[j]->dataType.copy());
childrenExpressions.push_back(rels[j]);
}
}
funcName = function::IsNodeDistinctFunction::name;
} else {
for (uint32_t j = 0; j < rels.size(); ++j) {
if (rels[j]->isRecursive()) {
childrenTypes.push_back(rels[j]->dataType.copy());
childrenExpressions.push_back(rels[j]);
} else {
childrenTypes.push_back(
rels[j]->getInternalIDProperty()->dataType.copy());
childrenExpressions.push_back(rels[j]->getInternalIDProperty());
}
}
funcName = function::IsRelDistinctFunction::name;
}
auto catalog = clientContext->getCatalog();
auto transaction = clientContext->getTransaction();
auto functionEntry = catalog->getFunctionEntry(transaction, funcName);
auto function = function::BuiltInFunctionsUtils::matchFunction(funcName,
childrenTypes, functionEntry->ptrCast<FunctionCatalogEntry>())
->ptrCast<function::ScalarFunction>()
->copy();
std::unique_ptr<function::FunctionBindData> bindData =
std::make_unique<function::FunctionBindData>(
LogicalType(function->returnTypeID));
auto uniqueExpressionName = binder::ScalarFunctionExpression::getUniqueName(
function->name, childrenExpressions);
auto functionExpression = std::make_shared<binder::ScalarFunctionExpression>(
ExpressionType::FUNCTION, std::move(function), std::move(bindData),
childrenExpressions, uniqueExpressionName);
queryGraph.addSemanticExpression(functionExpression);
}
} else {
// not have recursive rel. Like: (a)-[b]-(c)-[d]-(e)
std::vector<LogicalType> childrenTypes;
binder::expression_vector childrenExpressions;
if (clientContext->getClientConfig()->recursivePatternSemantic ==
common::PathSemantic::ACYCLIC) {
auto nodes = queryGraph.getQueryNodes();
for (uint32_t j = 0; j < nodes.size(); ++j) {
childrenTypes.push_back(nodes[j]->getInternalID()->dataType.copy());
childrenExpressions.push_back(nodes[j]->getInternalID());
}
} else {
auto rels = queryGraph.getQueryRels();
for (uint32_t j = 0; j < rels.size(); ++j) {
childrenTypes.push_back(rels[j]->getInternalIDProperty()->dataType.copy());
childrenExpressions.push_back(rels[j]->getInternalIDProperty());
}
}
if (childrenExpressions.size() > 2) {
auto catalog = clientContext->getCatalog();
auto transaction = clientContext->getTransaction();
auto functionEntry =
catalog->getFunctionEntry(transaction, function::IsIDDistinctFunction::name);
auto function = function::BuiltInFunctionsUtils::matchFunction(
function::IsIDDistinctFunction::name, childrenTypes,
functionEntry->ptrCast<FunctionCatalogEntry>())
->ptrCast<function::ScalarFunction>()
->copy();
std::unique_ptr<function::FunctionBindData> bindData =
std::make_unique<function::FunctionBindData>(
LogicalType(function->returnTypeID));
auto uniqueExpressionName = binder::ScalarFunctionExpression::getUniqueName(
function->name, childrenExpressions);
auto functionExpression = std::make_shared<binder::ScalarFunctionExpression>(
ExpressionType::FUNCTION, std::move(function), std::move(bindData),
childrenExpressions, uniqueExpressionName);
queryGraph.addSemanticExpression(functionExpression);
} else if (childrenExpressions.size() == 2) {
// when only two children use no_equal
auto noEquals = expressionBinder.bindComparisonExpression(
kuzu::common::ExpressionType::NOT_EQUALS, childrenExpressions);
queryGraph.addSemanticExpression(noEquals);
}
}
}
if (patternElement.hasPathName()) {
auto pathName = patternElement.getPathName();
auto pathExpression = createPath(pathName, nodeAndRels);
Expand Down
28 changes: 26 additions & 2 deletions src/binder/bind/read/bind_match.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#include "binder/binder.h"
#include "binder/expression/scalar_function_expression.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "function/built_in_function_utils.h"
#include "function/scalar_function.h"
#include "function/schema/vector_node_rel_functions.h"
#include "main/client_context.h"
#include "parser/query/reading_clause/match_clause.h"

using namespace kuzu::common;
using namespace kuzu::parser;

Expand Down Expand Up @@ -37,8 +42,27 @@ static void validateHintCompleteness(const BoundJoinHintNode& root, const QueryG
std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause& readingClause) {
auto& matchClause = readingClause.constCast<MatchClause>();
auto boundGraphPattern = bindGraphPattern(matchClause.getPatternElementsRef());
if (matchClause.hasWherePredicate()) {
std::shared_ptr<Expression> semanticExpression;
auto queryGraphsNum = boundGraphPattern.queryGraphCollection.getNumQueryGraphs();
for (uint32_t i = 0; i < queryGraphsNum; ++i){
for(auto & expr: boundGraphPattern.queryGraphCollection.getQueryGraph(i)->getSemanticExpressions()) {
if(!semanticExpression){
semanticExpression = expr;
} else {
semanticExpression = expressionBinder.bindBooleanExpression(kuzu::common::ExpressionType::AND,
binder::expression_vector{semanticExpression, expr});
}
}
}
if (matchClause.hasWherePredicate() && semanticExpression) {
boundGraphPattern.where =
expressionBinder.bindBooleanExpression(kuzu::common::ExpressionType::AND,
binder::expression_vector{semanticExpression,
bindWhereExpression(*matchClause.getWherePredicate())});
} else if (matchClause.hasWherePredicate()) {
boundGraphPattern.where = bindWhereExpression(*matchClause.getWherePredicate());
} else if (semanticExpression) {
boundGraphPattern.where = semanticExpression;
}
rewriteMatchPattern(boundGraphPattern);
auto boundMatch = std::make_unique<BoundMatchClause>(
Expand Down
12 changes: 12 additions & 0 deletions src/binder/query/query_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ void QueryGraph::merge(const QueryGraph& other) {
for (auto& otherRel : other.queryRels) {
addQueryRel(otherRel);
}
for (auto& otherSemantic : other.semanticExpressions) {
semanticExpressions.push_back(otherSemantic);
}
}

bool QueryGraph::canProjectExpression(const std::shared_ptr<Expression>& expression) const {
Expand All @@ -235,6 +238,15 @@ bool QueryGraph::isConnected(const QueryGraph& other) const {
return false;
}

bool QueryGraph::hasRecursiveRel() {
for (auto& rel : queryRels) {
if (rel->isRecursive()) {
return true;
}
}
return false;
}

void QueryGraphCollection::addAndMergeQueryGraphIfConnected(QueryGraph queryGraphToAdd) {
auto newQueryGraphSet = std::vector<QueryGraph>();
for (auto i = 0u; i < queryGraphs.size(); i++) {
Expand Down
2 changes: 2 additions & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ FunctionCollection* FunctionCollection::getFunctions() {
SCALAR_FUNCTION(OffsetFunction), REWRITE_FUNCTION(IDFunction),
REWRITE_FUNCTION(StartNodeFunction), REWRITE_FUNCTION(EndNodeFunction),
REWRITE_FUNCTION(LabelFunction),
SCALAR_FUNCTION(IsIDDistinctFunction),SCALAR_FUNCTION(IsNodeDistinctFunction),
SCALAR_FUNCTION(IsRelDistinctFunction),

// Path functions
SCALAR_FUNCTION(NodesFunction), SCALAR_FUNCTION(RelsFunction),
Expand Down
Loading