Skip to content

Commit da8d25b

Browse files
authored
Merge pull request #2399 from anutosh491/Second_method_for_args
Implementing get argument through design 1
2 parents 5072c12 + 5455dd0 commit da8d25b

File tree

6 files changed

+360
-3
lines changed

6 files changed

+360
-3
lines changed

integration_tests/symbolics_02.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def test_symbolic_operations():
1919
else:
2020
assert False
2121
assert(z.func == Add)
22+
assert(z.args[0] == x or z.args[0] == y)
23+
assert(z.args[1] == y or z.args[1] == x)
2224
print(z)
2325

2426
# Subtraction
@@ -43,6 +45,8 @@ def test_symbolic_operations():
4345
else:
4446
assert False
4547
assert(u.func == Mul)
48+
assert(u.args[0] == x)
49+
assert(u.args[1] == y)
4650
print(u)
4751

4852
# Division

integration_tests/symbolics_05.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,12 @@ def test_operations():
3232
assert((sin(x) + cos(x)).diff(x) == S(-1)*c + d)
3333
assert((sin(x) + cos(x) + exp(x) + pi).diff(x).expand().diff(x) == exp(x) + S(-1)*c + S(-1)*d)
3434

35+
# test args
36+
assert(a.args[0] == x + y)
37+
assert(a.args[1] == S(2))
38+
assert(b.args[0] == x + y + z)
39+
assert(b.args[1] == S(3))
40+
assert(c.args[0] == x)
41+
assert(d.args[0] == x)
3542

3643
test_operations()

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ enum class IntrinsicScalarFunctions : int64_t {
8585
SymbolicPowQ,
8686
SymbolicLogQ,
8787
SymbolicSinQ,
88+
SymbolicGetArgument,
8889
// ...
8990
};
9091

@@ -152,6 +153,7 @@ inline std::string get_intrinsic_name(int x) {
152153
INTRINSIC_NAME_CASE(SymbolicPowQ)
153154
INTRINSIC_NAME_CASE(SymbolicLogQ)
154155
INTRINSIC_NAME_CASE(SymbolicSinQ)
156+
INTRINSIC_NAME_CASE(SymbolicGetArgument)
155157
default : {
156158
throw LCompilersException("pickle: intrinsic_id not implemented");
157159
}
@@ -3116,6 +3118,54 @@ namespace SymbolicHasSymbolQ {
31163118
}
31173119
} // namespace SymbolicHasSymbolQ
31183120

3121+
namespace SymbolicGetArgument {
3122+
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x,
3123+
diag::Diagnostics& diagnostics) {
3124+
ASRUtils::require_impl(x.n_args == 2, "Intrinsic function SymbolicGetArgument"
3125+
"accepts exactly 2 argument", x.base.base.loc, diagnostics);
3126+
3127+
ASR::ttype_t* arg1_type = ASRUtils::expr_type(x.m_args[0]);
3128+
ASR::ttype_t* arg2_type = ASRUtils::expr_type(x.m_args[1]);
3129+
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*arg1_type),
3130+
"SymbolicGetArgument expects the first argument to be of type SymbolicExpression",
3131+
x.base.base.loc, diagnostics);
3132+
ASRUtils::require_impl(ASR::is_a<ASR::Integer_t>(*arg2_type),
3133+
"SymbolicGetArgument expects the second argument to be of type Integer",
3134+
x.base.base.loc, diagnostics);
3135+
}
3136+
3137+
static inline ASR::expr_t* eval_SymbolicGetArgument(Allocator &/*al*/,
3138+
const Location &/*loc*/, ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/) {
3139+
/*TODO*/
3140+
return nullptr;
3141+
}
3142+
3143+
static inline ASR::asr_t* create_SymbolicGetArgument(Allocator& al,
3144+
const Location& loc, Vec<ASR::expr_t*>& args,
3145+
const std::function<void (const std::string &, const Location &)> err) {
3146+
3147+
if (args.size() != 2) {
3148+
err("Intrinsic function SymbolicGetArguments accepts exactly 2 argument", loc);
3149+
}
3150+
3151+
ASR::ttype_t* arg1_type = ASRUtils::expr_type(args[0]);
3152+
ASR::ttype_t* arg2_type = ASRUtils::expr_type(args[1]);
3153+
if (!ASR::is_a<ASR::SymbolicExpression_t>(*arg1_type)) {
3154+
err("The first argument of SymbolicGetArgument function must be of type SymbolicExpression",
3155+
args[0]->base.loc);
3156+
}
3157+
if (!ASR::is_a<ASR::Integer_t>(*arg2_type)) {
3158+
err("The second argument of SymbolicGetArgument function must be of type Integer",
3159+
args[1]->base.loc);
3160+
}
3161+
3162+
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
3163+
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicGetArgument,
3164+
static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
3165+
0, to_type);
3166+
}
3167+
} // namespace SymbolicGetArgument
3168+
31193169
#define create_symbolic_query_macro(X) \
31203170
namespace X { \
31213171
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
@@ -3325,6 +3375,8 @@ namespace IntrinsicScalarFunctionRegistry {
33253375
{nullptr, &SymbolicLogQ::verify_args}},
33263376
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicSinQ),
33273377
{nullptr, &SymbolicSinQ::verify_args}},
3378+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
3379+
{nullptr, &SymbolicGetArgument::verify_args}},
33283380
};
33293381

33303382
static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
@@ -3441,6 +3493,8 @@ namespace IntrinsicScalarFunctionRegistry {
34413493
"SymbolicLogQ"},
34423494
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicSinQ),
34433495
"SymbolicSinQ"},
3496+
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicGetArgument),
3497+
"SymbolicGetArgument"},
34443498
};
34453499

34463500

@@ -3502,6 +3556,7 @@ namespace IntrinsicScalarFunctionRegistry {
35023556
{"PowQ", {&SymbolicPowQ::create_SymbolicPowQ, &SymbolicPowQ::eval_SymbolicPowQ}},
35033557
{"LogQ", {&SymbolicLogQ::create_SymbolicLogQ, &SymbolicLogQ::eval_SymbolicLogQ}},
35043558
{"SinQ", {&SymbolicSinQ::create_SymbolicSinQ, &SymbolicSinQ::eval_SymbolicSinQ}},
3559+
{"GetArgument", {&SymbolicGetArgument::create_SymbolicGetArgument, &SymbolicGetArgument::eval_SymbolicGetArgument}},
35053560
};
35063561

35073562
static inline bool is_intrinsic_function(const std::string& name) {

0 commit comments

Comments
 (0)