Skip to content

Commit 5b945fc

Browse files
authored
Added support for freeing variables in the ASR symbolic pass (#2268)
1 parent ddd49cc commit 5b945fc

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

integration_tests/symbolics_08.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
def basic_new_stack(x: CPtr) -> None:
66
pass
77

8+
@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib")
9+
def basic_free_stack(x: CPtr) -> None:
10+
pass
11+
812
@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib")
913
def basic_const_pi(x: CPtr) -> None:
1014
pass
@@ -22,5 +26,6 @@ def main0():
2226
s: str = basic_str(x)
2327
print(s)
2428
assert s == "pi"
29+
basic_free_stack(x)
2530

2631
main0()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
5454
ASR::Function_t &xx = const_cast<ASR::Function_t&>(x);
5555
SymbolTable* current_scope_copy = this->current_scope;
5656
this->current_scope = xx.m_symtab;
57+
SymbolTable* module_scope = this->current_scope->parent;
5758
for (auto &item : x.m_symtab->get_scope()) {
5859
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
5960
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
@@ -75,6 +76,28 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
7576
xx.n_dependencies = function_dependencies.size();
7677
xx.m_dependencies = function_dependencies.p;
7778
this->current_scope = current_scope_copy;
79+
80+
// freeing out variables
81+
std::string new_name = "basic_free_stack";
82+
ASR::symbol_t* basic_free_stack_sym = module_scope->get_symbol(new_name);
83+
Vec<ASR::stmt_t*> func_body;
84+
func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body);
85+
86+
for (ASR::symbol_t* symbol : symbolic_vars) {
87+
Vec<ASR::call_arg_t> call_args;
88+
call_args.reserve(al, 1);
89+
ASR::call_arg_t call_arg;
90+
call_arg.loc = xx.base.base.loc;
91+
call_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, symbol));
92+
call_args.push_back(al, call_arg);
93+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_free_stack_sym,
94+
basic_free_stack_sym, call_args.p, call_args.n, nullptr));
95+
func_body.push_back(al, stmt);
96+
}
97+
98+
xx.n_body = func_body.size();
99+
xx.m_body = func_body.p;
100+
symbolic_vars.clear();
78101
}
79102

80103
void visit_Variable(const ASR::Variable_t& x) {
@@ -132,6 +155,38 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
132155
module_scope->add_symbol(new_name, new_symbol);
133156
}
134157

158+
new_name = "basic_free_stack";
159+
symbolic_dependencies.push_back(new_name);
160+
if (!module_scope->get_symbol(new_name)) {
161+
std::string header = "symengine/cwrapper.h";
162+
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);
163+
164+
Vec<ASR::expr_t*> args;
165+
{
166+
args.reserve(al, 1);
167+
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
168+
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
169+
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
170+
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
171+
fn_symtab->add_symbol(s2c(al, "x"), arg);
172+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
173+
}
174+
175+
Vec<ASR::stmt_t*> body;
176+
body.reserve(al, 1);
177+
178+
Vec<char *> dep;
179+
dep.reserve(al, 1);
180+
181+
ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
182+
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
183+
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
184+
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
185+
false, false, nullptr, 0, false, false, false, s2c(al, header));
186+
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
187+
module_scope->add_symbol(new_name, new_symbol);
188+
}
189+
135190
ASR::symbol_t* var_sym = current_scope->get_symbol(var_name);
136191
ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder);
137192
ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym));
@@ -154,7 +209,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
154209
type1, nullptr));
155210

156211
// statement 4
157-
ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol(new_name);
212+
ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol("basic_new_stack");
158213
Vec<ASR::call_arg_t> call_args;
159214
call_args.reserve(al, 1);
160215
ASR::call_arg_t call_arg;

0 commit comments

Comments
 (0)