Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ RUN(NAME test_dict_12 LABELS cpython llvm c)
RUN(NAME test_dict_13 LABELS cpython llvm c)
RUN(NAME test_dict_bool LABELS cpython llvm)
RUN(NAME test_dict_increment LABELS cpython llvm)
RUN(NAME test_set_len LABELS cpython llvm)
RUN(NAME test_set_add LABELS cpython llvm)
RUN(NAME test_set_remove LABELS cpython llvm)
RUN(NAME test_for_loop LABELS cpython llvm c)
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
34 changes: 34 additions & 0 deletions integration_tests/test_set_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from lpython import i32

def test_set_add():
s1: set[i32]
s2: set[tuple[i32, tuple[i32, i32], str]]
s3: set[str]
st1: str
i: i32
j: i32

s1 = {0}
s2 = {(0, (1, 2), 'a')}
for i in range(20):
j = i % 10
s1.add(j)
s2.add((j, (j + 1, j + 2), 'a'))
assert len(s1) == len(s2)
if i < 10:
assert len(s1) == i + 1
else:
assert len(s1) == 10

st1 = 'a'
s3 = {st1}
for i in range(20):
s3.add(st1)
if i < 10:
if i > 0:
assert len(s3) == i
st1 += 'a'
else:
assert len(s3) == 10

test_set_add()
8 changes: 8 additions & 0 deletions integration_tests/test_set_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from lpython import i32

def test_set():
s: set[i32]
s = {1, 2, 22, 2, -1, 1}
assert len(s) == 4

test_set()
47 changes: 47 additions & 0 deletions integration_tests/test_set_remove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from lpython import i32

def test_set_add():
s1: set[i32]
s2: set[tuple[i32, tuple[i32, i32], str]]
s3: set[str]
st1: str
i: i32
j: i32
k: i32

for k in range(2):
s1 = {0}
s2 = {(0, (1, 2), 'a')}
for i in range(20):
j = i % 10
s1.add(j)
s2.add((j, (j + 1, j + 2), 'a'))

for i in range(10):
s1.remove(i)
s2.remove((i, (i + 1, i + 2), 'a'))
# assert len(s1) == 10 - 1 - i
# assert len(s1) == len(s2)

st1 = 'a'
s3 = {st1}
for i in range(20):
s3.add(st1)
if i < 10:
if i > 0:
st1 += 'a'

st1 = 'a'
for i in range(10):
s3.remove(st1)
assert len(s3) == 10 - 1 - i
if i < 10:
st1 += 'a'

for i in range(20):
s1.add(i)
if i % 2 == 0:
s1.remove(i)
assert len(s1) == (i + 1) // 2

test_set_add()
2 changes: 0 additions & 2 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ stmt
| SelectType(expr selector, type_stmt* body, stmt* default)
| CPtrToPointer(expr cptr, expr ptr, expr? shape, expr? lower_bounds)
| BlockCall(int label, symbol m)
| SetInsert(expr a, expr ele)
| SetRemove(expr a, expr ele)
| ListInsert(expr a, expr pos, expr ele)
| ListRemove(expr a, expr ele)
| ListClear(expr a)
Expand Down
86 changes: 86 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
std::unique_ptr<LLVMTuple> tuple_api;
std::unique_ptr<LLVMDictInterface> dict_api_lp;
std::unique_ptr<LLVMDictInterface> dict_api_sc;
std::unique_ptr<LLVMSetInterface> set_api; // linear probing
std::unique_ptr<LLVMArrUtils::Descriptor> arr_descr;

ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, std::string infile,
Expand All @@ -199,13 +200,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tuple_api(std::make_unique<LLVMTuple>(context, llvm_utils.get(), builder.get())),
dict_api_lp(std::make_unique<LLVMDictOptimizedLinearProbing>(context, llvm_utils.get(), builder.get())),
dict_api_sc(std::make_unique<LLVMDictSeparateChaining>(context, llvm_utils.get(), builder.get())),
set_api(std::make_unique<LLVMSetLinearProbing>(context, llvm_utils.get(), builder.get())),
arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context,
builder.get(), llvm_utils.get(),
LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor))
{
llvm_utils->tuple_api = tuple_api.get();
llvm_utils->list_api = list_api.get();
llvm_utils->dict_api = nullptr;
llvm_utils->set_api = set_api.get();
llvm_utils->arr_api = arr_descr.get();
llvm_utils->dict_api_lp = dict_api_lp.get();
llvm_utils->dict_api_sc = dict_api_sc.get();
Expand Down Expand Up @@ -1149,6 +1152,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = const_dict;
}

void visit_SetConstant(const ASR::SetConstant_t& x) {
llvm::Type* const_set_type = llvm_utils->get_set_type(x.m_type, module.get());
llvm::Value* const_set = builder->CreateAlloca(const_set_type, nullptr, "const_set");
ASR::Set_t* x_set = ASR::down_cast<ASR::Set_t>(x.m_type);
std::string el_type_code = ASRUtils::get_type_code(x_set->m_type);
llvm_utils->set_api->set_init(el_type_code, const_set, module.get(), x.n_elements);
int64_t ptr_loads_el = !LLVM::is_llvm_struct(x_set->m_type);
int64_t ptr_loads_copy = ptr_loads;
for( size_t i = 0; i < x.n_elements; i++ ) {
ptr_loads = ptr_loads_el;
visit_expr_wrapper(x.m_elements[i], true);
llvm::Value* element = tmp;
llvm_utils->set_api->write_item(const_set, element, module.get(),
x_set->m_type, name2memidx);
}
ptr_loads = ptr_loads_copy;
tmp = const_set;
}

void visit_TupleConstant(const ASR::TupleConstant_t& x) {
ASR::Tuple_t* tuple_type = ASR::down_cast<ASR::Tuple_t>(x.m_type);
std::string type_code = ASRUtils::get_type_code(tuple_type->m_type,
Expand Down Expand Up @@ -1487,6 +1509,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = llvm_utils->dict_api->len(pdict);
}

void visit_SetLen(const ASR::SetLen_t& x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
return ;
}

int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_arg);
ptr_loads = ptr_loads_copy;
llvm::Value* pset = tmp;
tmp = llvm_utils->set_api->len(pset);
}

void visit_ListInsert(const ASR::ListInsert_t& x) {
ASR::List_t* asr_list = ASR::down_cast<ASR::List_t>(
ASRUtils::expr_type(x.m_a));
Expand Down Expand Up @@ -1648,6 +1684,34 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx);
}

void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*m_arg);
llvm::Value* pset = tmp;

ptr_loads = 2;
this->visit_expr_wrapper(m_ele, true);
ptr_loads = ptr_loads_copy;
llvm::Value *el = tmp;
set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
}

void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*m_arg);
llvm::Value* pset = tmp;

ptr_loads = 2;
this->visit_expr_wrapper(m_ele, true);
ptr_loads = ptr_loads_copy;
llvm::Value *el = tmp;
set_api->remove_item(pset, el, *module, asr_el_type);
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
switch (static_cast<ASRUtils::IntrinsicFunctions>(x.m_intrinsic_id)) {
case ASRUtils::IntrinsicFunctions::ListIndex: {
Expand Down Expand Up @@ -1691,6 +1755,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
break;
}
case ASRUtils::IntrinsicFunctions::SetAdd: {
generate_SetAdd(x.m_args[0], x.m_args[1]);
break;
}
case ASRUtils::IntrinsicFunctions::SetRemove: {
generate_SetRemove(x.m_args[0], x.m_args[1]);
break;
}
case ASRUtils::IntrinsicFunctions::Exp: {
switch (x.m_overload_id) {
case 0: {
Expand Down Expand Up @@ -3945,6 +4017,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
bool is_value_tuple = ASR::is_a<ASR::Tuple_t>(*asr_value_type);
bool is_target_dict = ASR::is_a<ASR::Dict_t>(*asr_target_type);
bool is_value_dict = ASR::is_a<ASR::Dict_t>(*asr_value_type);
bool is_target_set = ASR::is_a<ASR::Set_t>(*asr_target_type);
bool is_value_set = ASR::is_a<ASR::Set_t>(*asr_value_type);
bool is_target_struct = ASR::is_a<ASR::Struct_t>(*asr_target_type);
bool is_value_struct = ASR::is_a<ASR::Struct_t>(*asr_value_type);
if (ASR::is_a<ASR::StringSection_t>(*x.m_target)) {
Expand Down Expand Up @@ -4034,6 +4108,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm_utils->dict_api->dict_deepcopy(value_dict, target_dict,
value_dict_type, module.get(), name2memidx);
return ;
} else if( is_target_set && is_value_set ) {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_value);
llvm::Value* value_set = tmp;
this->visit_expr(*x.m_target);
llvm::Value* target_set = tmp;
ptr_loads = ptr_loads_copy;
ASR::Set_t* value_set_type = ASR::down_cast<ASR::Set_t>(asr_value_type);
llvm_utils->set_api->set_deepcopy(value_set, target_set,
value_set_type, module.get(), name2memidx);
return ;
} else if( is_target_struct && is_value_struct ) {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
Expand Down
Loading