Skip to content

Commit 37146db

Browse files
andyfengHKUroyi-luo
authored andcommitted
Add dispatch test script
Update testing script Update script Remove testing script Update rust build Update tests so that it passes in wasm Add testing script Fix script Fix indentation Fix script Fix indentation Fix script Add CI workflow for simsimd dispatch test Move job to build and deploy workflow Fix yaml Add test trigger Add dependency on build job Move CI job to separate workflow
1 parent 8cfb655 commit 37146db

File tree

10 files changed

+151
-31
lines changed

10 files changed

+151
-31
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: SimSIMD Dispatch Test
2+
on:
3+
# TODO(Royi) remove once done testing
4+
pull_request:
5+
branches:
6+
- master
7+
schedule:
8+
- cron: "0 5 * * *"
9+
10+
workflow_dispatch:
11+
12+
jobs:
13+
build-precompiled-bin-linux:
14+
if: ${{ github.event_name == 'schedule' || github.event.inputs.skipBinaries != 'true' }}
15+
uses: ./.github/workflows/linux-precompiled-bin-workflow.yml
16+
with:
17+
isNightly: true
18+
secrets: inherit
19+
20+
simsimd-dispatch-test:
21+
name: simsimd-dispatch-test
22+
needs: build-precompiled-bin-linux
23+
runs-on: kuzu-self-hosted-testing
24+
env:
25+
NUM_THREADS: 32
26+
GEN: Ninja
27+
CC: gcc
28+
CXX: g++
29+
steps:
30+
- name: Download nightly build
31+
uses: actions/download-artifact@v4
32+
with:
33+
name: kuzu_cli-linux-x86_64
34+
35+
- name: Extract kuzu shell
36+
run: |
37+
tar xf kuzu_cli-linux-x86_64.tar.gz
38+
39+
- name: Test
40+
run: gdb --batch -x scripts/test-simsimd-dispatch.py --args ./kuzu

scripts/simd-dispatch-test.cypher

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id));
2+
COPY embeddings FROM "dataset/embeddings/embeddings-8-1k.csv" (deLim=',');
3+
CALL CREATE_HNSW_INDEX('e_hnsw_index', 'embeddings', 'vec', distFunc := 'l2');
4+
CALL QUERY_HNSW_INDEX('e_hnsw_index', 'embeddings', CAST([0.1521,0.3021,0.5366,0.2774,0.5593,0.5589,0.1365,0.8557],'FLOAT[8]'), 3) RETURN nn.id ORDER BY _distance;

scripts/test-simsimd-dispatch.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import gdb
2+
import subprocess
3+
4+
5+
def get_machine_architecture():
6+
output = subprocess.check_output(["cat", "/proc/cpuinfo"])
7+
flags_str = ""
8+
for line in output.decode("utf-8").split("\n"):
9+
line = line.strip()
10+
if line.startswith("flags"):
11+
flags_str = line
12+
break
13+
assert len(line) > 0
14+
15+
# Uses the same way as _simsimd_capabilities_x86() in simsimd.h to detect supported features
16+
# The only simsimd functions we use in the vector index are cos/l2/l2_sq/dot which only branch for the targets 'skylake', 'haswell', and 'serial' so we only test for those
17+
if "avx512f" in flags_str:
18+
return "skylake"
19+
elif "avx2" in flags_str and "f16c" in flags_str and "fma" in flags_str:
20+
return "haswell"
21+
else:
22+
return "serial"
23+
24+
25+
bp = gdb.Breakpoint(f"simsimd_l2_f32_{get_machine_architecture()}")
26+
27+
gdb.execute("run < scripts/simd-dispatch-test.cypher")
28+
29+
try:
30+
gdb.execute("continue")
31+
# we only care if the breakpoint is hit at all
32+
# disable it now to prevent the test from needing to execute 'continue' many times to reach completion
33+
bp.enabled = False
34+
except gdb.error:
35+
# the program has terminated
36+
pass
37+
38+
# Check if the breakpoint was hit
39+
if bp.hit_count == 0:
40+
print(
41+
f"Error: did not hit the expected simsimd function for machine architecture '{get_machine_architecture()}'"
42+
)
43+
gdb.execute("quit 1")
44+
45+
gdb.execute("quit")

src/function/sequence/sequence_functions.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ function_set CurrValFunction::getFunctionSet() {
4545

4646
function_set NextValFunction::getFunctionSet() {
4747
function_set functionSet;
48-
functionSet.push_back(make_unique<ScalarFunction>(name,
49-
std::vector<LogicalTypeID>{LogicalTypeID::STRING}, LogicalTypeID::INT64,
48+
auto func = make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::STRING},
49+
LogicalTypeID::INT64,
5050
ScalarFunction::UnarySequenceExecFunction<common::ku_string_t, common::ValueVector,
51-
NextVal>));
51+
NextVal>);
52+
func->isReadOnly = false;
53+
functionSet.push_back(std::move(func));
5254
return functionSet;
5355
}
5456

src/include/function/function.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ struct KUZU_API Function {
6363
std::vector<common::LogicalTypeID> parameterTypeIDs;
6464
// Currently we only one variable-length function which is list creation. The expectation is
6565
// that all parameters must have the same type as parameterTypes[0].
66-
bool isVarLength;
67-
bool isListLambda;
66+
// For variable length function. A
67+
bool isVarLength = false;
68+
bool isListLambda = false;
69+
bool isReadOnly = true;
6870

69-
Function() : isVarLength{false}, isListLambda{false} {};
71+
Function() : isVarLength{false}, isListLambda{false}, isReadOnly{true} {};
7072
Function(std::string name, std::vector<common::LogicalTypeID> parameterTypeIDs)
7173
: name{std::move(name)}, parameterTypeIDs{std::move(parameterTypeIDs)}, isVarLength{false},
7274
isListLambda{false} {}

src/include/parser/expression/parsed_expression_visitor.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,17 @@ class ParsedParamExprCollector : public ParsedExpressionVisitor {
4646
std::vector<const ParsedExpression*> paramExprs;
4747
};
4848

49-
class ParsedSequenceFunctionCollector : public ParsedExpressionVisitor {
49+
class ReadWriteExprAnalyzer : public ParsedExpressionVisitor {
5050
public:
51-
bool hasSeqUpdate() const { return hasSeqUpdate_; }
51+
explicit ReadWriteExprAnalyzer(main::ClientContext* context)
52+
: ParsedExpressionVisitor{}, context{context} {}
53+
54+
bool isReadOnly() const { return readOnly; }
5255
void visitFunctionExpr(const ParsedExpression* expr) override;
5356

5457
private:
55-
bool hasSeqUpdate_ = false;
58+
main::ClientContext* context;
59+
bool readOnly = true;
5660
};
5761

5862
class MacroParameterReplacer : public ParsedExpressionVisitor {

src/include/parser/visitor/statement_read_write_analyzer.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#pragma once
22

3+
#include "parser/expression/parsed_expression.h"
34
#include "parser/parsed_statement_visitor.h"
45

56
namespace kuzu {
67
namespace parser {
78

89
class StatementReadWriteAnalyzer final : public StatementVisitor {
910
public:
10-
StatementReadWriteAnalyzer() : StatementVisitor{}, readOnly{true} {}
11+
explicit StatementReadWriteAnalyzer(main::ClientContext* context)
12+
: StatementVisitor{}, readOnly{true}, context{context} {}
1113

12-
bool isReadOnly(const Statement& statement);
14+
bool isReadOnly() const { return readOnly; }
1315

1416
private:
1517
void visitCreateSequence(const Statement& /*statement*/) override { readOnly = false; }
@@ -31,8 +33,11 @@ class StatementReadWriteAnalyzer final : public StatementVisitor {
3133
readOnly = false;
3234
}
3335

36+
bool isExprReadOnly(const ParsedExpression* expr);
37+
3438
private:
3539
bool readOnly;
40+
main::ClientContext* context;
3641
};
3742

3843
} // namespace parser

src/main/client_context.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,12 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
462462
prepareTimer.start();
463463
try {
464464
preparedStatement->preparedSummary.statementType = parsedStatement->getStatementType();
465-
preparedStatement->readOnly = StatementReadWriteAnalyzer().isReadOnly(*parsedStatement);
465+
auto readWriteAnalyzer = StatementReadWriteAnalyzer(this);
466+
TransactionHelper::runFuncInTransaction(
467+
*transactionContext, [&]() -> void { readWriteAnalyzer.visit(*parsedStatement); },
468+
true /* readOnly */, false /* */,
469+
TransactionHelper::TransactionCommitAction::COMMIT_IF_NEW);
470+
preparedStatement->readOnly = readWriteAnalyzer.isReadOnly();
466471
preparedStatement->parsedStatement = std::move(parsedStatement);
467472
validateTransaction(*preparedStatement);
468473

src/parser/expression/parsed_expression_visitor.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#include "parser/expression/parsed_expression_visitor.h"
22

3+
#include "catalog/catalog.h"
4+
#include "catalog/catalog_entry/function_catalog_entry.h"
35
#include "common/exception/not_implemented.h"
4-
#include "function/sequence/sequence_functions.h"
6+
#include "main/client_context.h"
57
#include "parser/expression/parsed_case_expression.h"
68
#include "parser/expression/parsed_function_expression.h"
79
#include "parser/expression/parsed_lambda_expression.h"
810

911
using namespace kuzu::common;
12+
using namespace kuzu::catalog;
1013

1114
namespace kuzu {
1215
namespace parser {
@@ -141,13 +144,28 @@ void ParsedExpressionVisitor::visitCaseChildrenUnsafe(ParsedExpression& expr) {
141144
}
142145
}
143146

144-
void ParsedSequenceFunctionCollector::visitFunctionExpr(const ParsedExpression* expr) {
147+
void ReadWriteExprAnalyzer::visitFunctionExpr(const ParsedExpression* expr) {
145148
if (expr->getExpressionType() != ExpressionType::FUNCTION) {
149+
// Can be AND/OR/... which guarantees to be readonly.
146150
return;
147151
}
148-
auto funName = expr->constCast<ParsedFunctionExpression>().getFunctionName();
149-
if (StringUtils::getUpper(funName) == function::NextValFunction::name) {
150-
hasSeqUpdate_ = true;
152+
auto funcName = expr->constCast<ParsedFunctionExpression>().getFunctionName();
153+
auto catalog = context->getCatalog();
154+
// Assume user cannot add function with sideeffect, i.e. all non-readonly function is
155+
// registered when database starts.
156+
auto transaction = &transaction::DUMMY_TRANSACTION;
157+
if (!catalog->containsFunction(transaction, funcName)) {
158+
return;
159+
}
160+
auto entry = catalog->getFunctionEntry(transaction, funcName);
161+
if (entry->getType() != CatalogEntryType::SCALAR_FUNCTION_ENTRY) {
162+
// Can be macro function which guarantees to be readonly.
163+
return;
164+
}
165+
auto& funcSet = entry->constPtrCast<FunctionCatalogEntry>()->getFunctionSet();
166+
KU_ASSERT(!funcSet.empty());
167+
if (!funcSet[0]->isReadOnly) {
168+
readOnly = false;
151169
}
152170
}
153171

src/parser/visitor/statement_read_write_analyzer.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,17 @@
77
namespace kuzu {
88
namespace parser {
99

10-
bool StatementReadWriteAnalyzer::isReadOnly(const Statement& statement) {
11-
visit(statement);
12-
return readOnly;
13-
}
14-
15-
static bool hasSequenceUpdate(const ParsedExpression* expr) {
16-
auto collector = ParsedSequenceFunctionCollector();
17-
collector.visit(expr);
18-
return collector.hasSeqUpdate();
19-
}
20-
2110
void StatementReadWriteAnalyzer::visitReadingClause(const ReadingClause* readingClause) {
2211
if (readingClause->hasWherePredicate()) {
23-
if (hasSequenceUpdate(readingClause->getWherePredicate())) {
12+
if (!isExprReadOnly(readingClause->getWherePredicate())) {
2413
readOnly = false;
2514
}
2615
}
2716
}
2817

2918
void StatementReadWriteAnalyzer::visitWithClause(const WithClause* withClause) {
3019
for (auto& expr : withClause->getProjectionBody()->getProjectionExpressions()) {
31-
if (hasSequenceUpdate(expr.get())) {
20+
if (!isExprReadOnly(expr.get())) {
3221
readOnly = false;
3322
return;
3423
}
@@ -37,12 +26,18 @@ void StatementReadWriteAnalyzer::visitWithClause(const WithClause* withClause) {
3726

3827
void StatementReadWriteAnalyzer::visitReturnClause(const ReturnClause* returnClause) {
3928
for (auto& expr : returnClause->getProjectionBody()->getProjectionExpressions()) {
40-
if (hasSequenceUpdate(expr.get())) {
29+
if (!isExprReadOnly(expr.get())) {
4130
readOnly = false;
4231
return;
4332
}
4433
}
4534
}
4635

36+
bool StatementReadWriteAnalyzer::isExprReadOnly(const ParsedExpression* expr) {
37+
auto analyzer = ReadWriteExprAnalyzer(context);
38+
analyzer.visit(expr);
39+
return analyzer.isReadOnly();
40+
}
41+
4742
} // namespace parser
4843
} // namespace kuzu

0 commit comments

Comments
 (0)