Skip to content

Commit d41216c

Browse files
authored
Added Support for Symbolic func attribute (#2367)
1 parent 8aaa4a7 commit d41216c

File tree

5 files changed

+297
-10
lines changed

5 files changed

+297
-10
lines changed

integration_tests/symbolics_02.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,71 @@
1-
from sympy import Symbol, pi
1+
from sympy import Symbol, pi, Add, Mul, Pow
22
from lpython import S
33

44
def test_symbolic_operations():
55
x: S = Symbol('x')
66
y: S = Symbol('y')
7-
p1: S = pi
8-
p2: S = pi
7+
pi1: S = pi
8+
pi2: S = pi
99

1010
# Addition
1111
z: S = x + y
12+
z1: bool = z.func == Add
13+
z2: bool = z.func == Mul
1214
assert(z == x + y)
15+
assert(z1 == True)
16+
assert(z2 == False)
17+
if z.func == Add:
18+
assert True
19+
else:
20+
assert False
1321
print(z)
1422

1523
# Subtraction
1624
w: S = x - y
25+
w1: bool = w.func == Add
1726
assert(w == x - y)
27+
assert(w1 == True)
28+
if w.func == Add:
29+
assert True
30+
else:
31+
assert False
1832
print(w)
1933

2034
# Multiplication
2135
u: S = x * y
36+
u1: bool = u.func == Mul
2237
assert(u == x * y)
38+
assert(u1 == True)
39+
if u.func == Mul:
40+
assert True
41+
else:
42+
assert False
2343
print(u)
2444

2545
# Division
2646
v: S = x / y
47+
v1: bool = v.func == Mul
2748
assert(v == x / y)
49+
assert(v1 == True)
50+
if v.func == Mul:
51+
assert True
52+
else:
53+
assert False
2854
print(v)
2955

3056
# Power
3157
p: S = x ** y
58+
p1: bool = p.func == Pow
59+
p2: bool = p.func == Add
60+
p3: bool = p.func == Mul
3261
assert(p == x ** y)
62+
assert(p1 == True)
63+
assert(p2 == False)
64+
assert(p3 == False)
65+
if p.func == Pow:
66+
assert True
67+
else:
68+
assert False
3369
print(p)
3470

3571
# Casting
@@ -40,13 +76,13 @@ def test_symbolic_operations():
4076
print(c)
4177

4278
# Comparison
43-
b1: bool = p1 == p2
79+
b1: bool = pi1 == pi2
4480
print(b1)
4581
assert(b1 == True)
46-
b2: bool = p1 != pi
82+
b2: bool = pi1 != pi
4783
print(b2)
4884
assert(b2 == False)
49-
b3: bool = p1 != x
85+
b3: bool = pi1 != x
5086
print(b3)
5187
assert(b3 == True)
5288
b4: bool = pi == Symbol("x")

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ enum class IntrinsicScalarFunctions : int64_t {
7979
SymbolicExp,
8080
SymbolicAbs,
8181
SymbolicHasSymbolQ,
82+
SymbolicAddQ,
83+
SymbolicMulQ,
84+
SymbolicPowQ,
8285
// ...
8386
};
8487

@@ -140,6 +143,9 @@ inline std::string get_intrinsic_name(int x) {
140143
INTRINSIC_NAME_CASE(SymbolicExp)
141144
INTRINSIC_NAME_CASE(SymbolicAbs)
142145
INTRINSIC_NAME_CASE(SymbolicHasSymbolQ)
146+
INTRINSIC_NAME_CASE(SymbolicAddQ)
147+
INTRINSIC_NAME_CASE(SymbolicMulQ)
148+
INTRINSIC_NAME_CASE(SymbolicPowQ)
143149
default : {
144150
throw LCompilersException("pickle: intrinsic_id not implemented");
145151
}
@@ -3100,6 +3106,48 @@ namespace SymbolicHasSymbolQ {
31003106
}
31013107
} // namespace SymbolicHasSymbolQ
31023108

3109+
#define create_symbolic_query_macro(X) \
3110+
namespace X { \
3111+
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
3112+
diag::Diagnostics& diagnostics) { \
3113+
const Location& loc = x.base.base.loc; \
3114+
ASRUtils::require_impl(x.n_args == 1, \
3115+
#X " must have exactly 1 input argument", loc, diagnostics); \
3116+
\
3117+
ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]); \
3118+
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*input_type), \
3119+
#X " expects an argument of type SymbolicExpression", loc, diagnostics); \
3120+
} \
3121+
\
3122+
static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \
3123+
ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/) { \
3124+
/*TODO*/ \
3125+
return nullptr; \
3126+
} \
3127+
\
3128+
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
3129+
Vec<ASR::expr_t*>& args, \
3130+
const std::function<void (const std::string &, const Location &)> err) { \
3131+
if (args.size() != 1) { \
3132+
err("Intrinsic " #X " function accepts exactly 1 argument", loc); \
3133+
} \
3134+
\
3135+
ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]); \
3136+
if (!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) { \
3137+
err("Argument of " #X " function must be of type SymbolicExpression", \
3138+
args[0]->base.loc); \
3139+
} \
3140+
\
3141+
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_##X, \
3142+
static_cast<int64_t>(IntrinsicScalarFunctions::X), 0, logical); \
3143+
} \
3144+
} // namespace X
3145+
3146+
create_symbolic_query_macro(SymbolicAddQ)
3147+
create_symbolic_query_macro(SymbolicMulQ)
3148+
create_symbolic_query_macro(SymbolicPowQ)
3149+
3150+
31033151
#define create_symbolic_unary_macro(X) \
31043152
namespace X { \
31053153
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
@@ -3253,6 +3301,12 @@ namespace IntrinsicScalarFunctionRegistry {
32533301
{nullptr, &SymbolicAbs::verify_args}},
32543302
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
32553303
{nullptr, &SymbolicHasSymbolQ::verify_args}},
3304+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicAddQ),
3305+
{nullptr, &SymbolicAddQ::verify_args}},
3306+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicMulQ),
3307+
{nullptr, &SymbolicMulQ::verify_args}},
3308+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPowQ),
3309+
{nullptr, &SymbolicPowQ::verify_args}},
32563310
};
32573311

32583312
static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
@@ -3357,6 +3411,12 @@ namespace IntrinsicScalarFunctionRegistry {
33573411
"SymbolicAbs"},
33583412
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicHasSymbolQ),
33593413
"SymbolicHasSymbolQ"},
3414+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicAddQ),
3415+
"SymbolicAddQ"},
3416+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicMulQ),
3417+
"SymbolicMulQ"},
3418+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPowQ),
3419+
"SymbolicPowQ"},
33603420
};
33613421

33623422

@@ -3412,6 +3472,9 @@ namespace IntrinsicScalarFunctionRegistry {
34123472
{"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}},
34133473
{"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}},
34143474
{"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}},
3475+
{"AddQ", {&SymbolicAddQ::create_SymbolicAddQ, &SymbolicAddQ::eval_SymbolicAddQ}},
3476+
{"MulQ", {&SymbolicMulQ::create_SymbolicMulQ, &SymbolicMulQ::eval_SymbolicMulQ}},
3477+
{"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}},
34153478
};
34163479

34173480
static inline bool is_intrinsic_function(const std::string& name) {

src/libasr/pass/replace_symbolic.cpp

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
672672
return module_scope->get_symbol(name);
673673
}
674674

675+
ASR::symbol_t* declare_basic_get_type_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
676+
std::string name = "basic_get_type";
677+
symbolic_dependencies.push_back(name);
678+
if (!module_scope->get_symbol(name)) {
679+
std::string header = "symengine/cwrapper.h";
680+
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);
681+
682+
Vec<ASR::expr_t*> args;
683+
args.reserve(al, 1);
684+
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
685+
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
686+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)),
687+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
688+
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
689+
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
690+
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
691+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
692+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
693+
fn_symtab->add_symbol(s2c(al, "x"), arg2);
694+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
695+
696+
Vec<ASR::stmt_t*> body;
697+
body.reserve(al, 1);
698+
699+
Vec<char*> dep;
700+
dep.reserve(al, 1);
701+
702+
ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
703+
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
704+
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
705+
return_var, ASR::abiType::BindC, ASR::accessType::Public,
706+
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
707+
false, false, nullptr, 0, false, false, false, s2c(al, header));
708+
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
709+
module_scope->add_symbol(s2c(al, name), symbol);
710+
}
711+
return module_scope->get_symbol(name);
712+
}
713+
675714
ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
676715
std::string name = "basic_eq";
677716
symbolic_dependencies.push_back(name);
@@ -828,6 +867,60 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
828867
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
829868
break;
830869
}
870+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ: {
871+
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
872+
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
873+
Vec<ASR::call_arg_t> call_args;
874+
call_args.reserve(al, 1);
875+
ASR::call_arg_t call_arg;
876+
call_arg.loc = loc;
877+
call_arg.m_value = value1;
878+
call_args.push_back(al, call_arg);
879+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
880+
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
881+
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
882+
// Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM
883+
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
884+
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
885+
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
886+
break;
887+
}
888+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: {
889+
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
890+
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
891+
Vec<ASR::call_arg_t> call_args;
892+
call_args.reserve(al, 1);
893+
ASR::call_arg_t call_arg;
894+
call_arg.loc = loc;
895+
call_arg.m_value = value1;
896+
call_args.push_back(al, call_arg);
897+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
898+
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
899+
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
900+
// Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM
901+
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
902+
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
903+
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
904+
break;
905+
}
906+
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: {
907+
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
908+
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
909+
Vec<ASR::call_arg_t> call_args;
910+
call_args.reserve(al, 1);
911+
ASR::call_arg_t call_arg;
912+
call_arg.loc = loc;
913+
call_arg.m_value = value1;
914+
call_args.push_back(al, call_arg);
915+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
916+
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
917+
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
918+
// Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM
919+
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
920+
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
921+
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
922+
break;
923+
}
831924
default: {
832925
throw LCompilersException("IntrinsicFunction: `"
833926
+ ASRUtils::get_intrinsic_name(intrinsic_id)
@@ -998,6 +1091,20 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
9981091
}
9991092
}
10001093

1094+
void visit_If(const ASR::If_t& x) {
1095+
ASR::If_t& xx = const_cast<ASR::If_t&>(x);
1096+
transform_stmts(xx.m_body, xx.n_body);
1097+
transform_stmts(xx.m_orelse, xx.n_orelse);
1098+
SymbolTable* module_scope = current_scope->parent;
1099+
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*xx.m_test)) {
1100+
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(xx.m_test);
1101+
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
1102+
ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, xx.m_test, module_scope);
1103+
xx.m_test = function_call;
1104+
}
1105+
}
1106+
}
1107+
10011108
void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
10021109
SymbolTable* module_scope = current_scope->parent;
10031110
Vec<ASR::call_arg_t> call_args;
@@ -1298,7 +1405,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
12981405

12991406
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
13001407
pass_result.push_back(al, assert_stmt);
1301-
} else if(ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
1408+
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
13021409
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
13031410
SymbolTable* module_scope = current_scope->parent;
13041411
ASR::expr_t* left_tmp = nullptr;

0 commit comments

Comments
 (0)