10
10
#include " function/gds/gds_utils.h"
11
11
#include " function/query_fts_bind_data.h"
12
12
#include " processor/execution_context.h"
13
- #include " processor/operator/gds_call_shared_state.h"
14
13
#include " storage/index/index_utils.h"
15
14
#include " storage/storage_manager.h"
16
15
#include " storage/store/node_table.h"
@@ -153,7 +152,7 @@ void QFTSOutputWriter::write(processor::FactorizedTable& scoreFT, nodeID_t docNo
153
152
154
153
class QFTSVertexCompute final : public VertexCompute {
155
154
public:
156
- QFTSVertexCompute (MemoryManager* mm, processor::GDSCallSharedState * sharedState,
155
+ QFTSVertexCompute (MemoryManager* mm, GDSFuncSharedState * sharedState,
157
156
std::unique_ptr<QFTSOutputWriter> writer)
158
157
: mm{mm}, sharedState{sharedState}, writer{std::move (writer)} {
159
158
localFT = sharedState->factorizedTablePool .claimLocalTable (mm);
@@ -175,18 +174,24 @@ class QFTSVertexCompute final : public VertexCompute {
175
174
176
175
private:
177
176
MemoryManager* mm;
178
- processor::GDSCallSharedState * sharedState;
177
+ GDSFuncSharedState * sharedState;
179
178
processor::FactorizedTable* localFT;
180
179
std::unique_ptr<QFTSOutputWriter> writer;
181
180
};
182
181
182
+ static constexpr char SCORE_PROP_NAME[] = " score" ;
183
+ static constexpr char DOC_FREQUENCY_PROP_NAME[] = " df" ;
184
+ static constexpr char TERM_FREQUENCY_PROP_NAME[] = " tf" ;
185
+ static constexpr char DOC_LEN_PROP_NAME[] = " len" ;
186
+ static constexpr char DOC_ID_PROP_NAME[] = " docID" ;
187
+
183
188
static std::unordered_map<offset_t , uint64_t > getDFs (main::ClientContext& context,
184
189
const catalog::NodeTableCatalogEntry& termsEntry, const std::vector<std::string>& terms) {
185
190
auto storageManager = context.getStorageManager ();
186
191
auto tableID = termsEntry.getTableID ();
187
192
auto & termsNodeTable = storageManager->getTable (tableID)->cast <NodeTable>();
188
193
auto tx = context.getTransaction ();
189
- auto dfColumnID = termsEntry.getColumnID (QueryFTSAlgorithm:: DOC_FREQUENCY_PROP_NAME);
194
+ auto dfColumnID = termsEntry.getColumnID (DOC_FREQUENCY_PROP_NAME);
190
195
std::vector<LogicalType> vectorTypes;
191
196
vectorTypes.push_back (LogicalType::INTERNAL_ID ());
192
197
vectorTypes.push_back (LogicalType::UINT64 ());
@@ -231,21 +236,22 @@ static uint64_t getSparseFrontierSize(uint64_t numRows) {
231
236
return size;
232
237
}
233
238
234
- void QueryFTSAlgorithm::exec (processor::ExecutionContext* executionContext ) {
235
- auto clientContext = executionContext ->clientContext ;
239
+ static common:: offset_t tableFunc ( const TableFuncInput& input, TableFuncOutput& ) {
240
+ auto clientContext = input. context ->clientContext ;
236
241
auto transaction = clientContext->getTransaction ();
242
+ auto sharedState = input.sharedState ->ptrCast <GDSFuncSharedState>();
237
243
auto graph = sharedState->graph .get ();
238
244
auto graphEntry = graph->getGraphEntry ();
239
- auto qFTSBindData = bindData->ptrCast <QueryFTSBindData>();
245
+ auto qFTSBindData = input. bindData ->constPtrCast <QueryFTSBindData>();
240
246
auto & termsEntry = graphEntry->nodeInfos [0 ].entry ->constCast <catalog::NodeTableCatalogEntry>();
241
- auto terms = bindData-> ptrCast <QueryFTSBindData>()-> getTerms (*executionContext ->clientContext );
242
- auto dfs = getDFs (*executionContext ->clientContext , termsEntry, terms);
247
+ auto terms = qFTSBindData-> getTerms (*input. context ->clientContext );
248
+ auto dfs = getDFs (*input. context ->clientContext , termsEntry, terms);
243
249
// Do edge compute to extend terms -> docs and save the term frequency and document frequency
244
250
// for each term-doc pair. The reason why we store the term frequency and document frequency
245
251
// is that: we need the `len` property from the docs table which is only available during the
246
252
// vertex compute.
247
- auto currentFrontier = PathLengths::getUnvisitedFrontier (executionContext , graph);
248
- auto nextFrontier = PathLengths::getUnvisitedFrontier (executionContext , graph);
253
+ auto currentFrontier = PathLengths::getUnvisitedFrontier (input. context , graph);
254
+ auto nextFrontier = PathLengths::getUnvisitedFrontier (input. context , graph);
249
255
auto frontierPair =
250
256
std::make_unique<DoublePathLengthsFrontierPair>(currentFrontier, nextFrontier);
251
257
auto termsTableID = termsEntry.getTableID ();
@@ -263,15 +269,15 @@ void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
263
269
auto auxiliaryState = std::make_unique<EmptyGDSAuxiliaryState>();
264
270
auto compState = GDSComputeState (std::move (frontierPair), std::move (edgeCompute),
265
271
std::move (auxiliaryState), nullptr /* outputNodeMask */ );
266
- GDSUtils::runFrontiersUntilConvergence (executionContext , compState, graph, ExtendDirection::FWD,
272
+ GDSUtils::runFrontiersUntilConvergence (input. context , compState, graph, ExtendDirection::FWD,
267
273
1 /* maxIters */ , TERM_FREQUENCY_PROP_NAME);
268
274
269
275
// Do vertex compute to calculate the score for doc with the length property.
270
276
auto mm = clientContext->getMemoryManager ();
271
277
auto numUniqueTerms = getNumUniqueTerms (terms);
272
278
auto writer = std::make_unique<QFTSOutputWriter>(scores, mm, qFTSBindData->getConfig (),
273
279
*qFTSBindData, numUniqueTerms);
274
- auto vc = std::make_unique<QFTSVertexCompute>(mm, sharedState. get () , std::move (writer));
280
+ auto vc = std::make_unique<QFTSVertexCompute>(mm, sharedState, std::move (writer));
275
281
auto vertexPropertiesToScan = std::vector<std::string>{DOC_LEN_PROP_NAME, DOC_ID_PROP_NAME};
276
282
auto docsEntry = graphEntry->nodeInfos [1 ].entry ;
277
283
auto numDocs = storageManager->getTable (docsEntry->getTableID ())->getNumTotalRows (transaction);
@@ -284,65 +290,75 @@ void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
284
290
}
285
291
}
286
292
} else {
287
- GDSUtils::runVertexCompute (executionContext , graph, *vc, docsEntry, vertexPropertiesToScan);
293
+ GDSUtils::runVertexCompute (input. context , graph, *vc, docsEntry, vertexPropertiesToScan);
288
294
}
289
295
sharedState->factorizedTablePool .mergeLocalTables ();
296
+ return 0 ;
290
297
}
291
298
292
- expression_vector QueryFTSAlgorithm::getResultColumns (const GDSBindInput& bindInput) const {
293
- expression_vector columns;
294
- auto & docsNode = bindData->getNodeOutput ()->constCast <NodeExpression>();
295
- columns.push_back (docsNode.getInternalID ());
296
- std::string scoreColumnName = SCORE_PROP_NAME;
297
- if (!bindInput.yieldVariables .empty ()) {
298
- scoreColumnName = bindColumnName (bindInput.yieldVariables [1 ], scoreColumnName);
299
- }
300
- auto scoreColumn = bindInput.binder ->createVariable (scoreColumnName, LogicalType::DOUBLE ());
301
- columns.push_back (scoreColumn);
302
- return columns;
303
- }
304
-
305
- static std::string getParamVal (const GDSBindInput& input, idx_t idx) {
299
+ static std::string getParamVal (const TableFuncBindInput& input, idx_t idx) {
306
300
if (input.getParam (idx)->expressionType != ExpressionType::LITERAL) {
307
301
throw BinderException{" The table and index name must be literal expressions." };
308
302
}
309
303
return ExpressionUtil::getLiteralValue<std::string>(
310
304
input.getParam (idx)->constCast <LiteralExpression>());
311
305
}
312
306
313
- void QueryFTSAlgorithm::bind (const GDSBindInput& input, main::ClientContext& context) {
314
- context.setUseInternalCatalogEntry (true /* useInternalCatalogEntry */ );
307
+ static std::unique_ptr<TableFuncBindData> bindFunc (main::ClientContext* context,
308
+ const TableFuncBindInput* input) {
309
+ context->setUseInternalCatalogEntry (true /* useInternalCatalogEntry */ );
315
310
// For queryFTS, the table and index name must be given at compile time while the user
316
311
// can give the query at runtime.
317
- auto inputTableName = getParamVal (input, 0 );
318
- auto indexName = getParamVal (input, 1 );
319
- auto query = input. getParam (2 );
312
+ auto inputTableName = getParamVal (* input, 0 );
313
+ auto indexName = getParamVal (* input, 1 );
314
+ auto query = input-> getParam (2 );
320
315
321
316
auto tableEntry =
322
- IndexUtils::bindTable (context, inputTableName, indexName, IndexOperation::QUERY);
323
- auto ftsIndexEntry = context. getCatalog ()-> getIndex (context. getTransaction (),
324
- tableEntry-> getTableID (), indexName );
325
- auto entry =
326
- context. getCatalog () ->getTableCatalogEntry (context. getTransaction () , inputTableName);
327
- auto nodeOutput = bindNodeOutput (input, {entry});
328
-
329
- auto termsEntry = context. getCatalog () ->getTableCatalogEntry (context. getTransaction () ,
317
+ IndexUtils::bindTable (* context, inputTableName, indexName, IndexOperation::QUERY);
318
+ auto catalog = context-> getCatalog ();
319
+ auto transaction = context-> getTransaction ( );
320
+ auto ftsIndexEntry = catalog-> getIndex (transaction, tableEntry-> getTableID (), indexName);
321
+ auto entry = catalog ->getTableCatalogEntry (transaction , inputTableName);
322
+ auto nodeOutput = GDSFunction:: bindNodeOutput (* input, {entry});
323
+
324
+ auto termsEntry = catalog ->getTableCatalogEntry (transaction ,
330
325
FTSUtils::getTermsTableName (tableEntry->getTableID (), indexName));
331
- auto docsEntry = context. getCatalog () ->getTableCatalogEntry (context. getTransaction () ,
326
+ auto docsEntry = catalog ->getTableCatalogEntry (transaction ,
332
327
FTSUtils::getDocsTableName (tableEntry->getTableID (), indexName));
333
- auto appearsInEntry = context. getCatalog () ->getTableCatalogEntry (context. getTransaction () ,
328
+ auto appearsInEntry = catalog ->getTableCatalogEntry (transaction ,
334
329
FTSUtils::getAppearsInTableName (tableEntry->getTableID (), indexName));
335
330
auto graphEntry = graph::GraphEntry ({termsEntry, docsEntry}, {appearsInEntry});
336
- bindData = std::make_unique<QueryFTSBindData>(std::move (graphEntry), nodeOutput,
337
- std::move (query), *ftsIndexEntry, QueryFTSOptionalParams{input.optionalParams });
338
- context.setUseInternalCatalogEntry (false /* useInternalCatalogEntry */ );
331
+
332
+ expression_vector columns;
333
+ auto & docsNode = nodeOutput->constCast <NodeExpression>();
334
+ columns.push_back (docsNode.getInternalID ());
335
+ std::string scoreColumnName = SCORE_PROP_NAME;
336
+ if (!input->yieldVariables .empty ()) {
337
+ scoreColumnName = GDSFunction::bindColumnName (input->yieldVariables [1 ], scoreColumnName);
338
+ }
339
+ auto scoreColumn = input->binder ->createVariable (scoreColumnName, LogicalType::DOUBLE ());
340
+ columns.push_back (scoreColumn);
341
+ auto bindData =
342
+ std::make_unique<QueryFTSBindData>(std::move (columns), std::move (graphEntry), nodeOutput,
343
+ std::move (query), *ftsIndexEntry, QueryFTSOptionalParams{input->optionalParamsLegacy });
344
+ context->setUseInternalCatalogEntry (false /* useInternalCatalogEntry */ );
345
+ return bindData;
339
346
}
340
347
341
348
function_set QueryFTSFunction::getFunctionSet () {
342
349
function_set result;
343
- auto algo = std::make_unique<QueryFTSAlgorithm>();
344
- result.push_back (
345
- std::make_unique<GDSFunction>(name, algo->getParameterTypeIDs (), std::move (algo)));
350
+ // inputs are tableName, indexName, query
351
+ auto func = std::make_unique<TableFunction>(QueryFTSFunction::name,
352
+ std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING,
353
+ LogicalTypeID::STRING});
354
+ func->bindFunc = bindFunc;
355
+ func->tableFunc = tableFunc;
356
+ func->initSharedStateFunc = GDSFunction::initSharedState;
357
+ func->initLocalStateFunc = TableFunction::initEmptyLocalState;
358
+ func->canParallelFunc = [] { return false ; };
359
+ func->getLogicalPlanFunc = GDSFunction::getLogicalPlan;
360
+ func->getPhysicalPlanFunc = GDSFunction::getPhysicalPlan;
361
+ result.push_back (std::move (func));
346
362
return result;
347
363
}
348
364
0 commit comments