Skip to content

Commit 76bc25c

Browse files
committed
Rework parameter handling
1 parent 7924145 commit 76bc25c

File tree

13 files changed

+93
-61
lines changed

13 files changed

+93
-61
lines changed

src/binder/bind_expression/bind_parameter_expression.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "binder/expression/parameter_expression.h"
22
#include "binder/expression_binder.h"
33
#include "parser/expression/parsed_parameter_expression.h"
4+
#include "common/exception/binder.h"
45

56
using namespace kuzu::common;
67
using namespace kuzu::parser;
@@ -12,14 +13,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindParameterExpression(
1213
const ParsedExpression& parsedExpression) {
1314
auto& parsedParameterExpression = parsedExpression.constCast<ParsedParameterExpression>();
1415
auto parameterName = parsedParameterExpression.getParameterName();
15-
parsedParameters.insert(parameterName);
16-
if (parameterMap.contains(parameterName)) {
17-
return make_shared<ParameterExpression>(parameterName, *parameterMap.at(parameterName));
18-
} else {
19-
auto value = std::make_shared<Value>(Value::createNullValue());
20-
parameterMap.insert({parameterName, value});
21-
return std::make_shared<ParameterExpression>(parameterName, *value);
16+
if (knownParameters.contains(parameterName)) {
17+
return make_shared<ParameterExpression>(parameterName, *knownParameters.at(parameterName));
2218
}
19+
// LCOV_EXCL_START
20+
throw BinderException(stringFormat("Cannot find parameter {}. This should not happen.", parameterName));
21+
// LCOV_EXCL_STOP
2322
}
2423

2524
} // namespace binder

src/binder/binder.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,5 @@ TableFunction Binder::getScanFunction(const FileTypeInfo& typeInfo,
274274
return *func->ptrCast<TableFunction>();
275275
}
276276

277-
void Binder::validateAllInputParametersParsed() const {
278-
for (const auto& [name, _] : expressionBinder.parameterMap) {
279-
if (!expressionBinder.parsedParameters.contains(name)) {
280-
throw Exception("Parameter " + name + " not found.");
281-
}
282-
}
283-
}
284-
285277
} // namespace binder
286278
} // namespace kuzu

src/binder/expression_binder.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindExpression(
3535
bool allParamExist = true;
3636
for (auto& parsedExpr : collector.getParamExprs()) {
3737
auto name = parsedExpr->constCast<ParsedParameterExpression>().getParameterName();
38-
if (!parameterMap.contains(name)) {
39-
auto value = std::make_shared<Value>(Value::createNullValue());
40-
parameterMap.insert({name, value});
41-
parsedParameters.insert(name);
38+
if (!knownParameters.contains(name)) {
39+
unknownParameters.insert(name);
4240
allParamExist = false;
4341
}
4442
}
@@ -147,5 +145,11 @@ std::string ExpressionBinder::getUniqueName(const std::string& name) const {
147145
return binder->getUniqueExpressionName(name);
148146
}
149147

148+
void ExpressionBinder::addParameter(const std::string& name, std::shared_ptr<Value> value) {
149+
KU_ASSERT(!knownParameters.contains(name));
150+
knownParameters[name] = value;
151+
}
152+
153+
150154
} // namespace binder
151155
} // namespace kuzu

src/include/binder/binder.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,6 @@ class Binder {
7575

7676
KUZU_API std::unique_ptr<BoundStatement> bind(const parser::Statement& statement);
7777

78-
void setInputParameters(
79-
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameters) {
80-
expressionBinder.parameterMap = std::move(parameters);
81-
}
82-
83-
std::unordered_map<std::string, std::shared_ptr<common::Value>> getParameterMap() {
84-
return expressionBinder.parameterMap;
85-
}
86-
8778
KUZU_API std::shared_ptr<Expression> createVariable(const std::string& name,
8879
const common::LogicalType& dataType);
8980
KUZU_API std::shared_ptr<Expression> createInvisibleVariable(const std::string& name,
@@ -311,7 +302,6 @@ class Binder {
311302
KUZU_API static void validateColumnExistence(const catalog::TableCatalogEntry* entry,
312303
const std::string& columnName);
313304

314-
void validateAllInputParametersParsed() const;
315305
/*** helpers ***/
316306
std::string getUniqueExpressionName(const std::string& name);
317307

src/include/binder/expression_binder.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,24 @@ class ExpressionBinder {
129129
std::shared_ptr<Expression> forceCast(const std::shared_ptr<Expression>& expression,
130130
const common::LogicalType& targetType);
131131

132+
// Parameter
133+
void addParameter(const std::string& name, std::shared_ptr<common::Value> value);
134+
const std::unordered_set<std::string>& getUnknownParameters() const {
135+
return unknownParameters;
136+
}
137+
const std::unordered_map<std::string, std::shared_ptr<common::Value>>& getKnownParameters() const {
138+
return knownParameters;
139+
}
140+
132141
std::string getUniqueName(const std::string& name) const;
133142

134143
const ExpressionBinderConfig& getConfig() { return config; }
135144

136145
private:
137146
Binder* binder;
138147
main::ClientContext* context;
139-
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameterMap;
140-
std::unordered_set<std::string> parsedParameters;
148+
std::unordered_set<std::string> unknownParameters;
149+
std::unordered_map<std::string, std::shared_ptr<common::Value>> knownParameters;
141150
ExpressionBinderConfig config;
142151
};
143152

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
//#pragma once
2+
//
3+
//namespace common {
4+
//
5+
//}

src/include/main/client_context.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,7 @@ class KUZU_API ClientContext {
208208

209209
PrepareResult prepareNoLock(std::shared_ptr<parser::Statement> parsedStatement,
210210
bool shouldCommitNewTransaction,
211-
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams =
212-
std::nullopt);
211+
std::unordered_map<std::string, std::shared_ptr<common::Value>> inputParams = {});
213212

214213
template<typename T, typename... Args>
215214
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,

src/include/main/prepared_statement.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,18 @@ class PreparedStatement {
6262
*/
6363
KUZU_API bool isReadOnly() const;
6464

65-
std::unordered_map<std::string, std::shared_ptr<common::Value>>& getParameterMapUnsafe() {
66-
return parameterMap;
65+
const std::unordered_set<std::string>& getUnknownParameters() const {
66+
return unknownParameters;
6767
}
68+
std::unordered_set<std::string> getKnownParameters();
69+
void updateParameter(const std::string& name, common::Value* value);
70+
void addParameter(const std::string& name, common::Value* value);
6871

6972
std::string getName() const { return cachedPreparedStatementName; }
7073

7174
common::StatementType getStatementType() const;
7275

73-
void validateExecuteParam(const std::string& paramName, common::Value* param) const;
76+
void validateParam(const std::string& paramName, common::Value* param) const;
7477

7578
static std::unique_ptr<PreparedStatement> getPreparedStatementWithError(
7679
const std::string& errorMessage);
@@ -81,6 +84,7 @@ class PreparedStatement {
8184
std::string errMsg;
8285
PreparedSummary preparedSummary;
8386
std::string cachedPreparedStatementName;
87+
std::unordered_set<std::string> unknownParameters;
8488
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameterMap;
8589
};
8690

src/main/client_context.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -323,17 +323,19 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareWithParams(std::string_
323323

324324
static void bindParametersNoLock(PreparedStatement& preparedStatement,
325325
const std::unordered_map<std::string, std::unique_ptr<Value>>& inputParams) {
326-
auto& parameterMap = preparedStatement.getParameterMapUnsafe();
327-
for (auto& [name, value] : inputParams) {
328-
if (!parameterMap.contains(name)) {
329-
throw Exception("Parameter " + name + " not found.");
326+
for (auto& key : preparedStatement.getKnownParameters()) {
327+
if (inputParams.contains(key)) {
328+
// Found input. Update parameter map.
329+
preparedStatement.updateParameter(key, inputParams.at(key).get());
330330
}
331-
preparedStatement.validateExecuteParam(name, value.get());
332-
// The much more natural `parameterMap.at(name) = std::move(v)` fails.
333-
// The reason is that other parts of the code rely on the existing Value object to be
334-
// modified in-place, not replaced in this map.
335-
*parameterMap.at(name) = std::move(*value);
336331
}
332+
for (auto& key : preparedStatement.getUnknownParameters()) {
333+
if (!inputParams.contains(key)) {
334+
throw Exception("Parameter " + key + " not found.");
335+
}
336+
preparedStatement.addParameter(key, inputParams.at(key).get());
337+
}
338+
337339
}
338340

339341
std::unique_ptr<QueryResult> ClientContext::executeWithParams(PreparedStatement* preparedStatement,
@@ -469,7 +471,7 @@ void ClientContext::validateTransaction(bool readOnly, bool requireTransaction)
469471

470472
ClientContext::PrepareResult ClientContext::prepareNoLock(
471473
std::shared_ptr<Statement> parsedStatement, bool shouldCommitNewTransaction,
472-
std::optional<std::unordered_map<std::string, std::shared_ptr<Value>>> inputParams) {
474+
std::unordered_map<std::string, std::shared_ptr<Value>> inputParams) {
473475
auto preparedStatement = std::make_unique<PreparedStatement>();
474476
auto cachedStatement = std::make_unique<CachedPreparedStatement>();
475477
cachedStatement->parsedStatement = parsedStatement;
@@ -489,12 +491,13 @@ ClientContext::PrepareResult ClientContext::prepareNoLock(
489491
*transactionContext,
490492
[&]() -> void {
491493
auto binder = Binder(this, localDatabase->getBinderExtensions());
492-
if (inputParams) {
493-
binder.setInputParameters(*inputParams);
494+
auto expressionBinder = binder.getExpressionBinder();
495+
for (auto& [name, value] : inputParams) {
496+
expressionBinder->addParameter(name, value);
494497
}
495498
const auto boundStatement = binder.bind(*parsedStatement);
496-
binder.validateAllInputParametersParsed();
497-
preparedStatement->parameterMap = binder.getParameterMap();
499+
preparedStatement->unknownParameters = expressionBinder->getUnknownParameters();
500+
preparedStatement->parameterMap = expressionBinder->getKnownParameters();
498501
cachedStatement->columns = boundStatement->getStatementResult()->getColumns();
499502
auto planner = Planner(this);
500503
auto bestPlan = planner.planStatement(*boundStatement);

src/main/prepared_statement.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ StatementType PreparedStatement::getStatementType() const {
4545
return preparedSummary.statementType;
4646
}
4747

48-
void PreparedStatement::validateExecuteParam(const std::string& paramName,
49-
common::Value* param) const {
48+
void PreparedStatement::validateParam(const std::string& paramName, Value* param) const {
5049
if (param->getDataType().getLogicalTypeID() == LogicalTypeID::POINTER &&
5150
(!parameterMap.contains(paramName) ||
5251
parameterMap.at(paramName)->getValue<uint8_t*>() != param->getValue<uint8_t*>())) {
@@ -59,6 +58,23 @@ void PreparedStatement::validateExecuteParam(const std::string& paramName,
5958
}
6059
}
6160

61+
std::unordered_set<std::string> PreparedStatement::getKnownParameters() {
62+
std::unordered_set<std::string> result;
63+
for (auto& [k, _] : parameterMap) {
64+
result.insert(k);
65+
}
66+
return result;
67+
}
68+
69+
void PreparedStatement::updateParameter(const std::string& name, Value* value) {
70+
validateParam(name, value);
71+
*parameterMap.at(name) = std::move(*value);
72+
}
73+
74+
void PreparedStatement::addParameter(const std::string& name, Value* value) {
75+
parameterMap.insert({name, std::make_shared<Value>(*value)});
76+
}
77+
6278
PreparedStatement::~PreparedStatement() = default;
6379

6480
std::unique_ptr<PreparedStatement> PreparedStatement::getPreparedStatementWithError(

0 commit comments

Comments
 (0)