Skip to content

Commit 5cfad24

Browse files
authored
Make GDS table function (#5048)
1 parent 1068d07 commit 5cfad24

File tree

60 files changed

+651
-1153
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+651
-1153
lines changed

extension/fts/src/fts_extension.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ static void initFTSEntries(const transaction::Transaction* transaction, catalog:
2525
void FTSExtension::load(main::ClientContext* context) {
2626
auto& db = *context->getDatabase();
2727
ExtensionUtils::addScalarFunc<StemFunction>(db);
28-
ExtensionUtils::addGDSFunc<QueryFTSFunction>(db);
28+
ExtensionUtils::addTableFunc<QueryFTSFunction>(db);
2929
ExtensionUtils::addStandaloneTableFunc<CreateFTSFunction>(db);
3030
ExtensionUtils::addStandaloneTableFunc<InternalCreateFTSFunction>(db, true /* isInternal */);
3131
ExtensionUtils::addStandaloneTableFunc<DropFTSFunction>(db);

extension/fts/src/function/query_fts_index.cpp

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "function/gds/gds_utils.h"
1111
#include "function/query_fts_bind_data.h"
1212
#include "processor/execution_context.h"
13-
#include "processor/operator/gds_call_shared_state.h"
1413
#include "storage/index/index_utils.h"
1514
#include "storage/storage_manager.h"
1615
#include "storage/store/node_table.h"
@@ -153,7 +152,7 @@ void QFTSOutputWriter::write(processor::FactorizedTable& scoreFT, nodeID_t docNo
153152

154153
class QFTSVertexCompute final : public VertexCompute {
155154
public:
156-
QFTSVertexCompute(MemoryManager* mm, processor::GDSCallSharedState* sharedState,
155+
QFTSVertexCompute(MemoryManager* mm, GDSFuncSharedState* sharedState,
157156
std::unique_ptr<QFTSOutputWriter> writer)
158157
: mm{mm}, sharedState{sharedState}, writer{std::move(writer)} {
159158
localFT = sharedState->factorizedTablePool.claimLocalTable(mm);
@@ -175,18 +174,24 @@ class QFTSVertexCompute final : public VertexCompute {
175174

176175
private:
177176
MemoryManager* mm;
178-
processor::GDSCallSharedState* sharedState;
177+
GDSFuncSharedState* sharedState;
179178
processor::FactorizedTable* localFT;
180179
std::unique_ptr<QFTSOutputWriter> writer;
181180
};
182181

182+
static constexpr char SCORE_PROP_NAME[] = "score";
183+
static constexpr char DOC_FREQUENCY_PROP_NAME[] = "df";
184+
static constexpr char TERM_FREQUENCY_PROP_NAME[] = "tf";
185+
static constexpr char DOC_LEN_PROP_NAME[] = "len";
186+
static constexpr char DOC_ID_PROP_NAME[] = "docID";
187+
183188
static std::unordered_map<offset_t, uint64_t> getDFs(main::ClientContext& context,
184189
const catalog::NodeTableCatalogEntry& termsEntry, const std::vector<std::string>& terms) {
185190
auto storageManager = context.getStorageManager();
186191
auto tableID = termsEntry.getTableID();
187192
auto& termsNodeTable = storageManager->getTable(tableID)->cast<NodeTable>();
188193
auto tx = context.getTransaction();
189-
auto dfColumnID = termsEntry.getColumnID(QueryFTSAlgorithm::DOC_FREQUENCY_PROP_NAME);
194+
auto dfColumnID = termsEntry.getColumnID(DOC_FREQUENCY_PROP_NAME);
190195
std::vector<LogicalType> vectorTypes;
191196
vectorTypes.push_back(LogicalType::INTERNAL_ID());
192197
vectorTypes.push_back(LogicalType::UINT64());
@@ -231,21 +236,22 @@ static uint64_t getSparseFrontierSize(uint64_t numRows) {
231236
return size;
232237
}
233238

234-
void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
235-
auto clientContext = executionContext->clientContext;
239+
static common::offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) {
240+
auto clientContext = input.context->clientContext;
236241
auto transaction = clientContext->getTransaction();
242+
auto sharedState = input.sharedState->ptrCast<GDSFuncSharedState>();
237243
auto graph = sharedState->graph.get();
238244
auto graphEntry = graph->getGraphEntry();
239-
auto qFTSBindData = bindData->ptrCast<QueryFTSBindData>();
245+
auto qFTSBindData = input.bindData->constPtrCast<QueryFTSBindData>();
240246
auto& termsEntry = graphEntry->nodeInfos[0].entry->constCast<catalog::NodeTableCatalogEntry>();
241-
auto terms = bindData->ptrCast<QueryFTSBindData>()->getTerms(*executionContext->clientContext);
242-
auto dfs = getDFs(*executionContext->clientContext, termsEntry, terms);
247+
auto terms = qFTSBindData->getTerms(*input.context->clientContext);
248+
auto dfs = getDFs(*input.context->clientContext, termsEntry, terms);
243249
// Do edge compute to extend terms -> docs and save the term frequency and document frequency
244250
// for each term-doc pair. The reason why we store the term frequency and document frequency
245251
// is that: we need the `len` property from the docs table which is only available during the
246252
// vertex compute.
247-
auto currentFrontier = PathLengths::getUnvisitedFrontier(executionContext, graph);
248-
auto nextFrontier = PathLengths::getUnvisitedFrontier(executionContext, graph);
253+
auto currentFrontier = PathLengths::getUnvisitedFrontier(input.context, graph);
254+
auto nextFrontier = PathLengths::getUnvisitedFrontier(input.context, graph);
249255
auto frontierPair =
250256
std::make_unique<DoublePathLengthsFrontierPair>(currentFrontier, nextFrontier);
251257
auto termsTableID = termsEntry.getTableID();
@@ -263,15 +269,15 @@ void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
263269
auto auxiliaryState = std::make_unique<EmptyGDSAuxiliaryState>();
264270
auto compState = GDSComputeState(std::move(frontierPair), std::move(edgeCompute),
265271
std::move(auxiliaryState), nullptr /* outputNodeMask */);
266-
GDSUtils::runFrontiersUntilConvergence(executionContext, compState, graph, ExtendDirection::FWD,
272+
GDSUtils::runFrontiersUntilConvergence(input.context, compState, graph, ExtendDirection::FWD,
267273
1 /* maxIters */, TERM_FREQUENCY_PROP_NAME);
268274

269275
// Do vertex compute to calculate the score for doc with the length property.
270276
auto mm = clientContext->getMemoryManager();
271277
auto numUniqueTerms = getNumUniqueTerms(terms);
272278
auto writer = std::make_unique<QFTSOutputWriter>(scores, mm, qFTSBindData->getConfig(),
273279
*qFTSBindData, numUniqueTerms);
274-
auto vc = std::make_unique<QFTSVertexCompute>(mm, sharedState.get(), std::move(writer));
280+
auto vc = std::make_unique<QFTSVertexCompute>(mm, sharedState, std::move(writer));
275281
auto vertexPropertiesToScan = std::vector<std::string>{DOC_LEN_PROP_NAME, DOC_ID_PROP_NAME};
276282
auto docsEntry = graphEntry->nodeInfos[1].entry;
277283
auto numDocs = storageManager->getTable(docsEntry->getTableID())->getNumTotalRows(transaction);
@@ -284,65 +290,75 @@ void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
284290
}
285291
}
286292
} else {
287-
GDSUtils::runVertexCompute(executionContext, graph, *vc, docsEntry, vertexPropertiesToScan);
293+
GDSUtils::runVertexCompute(input.context, graph, *vc, docsEntry, vertexPropertiesToScan);
288294
}
289295
sharedState->factorizedTablePool.mergeLocalTables();
296+
return 0;
290297
}
291298

292-
expression_vector QueryFTSAlgorithm::getResultColumns(const GDSBindInput& bindInput) const {
293-
expression_vector columns;
294-
auto& docsNode = bindData->getNodeOutput()->constCast<NodeExpression>();
295-
columns.push_back(docsNode.getInternalID());
296-
std::string scoreColumnName = SCORE_PROP_NAME;
297-
if (!bindInput.yieldVariables.empty()) {
298-
scoreColumnName = bindColumnName(bindInput.yieldVariables[1], scoreColumnName);
299-
}
300-
auto scoreColumn = bindInput.binder->createVariable(scoreColumnName, LogicalType::DOUBLE());
301-
columns.push_back(scoreColumn);
302-
return columns;
303-
}
304-
305-
static std::string getParamVal(const GDSBindInput& input, idx_t idx) {
299+
static std::string getParamVal(const TableFuncBindInput& input, idx_t idx) {
306300
if (input.getParam(idx)->expressionType != ExpressionType::LITERAL) {
307301
throw BinderException{"The table and index name must be literal expressions."};
308302
}
309303
return ExpressionUtil::getLiteralValue<std::string>(
310304
input.getParam(idx)->constCast<LiteralExpression>());
311305
}
312306

313-
void QueryFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context) {
314-
context.setUseInternalCatalogEntry(true /* useInternalCatalogEntry */);
307+
static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
308+
const TableFuncBindInput* input) {
309+
context->setUseInternalCatalogEntry(true /* useInternalCatalogEntry */);
315310
// For queryFTS, the table and index name must be given at compile time while the user
316311
// can give the query at runtime.
317-
auto inputTableName = getParamVal(input, 0);
318-
auto indexName = getParamVal(input, 1);
319-
auto query = input.getParam(2);
312+
auto inputTableName = getParamVal(*input, 0);
313+
auto indexName = getParamVal(*input, 1);
314+
auto query = input->getParam(2);
320315

321316
auto tableEntry =
322-
IndexUtils::bindTable(context, inputTableName, indexName, IndexOperation::QUERY);
323-
auto ftsIndexEntry = context.getCatalog()->getIndex(context.getTransaction(),
324-
tableEntry->getTableID(), indexName);
325-
auto entry =
326-
context.getCatalog()->getTableCatalogEntry(context.getTransaction(), inputTableName);
327-
auto nodeOutput = bindNodeOutput(input, {entry});
328-
329-
auto termsEntry = context.getCatalog()->getTableCatalogEntry(context.getTransaction(),
317+
IndexUtils::bindTable(*context, inputTableName, indexName, IndexOperation::QUERY);
318+
auto catalog = context->getCatalog();
319+
auto transaction = context->getTransaction();
320+
auto ftsIndexEntry = catalog->getIndex(transaction, tableEntry->getTableID(), indexName);
321+
auto entry = catalog->getTableCatalogEntry(transaction, inputTableName);
322+
auto nodeOutput = GDSFunction::bindNodeOutput(*input, {entry});
323+
324+
auto termsEntry = catalog->getTableCatalogEntry(transaction,
330325
FTSUtils::getTermsTableName(tableEntry->getTableID(), indexName));
331-
auto docsEntry = context.getCatalog()->getTableCatalogEntry(context.getTransaction(),
326+
auto docsEntry = catalog->getTableCatalogEntry(transaction,
332327
FTSUtils::getDocsTableName(tableEntry->getTableID(), indexName));
333-
auto appearsInEntry = context.getCatalog()->getTableCatalogEntry(context.getTransaction(),
328+
auto appearsInEntry = catalog->getTableCatalogEntry(transaction,
334329
FTSUtils::getAppearsInTableName(tableEntry->getTableID(), indexName));
335330
auto graphEntry = graph::GraphEntry({termsEntry, docsEntry}, {appearsInEntry});
336-
bindData = std::make_unique<QueryFTSBindData>(std::move(graphEntry), nodeOutput,
337-
std::move(query), *ftsIndexEntry, QueryFTSOptionalParams{input.optionalParams});
338-
context.setUseInternalCatalogEntry(false /* useInternalCatalogEntry */);
331+
332+
expression_vector columns;
333+
auto& docsNode = nodeOutput->constCast<NodeExpression>();
334+
columns.push_back(docsNode.getInternalID());
335+
std::string scoreColumnName = SCORE_PROP_NAME;
336+
if (!input->yieldVariables.empty()) {
337+
scoreColumnName = GDSFunction::bindColumnName(input->yieldVariables[1], scoreColumnName);
338+
}
339+
auto scoreColumn = input->binder->createVariable(scoreColumnName, LogicalType::DOUBLE());
340+
columns.push_back(scoreColumn);
341+
auto bindData =
342+
std::make_unique<QueryFTSBindData>(std::move(columns), std::move(graphEntry), nodeOutput,
343+
std::move(query), *ftsIndexEntry, QueryFTSOptionalParams{input->optionalParamsLegacy});
344+
context->setUseInternalCatalogEntry(false /* useInternalCatalogEntry */);
345+
return bindData;
339346
}
340347

341348
function_set QueryFTSFunction::getFunctionSet() {
342349
function_set result;
343-
auto algo = std::make_unique<QueryFTSAlgorithm>();
344-
result.push_back(
345-
std::make_unique<GDSFunction>(name, algo->getParameterTypeIDs(), std::move(algo)));
350+
// inputs are tableName, indexName, query
351+
auto func = std::make_unique<TableFunction>(QueryFTSFunction::name,
352+
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING,
353+
LogicalTypeID::STRING});
354+
func->bindFunc = bindFunc;
355+
func->tableFunc = tableFunc;
356+
func->initSharedStateFunc = GDSFunction::initSharedState;
357+
func->initLocalStateFunc = TableFunction::initEmptyLocalState;
358+
func->canParallelFunc = [] { return false; };
359+
func->getLogicalPlanFunc = GDSFunction::getLogicalPlan;
360+
func->getPhysicalPlanFunc = GDSFunction::getPhysicalPlan;
361+
result.push_back(std::move(func));
346362
return result;
347363
}
348364

extension/fts/src/include/function/query_fts_bind_data.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ struct QueryFTSBindData final : function::GDSBindData {
2424
QueryFTSOptionalParams optionalParams;
2525
common::table_id_t outputTableID;
2626

27-
QueryFTSBindData(graph::GraphEntry graphEntry, std::shared_ptr<binder::Expression> docs,
28-
std::shared_ptr<binder::Expression> query, const catalog::IndexCatalogEntry& entry,
29-
QueryFTSOptionalParams optionalParams)
30-
: GDSBindData{std::move(graphEntry), std::move(docs)}, query{std::move(query)},
31-
entry{entry}, optionalParams{std::move(optionalParams)},
27+
QueryFTSBindData(binder::expression_vector columns, graph::GraphEntry graphEntry,
28+
std::shared_ptr<binder::Expression> docs, std::shared_ptr<binder::Expression> query,
29+
const catalog::IndexCatalogEntry& entry, QueryFTSOptionalParams optionalParams)
30+
: GDSBindData{std::move(columns), std::move(graphEntry), std::move(docs)},
31+
query{std::move(query)}, entry{entry}, optionalParams{std::move(optionalParams)},
3232
outputTableID{
3333
nodeOutput->constCast<binder::NodeExpression>().getSingleEntry()->getTableID()} {}
3434
QueryFTSBindData(const QueryFTSBindData& other)
@@ -39,7 +39,7 @@ struct QueryFTSBindData final : function::GDSBindData {
3939

4040
QueryFTSConfig getConfig() const { return optionalParams.getConfig(); }
4141

42-
std::unique_ptr<GDSBindData> copy() const override {
42+
std::unique_ptr<function::TableFuncBindData> copy() const override {
4343
return std::make_unique<QueryFTSBindData>(*this);
4444
}
4545
};

extension/fts/src/include/function/query_fts_index.h

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,10 @@
22

33
#include "function/gds/gds.h"
44
#include "function/gds/gds_frontier.h"
5-
#include "function/gds_function.h"
65

76
namespace kuzu {
87
namespace fts_extension {
98

10-
class QueryFTSAlgorithm : public function::GDSAlgorithm {
11-
public:
12-
static constexpr char SCORE_PROP_NAME[] = "score";
13-
static constexpr char DOC_FREQUENCY_PROP_NAME[] = "df";
14-
static constexpr char TERM_FREQUENCY_PROP_NAME[] = "tf";
15-
static constexpr char DOC_LEN_PROP_NAME[] = "len";
16-
static constexpr char DOC_ID_PROP_NAME[] = "docID";
17-
18-
public:
19-
QueryFTSAlgorithm() = default;
20-
QueryFTSAlgorithm(const QueryFTSAlgorithm& other) : GDSAlgorithm{other} {}
21-
22-
/*
23-
* Inputs include the following:
24-
*
25-
* graph::ANY
26-
* srcNode::NODE
27-
* queryString: STRING
28-
*/
29-
std::vector<common::LogicalTypeID> getParameterTypeIDs() const override {
30-
return {common::LogicalTypeID::STRING /* tableName */,
31-
common::LogicalTypeID::STRING /* indexName */,
32-
common::LogicalTypeID::STRING /* query */};
33-
}
34-
35-
void exec(processor::ExecutionContext* executionContext) override;
36-
37-
std::unique_ptr<GDSAlgorithm> copy() const override {
38-
return std::make_unique<QueryFTSAlgorithm>(*this);
39-
}
40-
41-
binder::expression_vector getResultColumns(
42-
const function::GDSBindInput& bindInput) const override;
43-
44-
void bind(const function::GDSBindInput& input, main::ClientContext&) override;
45-
};
46-
479
struct QueryFTSFunction {
4810
static constexpr const char* name = "QUERY_FTS_INDEX";
4911

src/binder/bind/bind_table_function.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,35 @@ using namespace kuzu::function;
1212
namespace kuzu {
1313
namespace binder {
1414

15-
static void validateParameterType(const expression_vector& positionalParams) {
16-
for (auto& param : positionalParams) {
17-
ExpressionUtil::validateExpressionType(*param,
18-
{ExpressionType::LITERAL, ExpressionType::PARAMETER});
19-
}
20-
}
21-
2215
BoundTableScanInfo Binder::bindTableFunc(const std::string& tableFuncName,
2316
const parser::ParsedExpression& expr, std::vector<parser::YieldVariable> yieldVariables) {
2417
auto entry = clientContext->getCatalog()->getFunctionEntry(clientContext->getTransaction(),
2518
tableFuncName, clientContext->useInternalCatalogEntry());
2619
expression_vector positionalParams;
2720
std::vector<LogicalType> positionalParamTypes;
2821
optional_params_t optionalParams;
22+
expression_vector optionalParamsLegacy;
2923
for (auto i = 0u; i < expr.getNumChildren(); i++) {
3024
auto& childExpr = *expr.getChild(i);
3125
auto param = expressionBinder.bindExpression(childExpr);
3226
if (!childExpr.hasAlias()) {
27+
ExpressionUtil::validateExpressionType(*param,
28+
{ExpressionType::LITERAL, ExpressionType::PARAMETER});
3329
positionalParams.push_back(param);
3430
positionalParamTypes.push_back(param->getDataType().copy());
3531
} else {
36-
ExpressionUtil::validateExpressionType(*param, ExpressionType::LITERAL);
37-
auto literalExpr = param->constPtrCast<LiteralExpression>();
38-
optionalParams.emplace(childExpr.getAlias(), literalExpr->getValue());
32+
ExpressionUtil::validateExpressionType(*param,
33+
{ExpressionType::LITERAL, ExpressionType::PARAMETER});
34+
if (param->expressionType == ExpressionType::LITERAL) {
35+
auto literalExpr = param->constPtrCast<LiteralExpression>();
36+
optionalParams.emplace(childExpr.getAlias(), literalExpr->getValue());
37+
}
38+
param->setAlias(expr.getChild(i)->getAlias());
39+
optionalParamsLegacy.push_back(param);
3940
}
4041
}
4142
auto func = BuiltInFunctionsUtils::matchFunction(tableFuncName, positionalParamTypes,
4243
entry->ptrCast<catalog::FunctionCatalogEntry>());
43-
validateParameterType(positionalParams);
4444
auto tableFunc = func->constPtrCast<TableFunction>();
4545
std::vector<common::LogicalType> inputTypes;
4646
if (tableFunc->inferInputTypes) {
@@ -66,6 +66,7 @@ BoundTableScanInfo Binder::bindTableFunc(const std::string& tableFuncName,
6666
auto bindInput = TableFuncBindInput();
6767
bindInput.params = std::move(positionalParams);
6868
bindInput.optionalParams = std::move(optionalParams);
69+
bindInput.optionalParamsLegacy = std::move(optionalParamsLegacy);
6970
bindInput.binder = this;
7071
bindInput.yieldVariables = std::move(yieldVariables);
7172
return BoundTableScanInfo{*tableFunc, tableFunc->bindFunc(clientContext, &bindInput)};

src/binder/bind/read/bind_in_query_call.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
#include "binder/binder.h"
2-
#include "binder/expression/expression_util.h"
3-
#include "binder/query/reading_clause/bound_gds_call.h"
42
#include "binder/query/reading_clause/bound_table_function_call.h"
53
#include "catalog/catalog.h"
64
#include "common/exception/binder.h"
7-
#include "function/built_in_function_utils.h"
8-
#include "function/gds_function.h"
95
#include "main/client_context.h"
106
#include "parser/expression/parsed_function_expression.h"
117
#include "parser/query/reading_clause/in_query_call_clause.h"
@@ -34,35 +30,6 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
3430
boundReadingClause =
3531
std::make_unique<BoundTableFunctionCall>(std::move(boundTableFunction));
3632
} break;
37-
case CatalogEntryType::GDS_FUNCTION_ENTRY: {
38-
expression_vector children;
39-
std::vector<LogicalType> childrenTypes;
40-
expression_vector optionalParams;
41-
for (auto i = 0u; i < functionExpr->getNumChildren(); i++) {
42-
auto child = expressionBinder.bindExpression(*functionExpr->getChild(i));
43-
if (!functionExpr->getChild(i)->hasAlias()) {
44-
children.push_back(child);
45-
childrenTypes.push_back(child->getDataType().copy());
46-
} else {
47-
ExpressionUtil::validateExpressionType(*child,
48-
{ExpressionType::LITERAL, ExpressionType::PARAMETER});
49-
child->setAlias(functionExpr->getChild(i)->getAlias());
50-
optionalParams.push_back(child);
51-
}
52-
}
53-
auto func = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes,
54-
entry->ptrCast<FunctionCatalogEntry>());
55-
auto gdsFunc = func->constPtrCast<GDSFunction>()->copy();
56-
auto input = GDSBindInput();
57-
input.params = children;
58-
input.binder = this;
59-
input.optionalParams = std::move(optionalParams);
60-
input.yieldVariables = call.getYieldVariables();
61-
gdsFunc.gds->bind(input, *clientContext);
62-
auto columns = gdsFunc.gds->getResultColumns(input);
63-
auto info = BoundGDSCallInfo(gdsFunc.copy(), std::move(columns));
64-
boundReadingClause = std::make_unique<BoundGDSCall>(std::move(info));
65-
} break;
6633
default:
6734
throw BinderException(
6835
stringFormat("{} is not a table or algorithm function.", functionName));

0 commit comments

Comments
 (0)