Skip to content

Commit a67222a

Browse files
committed
Add set.pop method
1 parent 0f16696 commit a67222a

File tree

5 files changed

+223
-0
lines changed

5 files changed

+223
-0
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ RUN(NAME test_set_add LABELS cpython llvm llvm_jit)
590590
RUN(NAME test_set_remove LABELS cpython llvm llvm_jit)
591591
RUN(NAME test_set_discard LABELS cpython llvm llvm_jit)
592592
RUN(NAME test_set_clear LABELS cpython llvm)
593+
RUN(NAME test_set_pop LABELS cpython llvm)
593594
RUN(NAME test_global_set LABELS cpython llvm llvm_jit)
594595
RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c)
595596
RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64)

integration_tests/test_set_pop.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
def set_pop_str():
2+
s: set[str] = {'a', 'b', 'c'}
3+
4+
assert s.pop() in {'a', 'b', 'c'}
5+
assert len(s) == 2
6+
assert s.pop() in {'a', 'b', 'c'}
7+
assert s.pop() in {'a', 'b', 'c'}
8+
assert len(s) == 0
9+
10+
s.add('d')
11+
assert s.pop() == 'd'
12+
13+
def set_pop_int():
14+
s: set[i32] = {1, 2, 3}
15+
16+
assert s.pop() in {1, 2, 3}
17+
assert len(s) == 2
18+
assert s.pop() in {1, 2, 3}
19+
assert s.pop() in {1, 2, 3}
20+
assert len(s) == 0
21+
22+
s.add(4)
23+
assert s.pop() == 4
24+
25+
set_pop_str()
26+
set_pop_int()

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,6 +1588,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15881588
LLVM::is_llvm_struct(dict_type->m_value_type));
15891589
}
15901590

1591+
void visit_SetPop(const ASR::SetPop_t& x) {
1592+
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
1593+
ASRUtils::expr_type(x.m_a));
1594+
int64_t ptr_loads_copy = ptr_loads;
1595+
ptr_loads = 0;
1596+
this->visit_expr(*x.m_a);
1597+
llvm::Value* pset = tmp;
1598+
1599+
ptr_loads = ptr_loads_copy;
1600+
1601+
llvm_utils->set_set_api(set_type);
1602+
tmp = llvm_utils->set_api->pop_item(pset, *module, set_type->m_type);
1603+
}
1604+
1605+
15911606
void visit_ListLen(const ASR::ListLen_t& x) {
15921607
if (x.m_value) {
15931608
this->visit_expr(*x.m_value);

src/libasr/codegen/llvm_utils.cpp

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6709,6 +6709,180 @@ namespace LCompilers {
67096709
LLVM::CreateStore(*builder, occupancy, occupancy_ptr);
67106710
}
67116711

6712+
llvm::Value* LLVMSetLinearProbing::pop_item(llvm::Value *set, llvm::Module &module,
6713+
ASR::ttype_t *el_asr_type) {
6714+
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
6715+
6716+
llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set);
6717+
llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr);
6718+
llvm_utils->create_if_else(builder->CreateICmpNE(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0)), [=]() {}, [&]() {
6719+
std::string message = "The set is empty";
6720+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6721+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6722+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6723+
int exit_code_int = 1;
6724+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6725+
llvm::APInt(32, exit_code_int));
6726+
exit(context, module, *builder, exit_code);
6727+
});
6728+
get_builder0();
6729+
llvm::AllocaInst *pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
6730+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr);
6731+
6732+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
6733+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
6734+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
6735+
6736+
llvm::Value *el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
6737+
llvm::Value *el_list = get_el_list(set);
6738+
6739+
// head
6740+
llvm_utils->start_new_block(loophead);
6741+
{
6742+
llvm::Value *cond = builder->CreateICmpSGT(
6743+
current_capacity,
6744+
LLVM::CreateLoad(*builder, pos_ptr)
6745+
);
6746+
builder->CreateCondBr(cond, loopbody, loopend);
6747+
}
6748+
6749+
// body
6750+
llvm_utils->start_new_block(loopbody);
6751+
{
6752+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
6753+
llvm::Value* el_mask_value = LLVM::CreateLoad(*builder,
6754+
llvm_utils->create_ptr_gep(el_mask, pos));
6755+
llvm::Value* is_el_skip = builder->CreateICmpEQ(el_mask_value,
6756+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)));
6757+
llvm::Value* is_el_set = builder->CreateICmpNE(el_mask_value,
6758+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));
6759+
llvm::Value* is_el = builder->CreateAnd(is_el_set,
6760+
builder->CreateNot(is_el_skip));
6761+
6762+
llvm_utils->create_if_else(is_el, [&]() {
6763+
llvm::Value* el_mask_i = llvm_utils->create_ptr_gep(el_mask, pos);
6764+
llvm::Value* tombstone_marker = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3));
6765+
LLVM::CreateStore(*builder, tombstone_marker, el_mask_i);
6766+
occupancy = builder->CreateSub(occupancy, llvm::ConstantInt::get(
6767+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
6768+
LLVM::CreateStore(*builder, occupancy, occupancy_ptr);
6769+
}, [=]() {
6770+
LLVM::CreateStore(*builder, builder->CreateAdd(pos, llvm::ConstantInt::get(
6771+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))), pos_ptr);
6772+
});
6773+
builder->CreateCondBr(is_el, loopend, loophead);
6774+
}
6775+
6776+
// end
6777+
llvm_utils->start_new_block(loopend);
6778+
6779+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
6780+
llvm::Value *el = llvm_utils->list_api->read_item(el_list, pos, false, module,
6781+
LLVM::is_llvm_struct(el_asr_type));
6782+
return el;
6783+
}
6784+
6785+
llvm::Value* LLVMSetSeparateChaining::pop_item(llvm::Value *set, llvm::Module &module,
6786+
ASR::ttype_t *el_asr_type) {
6787+
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
6788+
llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set);
6789+
llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr);
6790+
llvm_utils->create_if_else(builder->CreateICmpNE(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0)), []() {}, [&]() {
6791+
std::string message = "The set is empty";
6792+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6793+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6794+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6795+
int exit_code_int = 1;
6796+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6797+
llvm::APInt(32, exit_code_int));
6798+
exit(context, module, *builder, exit_code);
6799+
});
6800+
6801+
get_builder0();
6802+
llvm::AllocaInst* chain_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
6803+
llvm::AllocaInst* found_ptr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
6804+
llvm::AllocaInst* pos = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
6805+
LLVM::CreateStore(*builder,
6806+
llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos);
6807+
6808+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
6809+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
6810+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
6811+
6812+
llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set));
6813+
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
6814+
6815+
// head
6816+
llvm_utils->start_new_block(loophead);
6817+
{
6818+
llvm::Value *cond = builder->CreateICmpSGT(
6819+
current_capacity,
6820+
LLVM::CreateLoad(*builder, pos_ptr)
6821+
);
6822+
builder->CreateCondBr(cond, loopbody, loopend);
6823+
}
6824+
6825+
// body
6826+
llvm_utils->start_new_block(loopbody);
6827+
{
6828+
llvm::Value *el_mask_value = LLVM::CreateLoad(*builder,
6829+
llvm_utils->create_ptr_gep(el_mask, LLVM::CreateLoad(*builder, pos)));
6830+
llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, LLVM::CreateLoad(*builder, pos));
6831+
6832+
llvm::Value *is_el = builder->CreateICmpEQ(el_mask_value,
6833+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
6834+
llvm_utils->create_if_else(is_el, [&]() {
6835+
llvm::Value* el_ll_i8 = builder->CreateBitCast(el_linked_list, llvm::Type::getInt8PtrTy(context));
6836+
LLVM::CreateStore(*builder, el_ll_i8, chain_itr);
6837+
llvm::Value* el_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
6838+
llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)];
6839+
llvm::Value* el_struct = builder->CreateBitCast(el_struct_i8, el_struct_type->getPointerTo());
6840+
llvm::Value* next_el_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(el_struct, 1));
6841+
llvm::Value *cond = builder->CreateICmpNE(
6842+
next_el_struct,
6843+
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))
6844+
);
6845+
6846+
llvm_utils->create_if_else(cond, [&](){
6847+
llvm::Value *found = LLVM::CreateLoad(*builder, next_el_struct);
6848+
llvm::Value *prev = LLVM::CreateLoad(*builder, chain_itr);
6849+
found = builder->CreateBitCast(found, el_struct_type->getPointerTo());
6850+
llvm::Value* found_next = LLVM::CreateLoad(*builder, llvm_utils->create_gep(found, 1));
6851+
prev = builder->CreateBitCast(prev, el_struct_type->getPointerTo());
6852+
LLVM::CreateStore(*builder, found_next, llvm_utils->create_gep(prev, 1));
6853+
LLVM::CreateStore(*builder, found, found_ptr);
6854+
}, [&](){
6855+
llvm::Value *found = LLVM::CreateLoad(*builder, chain_itr);
6856+
llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)];
6857+
found = builder->CreateBitCast(found, el_struct_type->getPointerTo());
6858+
LLVM::CreateStore(*builder, found, found_ptr);
6859+
LLVM::CreateStore(
6860+
*builder,
6861+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)),
6862+
llvm_utils->create_ptr_gep(el_mask, LLVM::CreateLoad(*builder, pos))
6863+
);
6864+
llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set);
6865+
llvm::Value* num_buckets_filled = LLVM::CreateLoad(*builder, num_buckets_filled_ptr);
6866+
num_buckets_filled = builder->CreateSub(num_buckets_filled, llvm::ConstantInt::get(
6867+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
6868+
LLVM::CreateStore(*builder, num_buckets_filled, num_buckets_filled_ptr);
6869+
});
6870+
}, [&]() {
6871+
});
6872+
LLVM::CreateStore(*builder, builder->CreateAdd(pos, llvm::ConstantInt::get(
6873+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))), pos_ptr);
6874+
builder->CreateCondBr(is_el, loopend, loophead);
6875+
}
6876+
6877+
llvm::Value *el = llvm_utils->create_ptr_gep(LLVM::CreateLoad(*builder, pos_ptr), 0);
6878+
6879+
if (LLVM::is_llvm_struct(el_asr_type)) {
6880+
return el;
6881+
} else {
6882+
return LLVM::CreateLoad(*builder, el);
6883+
}
6884+
}
6885+
67126886
void LLVMSetLinearProbing::set_deepcopy(
67136887
llvm::Value* src, llvm::Value* dest,
67146888
ASR::Set_t* set_type, llvm::Module* module,

src/libasr/codegen/llvm_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,9 @@ namespace LCompilers {
10031003
llvm::Value* set, llvm::Value* el,
10041004
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0;
10051005

1006+
virtual
1007+
llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type) = 0;
1008+
10061009
virtual
10071010
void set_deepcopy(
10081011
llvm::Value* src, llvm::Value* dest,
@@ -1076,6 +1079,8 @@ namespace LCompilers {
10761079
llvm::Value* set, llvm::Value* el,
10771080
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);
10781081

1082+
llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type);
1083+
10791084
void set_deepcopy(
10801085
llvm::Value* src, llvm::Value* dest,
10811086
ASR::Set_t* set_type, llvm::Module* module,
@@ -1159,6 +1164,8 @@ namespace LCompilers {
11591164
llvm::Value* set, llvm::Value* el,
11601165
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);
11611166

1167+
llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type);
1168+
11621169
void set_deepcopy(
11631170
llvm::Value* src, llvm::Value* dest,
11641171
ASR::Set_t* set_type, llvm::Module* module,

0 commit comments

Comments
 (0)