Skip to content

Commit cdfdda8

Browse files
committed
Rework parameter handling
1 parent bcbfabd commit cdfdda8

File tree

12 files changed

+93
-65
lines changed

12 files changed

+93
-65
lines changed

src/binder/bind_expression/bind_parameter_expression.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "binder/expression/parameter_expression.h"
22
#include "binder/expression_binder.h"
33
#include "parser/expression/parsed_parameter_expression.h"
4+
#include "common/exception/binder.h"
45

56
using namespace kuzu::common;
67
using namespace kuzu::parser;
@@ -12,14 +13,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindParameterExpression(
1213
const ParsedExpression& parsedExpression) {
1314
auto& parsedParameterExpression = parsedExpression.constCast<ParsedParameterExpression>();
1415
auto parameterName = parsedParameterExpression.getParameterName();
15-
parsedParameters.insert(parameterName);
16-
if (parameterMap.contains(parameterName)) {
17-
return make_shared<ParameterExpression>(parameterName, *parameterMap.at(parameterName));
18-
} else {
19-
auto value = std::make_shared<Value>(Value::createNullValue());
20-
parameterMap.insert({parameterName, value});
21-
return std::make_shared<ParameterExpression>(parameterName, *value);
16+
if (knownParameters.contains(parameterName)) {
17+
return make_shared<ParameterExpression>(parameterName, *knownParameters.at(parameterName));
2218
}
19+
// LCOV_EXCL_START
20+
throw BinderException(stringFormat("Cannot find parameter {}. This should not happen.", parameterName));
21+
// LCOV_EXCL_STOP
2322
}
2423

2524
} // namespace binder

src/binder/binder.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,5 @@ TableFunction Binder::getScanFunction(const FileTypeInfo& typeInfo,
274274
return *func->ptrCast<TableFunction>();
275275
}
276276

277-
void Binder::validateAllInputParametersParsed() const {
278-
for (const auto& [name, _] : expressionBinder.parameterMap) {
279-
if (!expressionBinder.parsedParameters.contains(name)) {
280-
throw Exception("Parameter " + name + " not found.");
281-
}
282-
}
283-
}
284-
285277
} // namespace binder
286278
} // namespace kuzu

src/binder/expression_binder.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindExpression(
3535
bool allParamExist = true;
3636
for (auto& parsedExpr : collector.getParamExprs()) {
3737
auto name = parsedExpr->constCast<ParsedParameterExpression>().getParameterName();
38-
if (!parameterMap.contains(name)) {
39-
auto value = std::make_shared<Value>(Value::createNullValue());
40-
parameterMap.insert({name, value});
41-
parsedParameters.insert(name);
38+
if (!knownParameters.contains(name)) {
39+
unknownParameters.insert(name);
4240
allParamExist = false;
4341
}
4442
}
@@ -147,5 +145,11 @@ std::string ExpressionBinder::getUniqueName(const std::string& name) const {
147145
return binder->getUniqueExpressionName(name);
148146
}
149147

148+
void ExpressionBinder::addParameter(const std::string& name, std::shared_ptr<Value> value) {
149+
KU_ASSERT(!knownParameters.contains(name));
150+
knownParameters[name] = value;
151+
}
152+
153+
150154
} // namespace binder
151155
} // namespace kuzu

src/include/binder/binder.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,6 @@ class Binder {
7575

7676
KUZU_API std::unique_ptr<BoundStatement> bind(const parser::Statement& statement);
7777

78-
void setInputParameters(
79-
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameters) {
80-
expressionBinder.parameterMap = std::move(parameters);
81-
}
82-
83-
std::unordered_map<std::string, std::shared_ptr<common::Value>> getParameterMap() {
84-
return expressionBinder.parameterMap;
85-
}
86-
8778
KUZU_API std::shared_ptr<Expression> createVariable(const std::string& name,
8879
const common::LogicalType& dataType);
8980
KUZU_API std::shared_ptr<Expression> createInvisibleVariable(const std::string& name,
@@ -311,7 +302,6 @@ class Binder {
311302
KUZU_API static void validateColumnExistence(const catalog::TableCatalogEntry* entry,
312303
const std::string& columnName);
313304

314-
void validateAllInputParametersParsed() const;
315305
/*** helpers ***/
316306
std::string getUniqueExpressionName(const std::string& name);
317307

src/include/binder/expression_binder.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,24 @@ class ExpressionBinder {
129129
std::shared_ptr<Expression> forceCast(const std::shared_ptr<Expression>& expression,
130130
const common::LogicalType& targetType);
131131

132+
// Parameter
133+
void addParameter(const std::string& name, std::shared_ptr<common::Value> value);
134+
const std::unordered_set<std::string>& getUnknownParameters() const {
135+
return unknownParameters;
136+
}
137+
const std::unordered_map<std::string, std::shared_ptr<common::Value>>& getKnownParameters() const {
138+
return knownParameters;
139+
}
140+
132141
std::string getUniqueName(const std::string& name) const;
133142

134143
const ExpressionBinderConfig& getConfig() { return config; }
135144

136145
private:
137146
Binder* binder;
138147
main::ClientContext* context;
139-
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameterMap;
140-
std::unordered_set<std::string> parsedParameters;
148+
std::unordered_set<std::string> unknownParameters;
149+
std::unordered_map<std::string, std::shared_ptr<common::Value>> knownParameters;
141150
ExpressionBinderConfig config;
142151
};
143152

src/include/main/client_context.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ class KUZU_API ClientContext {
203203

204204
PrepareResult prepareNoLock(std::shared_ptr<parser::Statement> parsedStatement,
205205
bool shouldCommitNewTransaction,
206-
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams =
207-
std::nullopt);
206+
std::unordered_map<std::string, std::shared_ptr<common::Value>> inputParams = {});
208207

209208
template<typename T, typename... Args>
210209
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,

src/include/main/prepared_statement.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,17 @@ class PreparedStatement {
6262
*/
6363
KUZU_API bool isReadOnly() const;
6464

65-
std::unordered_map<std::string, std::shared_ptr<common::Value>>& getParameterMapUnsafe() {
66-
return parameterMap;
65+
const std::unordered_set<std::string>& getUnknownParameters() const {
66+
return unknownParameters;
6767
}
68+
std::unordered_set<std::string> getKnownParameters();
69+
void updateParameter(const std::string& name, common::Value* value);
70+
void addParameter(const std::string& name, common::Value* value);
6871

6972
std::string getName() const { return cachedPreparedStatementName; }
7073

7174
common::StatementType getStatementType() const;
7275

73-
void validateExecuteParam(const std::string& paramName, common::Value* param) const;
74-
7576
static std::unique_ptr<PreparedStatement> getPreparedStatementWithError(
7677
const std::string& errorMessage);
7778

@@ -81,6 +82,7 @@ class PreparedStatement {
8182
std::string errMsg;
8283
PreparedSummary preparedSummary;
8384
std::string cachedPreparedStatementName;
85+
std::unordered_set<std::string> unknownParameters;
8486
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameterMap;
8587
};
8688

src/main/client_context.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -307,17 +307,19 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareWithParams(std::string_
307307

308308
static void bindParametersNoLock(PreparedStatement& preparedStatement,
309309
const std::unordered_map<std::string, std::unique_ptr<Value>>& inputParams) {
310-
auto& parameterMap = preparedStatement.getParameterMapUnsafe();
311-
for (auto& [name, value] : inputParams) {
312-
if (!parameterMap.contains(name)) {
313-
throw Exception("Parameter " + name + " not found.");
310+
for (auto& key : preparedStatement.getKnownParameters()) {
311+
if (inputParams.contains(key)) {
312+
// Found input. Update parameter map.
313+
preparedStatement.updateParameter(key, inputParams.at(key).get());
314314
}
315-
preparedStatement.validateExecuteParam(name, value.get());
316-
// The much more natural `parameterMap.at(name) = std::move(v)` fails.
317-
// The reason is that other parts of the code rely on the existing Value object to be
318-
// modified in-place, not replaced in this map.
319-
*parameterMap.at(name) = std::move(*value);
320315
}
316+
for (auto& key : preparedStatement.getUnknownParameters()) {
317+
if (!inputParams.contains(key)) {
318+
throw Exception("Parameter " + key + " not found.");
319+
}
320+
preparedStatement.addParameter(key, inputParams.at(key).get());
321+
}
322+
321323
}
322324

323325
std::unique_ptr<QueryResult> ClientContext::executeWithParams(PreparedStatement* preparedStatement,
@@ -453,7 +455,7 @@ void ClientContext::validateTransaction(bool readOnly, bool requireTransaction)
453455

454456
ClientContext::PrepareResult ClientContext::prepareNoLock(
455457
std::shared_ptr<Statement> parsedStatement, bool shouldCommitNewTransaction,
456-
std::optional<std::unordered_map<std::string, std::shared_ptr<Value>>> inputParams) {
458+
std::unordered_map<std::string, std::shared_ptr<Value>> inputParams) {
457459
auto preparedStatement = std::make_unique<PreparedStatement>();
458460
auto cachedStatement = std::make_unique<CachedPreparedStatement>();
459461
cachedStatement->parsedStatement = parsedStatement;
@@ -473,12 +475,13 @@ ClientContext::PrepareResult ClientContext::prepareNoLock(
473475
*transactionContext,
474476
[&]() -> void {
475477
auto binder = Binder(this, localDatabase->getBinderExtensions());
476-
if (inputParams) {
477-
binder.setInputParameters(*inputParams);
478+
auto expressionBinder = binder.getExpressionBinder();
479+
for (auto& [name, value] : inputParams) {
480+
expressionBinder->addParameter(name, value);
478481
}
479482
const auto boundStatement = binder.bind(*parsedStatement);
480-
binder.validateAllInputParametersParsed();
481-
preparedStatement->parameterMap = binder.getParameterMap();
483+
preparedStatement->unknownParameters = expressionBinder->getUnknownParameters();
484+
preparedStatement->parameterMap = expressionBinder->getKnownParameters();
482485
cachedStatement->columns = boundStatement->getStatementResult()->getColumns();
483486
auto planner = Planner(this);
484487
auto bestPlan = planner.planStatement(*boundStatement);

src/main/prepared_statement.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,9 @@ StatementType PreparedStatement::getStatementType() const {
4545
return preparedSummary.statementType;
4646
}
4747

48-
void PreparedStatement::validateExecuteParam(const std::string& paramName,
49-
common::Value* param) const {
50-
if (param->getDataType().getLogicalTypeID() == LogicalTypeID::POINTER &&
51-
(!parameterMap.contains(paramName) ||
52-
parameterMap.at(paramName)->getValue<uint8_t*>() != param->getValue<uint8_t*>())) {
48+
static void validateParam(const std::string& paramName, Value* newVal, Value* oldVal) {
49+
if (newVal->getDataType().getLogicalTypeID() == LogicalTypeID::POINTER &&
50+
newVal->getValue<uint8_t*>() != oldVal->getValue<uint8_t*>()) {
5351
throw BinderException(stringFormat(
5452
"When preparing the current statement the dataframe passed into parameter "
5553
"'{}' was different from the one provided during prepare. Dataframes parameters "
@@ -59,6 +57,24 @@ void PreparedStatement::validateExecuteParam(const std::string& paramName,
5957
}
6058
}
6159

60+
std::unordered_set<std::string> PreparedStatement::getKnownParameters() {
61+
std::unordered_set<std::string> result;
62+
for (auto& [k, _] : parameterMap) {
63+
result.insert(k);
64+
}
65+
return result;
66+
}
67+
68+
void PreparedStatement::updateParameter(const std::string& name, Value* value) {
69+
KU_ASSERT(parameterMap.contains(name));
70+
validateParam(name, value, parameterMap.at(name).get());
71+
*parameterMap.at(name) = std::move(*value);
72+
}
73+
74+
void PreparedStatement::addParameter(const std::string& name, Value* value) {
75+
parameterMap.insert({name, std::make_shared<Value>(*value)});
76+
}
77+
6278
PreparedStatement::~PreparedStatement() = default;
6379

6480
std::unique_ptr<PreparedStatement> PreparedStatement::getPreparedStatementWithError(

test/api/api_test.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,19 +251,22 @@ TEST_F(ApiTest, PrepareWithSkipLimitError) {
251251
auto prepared = conn->prepare("MATCH (p:person) RETURN p.ID skip $sp");
252252
auto result = conn->execute(prepared.get());
253253
ASSERT_FALSE(result->isSuccess());
254-
ASSERT_EQ(result->toString(), "Runtime exception: Cannot evaluate $sp as a valid skip number.");
254+
ASSERT_EQ(result->toString(), "Parameter sp not found.");
255+
256+
result = conn->execute(prepared.get(), std::make_pair(std::string("sp"), "abc"));
257+
ASSERT_FALSE(result->isSuccess());
258+
ASSERT_EQ(result->toString(),
259+
"Runtime exception: The number of rows to skip/limit must be a non-negative integer.");
255260

256261
prepared = conn->prepare("MATCH (p:person) RETURN p.ID limit $sp");
257262
result = conn->execute(prepared.get());
258263
ASSERT_FALSE(result->isSuccess());
259-
ASSERT_EQ(result->toString(),
260-
"Runtime exception: Cannot evaluate $sp as a valid limit number.");
264+
ASSERT_EQ(result->toString(), "Parameter sp not found.");
261265

262266
prepared = conn->prepare("MATCH (p:person) RETURN p.ID skip $s limit $sp");
263267
result = conn->execute(prepared.get(), std::make_pair(std::string("s"), 3));
264268
ASSERT_FALSE(result->isSuccess());
265-
ASSERT_EQ(result->toString(),
266-
"Runtime exception: Cannot evaluate $sp as a valid limit number.");
269+
ASSERT_EQ(result->toString(), "Parameter sp not found.");
267270

268271
prepared = conn->prepare("MATCH (p:person) RETURN p.ID skip $s");
269272
result = conn->execute(prepared.get(), std::make_pair(std::string("s"), 3.4));
@@ -316,8 +319,16 @@ TEST_F(ApiTest, MissingParam) {
316319
std::unordered_map<std::string, std::unique_ptr<Value>> params;
317320
params["val1"] = std::make_unique<Value>(Value::createValue(3));
318321
auto prep = conn->prepareWithParams("RETURN $val1 + $val2", std::move(params));
319-
ASSERT_FALSE(prep->isSuccess());
320-
ASSERT_STREQ("Parameter val1 not found.", prep->getErrorMessage().c_str());
322+
ASSERT_TRUE(prep->isSuccess());
323+
auto result = conn->execute(prep.get(), std::make_pair(std::string("s"), 3));
324+
ASSERT_FALSE(result->isSuccess());
325+
ASSERT_STREQ("Parameter val2 not found.", result->getErrorMessage().c_str());
326+
result = conn->execute(prep.get(), std::make_pair(std::string("val2"), 1.1));
327+
ASSERT_TRUE(result->isSuccess());
328+
ASSERT_STREQ("4.100000\n", result->getNext()->toString().c_str());
329+
result = conn->execute(prep.get(), std::make_pair(std::string("val2"), 1.1), std::make_pair(std::string("val1"), 1.1));
330+
ASSERT_TRUE(result->isSuccess());
331+
ASSERT_STREQ("2.200000\n", result->getNext()->toString().c_str());
321332
}
322333

323334
TEST_F(ApiTest, CloseDatabaseBeforeQueryResultAndConnection) {

0 commit comments

Comments
 (0)