@@ -672,6 +672,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
672
672
return module_scope->get_symbol (name);
673
673
}
674
674
675
+ ASR::symbol_t * declare_basic_get_type_function (Allocator& al, const Location& loc, SymbolTable* module_scope) {
676
+ std::string name = " basic_get_type" ;
677
+ symbolic_dependencies.push_back (name);
678
+ if (!module_scope->get_symbol (name)) {
679
+ std::string header = " symengine/cwrapper.h" ;
680
+ SymbolTable* fn_symtab = al.make_new <SymbolTable>(module_scope);
681
+
682
+ Vec<ASR::expr_t *> args;
683
+ args.reserve (al, 1 );
684
+ ASR::symbol_t * arg1 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
685
+ al, loc, fn_symtab, s2c (al, " _lpython_return_variable" ), nullptr , 0 , ASR::intentType::ReturnVar,
686
+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )),
687
+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false ));
688
+ fn_symtab->add_symbol (s2c (al, " _lpython_return_variable" ), arg1);
689
+ ASR::symbol_t * arg2 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
690
+ al, loc, fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
691
+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_CPtr_t (al, loc)),
692
+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
693
+ fn_symtab->add_symbol (s2c (al, " x" ), arg2);
694
+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, loc, arg2)));
695
+
696
+ Vec<ASR::stmt_t *> body;
697
+ body.reserve (al, 1 );
698
+
699
+ Vec<char *> dep;
700
+ dep.reserve (al, 1 );
701
+
702
+ ASR::expr_t * return_var = ASRUtils::EXPR (ASR::make_Var_t (al, loc, fn_symtab->get_symbol (" _lpython_return_variable" )));
703
+ ASR::asr_t * subrout = ASRUtils::make_Function_t_util (al, loc,
704
+ fn_symtab, s2c (al, name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
705
+ return_var, ASR::abiType::BindC, ASR::accessType::Public,
706
+ ASR::deftypeType::Interface, s2c (al, name), false , false , false ,
707
+ false , false , nullptr , 0 , false , false , false , s2c (al, header));
708
+ ASR::symbol_t * symbol = ASR::down_cast<ASR::symbol_t >(subrout);
709
+ module_scope->add_symbol (s2c (al, name), symbol);
710
+ }
711
+ return module_scope->get_symbol (name);
712
+ }
713
+
675
714
ASR::symbol_t * declare_basic_eq_function (Allocator& al, const Location& loc, SymbolTable* module_scope) {
676
715
std::string name = " basic_eq" ;
677
716
symbolic_dependencies.push_back (name);
@@ -828,6 +867,60 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
828
867
ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr , nullptr ));
829
868
break ;
830
869
}
870
+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ: {
871
+ ASR::symbol_t * basic_get_type_sym = declare_basic_get_type_function (al, loc, module_scope);
872
+ ASR::expr_t * value1 = handle_argument (al, loc, intrinsic_func->m_args [0 ]);
873
+ Vec<ASR::call_arg_t > call_args;
874
+ call_args.reserve (al, 1 );
875
+ ASR::call_arg_t call_arg;
876
+ call_arg.loc = loc;
877
+ call_arg.m_value = value1;
878
+ call_args.push_back (al, call_arg);
879
+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, loc,
880
+ basic_get_type_sym, basic_get_type_sym, call_args.p , call_args.n ,
881
+ ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )), nullptr , nullptr ));
882
+ // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM
883
+ return ASRUtils::EXPR (ASR::make_IntegerCompare_t (al, loc, function_call, ASR::cmpopType::Eq,
884
+ ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, loc, 16 , ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )))),
885
+ ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr ));
886
+ break ;
887
+ }
888
+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: {
889
+ ASR::symbol_t * basic_get_type_sym = declare_basic_get_type_function (al, loc, module_scope);
890
+ ASR::expr_t * value1 = handle_argument (al, loc, intrinsic_func->m_args [0 ]);
891
+ Vec<ASR::call_arg_t > call_args;
892
+ call_args.reserve (al, 1 );
893
+ ASR::call_arg_t call_arg;
894
+ call_arg.loc = loc;
895
+ call_arg.m_value = value1;
896
+ call_args.push_back (al, call_arg);
897
+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, loc,
898
+ basic_get_type_sym, basic_get_type_sym, call_args.p , call_args.n ,
899
+ ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )), nullptr , nullptr ));
900
+ // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM
901
+ return ASRUtils::EXPR (ASR::make_IntegerCompare_t (al, loc, function_call, ASR::cmpopType::Eq,
902
+ ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, loc, 15 , ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )))),
903
+ ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr ));
904
+ break ;
905
+ }
906
+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: {
907
+ ASR::symbol_t * basic_get_type_sym = declare_basic_get_type_function (al, loc, module_scope);
908
+ ASR::expr_t * value1 = handle_argument (al, loc, intrinsic_func->m_args [0 ]);
909
+ Vec<ASR::call_arg_t > call_args;
910
+ call_args.reserve (al, 1 );
911
+ ASR::call_arg_t call_arg;
912
+ call_arg.loc = loc;
913
+ call_arg.m_value = value1;
914
+ call_args.push_back (al, call_arg);
915
+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, loc,
916
+ basic_get_type_sym, basic_get_type_sym, call_args.p , call_args.n ,
917
+ ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )), nullptr , nullptr ));
918
+ // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM
919
+ return ASRUtils::EXPR (ASR::make_IntegerCompare_t (al, loc, function_call, ASR::cmpopType::Eq,
920
+ ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, loc, 17 , ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )))),
921
+ ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr ));
922
+ break ;
923
+ }
831
924
default : {
832
925
throw LCompilersException (" IntrinsicFunction: `"
833
926
+ ASRUtils::get_intrinsic_name (intrinsic_id)
@@ -998,6 +1091,20 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
998
1091
}
999
1092
}
1000
1093
1094
+ void visit_If (const ASR::If_t& x) {
1095
+ ASR::If_t& xx = const_cast <ASR::If_t&>(x);
1096
+ transform_stmts (xx.m_body , xx.n_body );
1097
+ transform_stmts (xx.m_orelse , xx.n_orelse );
1098
+ SymbolTable* module_scope = current_scope->parent ;
1099
+ if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*xx.m_test )) {
1100
+ ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(xx.m_test );
1101
+ if (intrinsic_func->m_type ->type == ASR::ttypeType::Logical) {
1102
+ ASR::expr_t * function_call = process_attributes (al, xx.base .base .loc , xx.m_test , module_scope);
1103
+ xx.m_test = function_call;
1104
+ }
1105
+ }
1106
+ }
1107
+
1001
1108
void visit_SubroutineCall (const ASR::SubroutineCall_t &x) {
1002
1109
SymbolTable* module_scope = current_scope->parent ;
1003
1110
Vec<ASR::call_arg_t > call_args;
@@ -1298,7 +1405,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
1298
1405
1299
1406
ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
1300
1407
pass_result.push_back (al, assert_stmt);
1301
- } else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test )) {
1408
+ } else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test )) {
1302
1409
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test );
1303
1410
SymbolTable* module_scope = current_scope->parent ;
1304
1411
ASR::expr_t * left_tmp = nullptr ;
0 commit comments