Skip to content

Commit 584617f

Browse files
committed
Updated register class and register lists bindings.
1 parent ef29b0b commit 584617f

File tree

4 files changed

+185
-102
lines changed

4 files changed

+185
-102
lines changed

arch/riscv/src/lib.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ use binaryninja::{
1515
UnusedRegisterStackInfo,
1616
},
1717
binary_view::{BinaryView, BinaryViewExt},
18-
calling_convention::{
19-
register_calling_convention, CallingConvention, ConventionBuilder,
20-
},
18+
calling_convention::{register_calling_convention, CallingConvention, ConventionBuilder},
2119
custom_binary_view::{BinaryViewType, BinaryViewTypeExt},
2220
disassembly::{InstructionTextToken, InstructionTextTokenKind},
2321
function::Function,

callingconvention.cpp

Lines changed: 135 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -358,23 +358,78 @@ vector<uint32_t> CallingConvention::GetRegisterArgumentListRegs(uint32_t regList
358358
return vector<uint32_t>();
359359
}
360360

361+
361362
vector<Variable> CallingConvention::GetVariablesForParameters(
362363
const vector<FunctionParameter>& params, const std::optional<set<uint32_t>>& permittedRegs)
363364
{
364-
vector<uint32_t> intArgs = GetIntegerArgumentRegisters();
365-
vector<uint32_t> floatArgs = GetFloatArgumentRegisters();
365+
vector<uint32_t> classes = GetRegisterArgumentClasses();
366+
367+
// Build register lists for all classes
368+
// The order of iterators matter here, for register class and register list
369+
// we have assumed the INTEGER_SEMANTICS should be the first ones to be processed
370+
vector<vector<uint32_t>> allRegLists;
371+
vector<BNRegisterListKind> allListKinds;
372+
vector<vector<uint32_t>::iterator> allIterators;
373+
vector<vector<uint32_t>::iterator> allEndIterators;
374+
bool hasSharedIndex = false;
375+
376+
for (uint32_t classId : classes)
377+
{
378+
vector<uint32_t> registerLists = GetRegisterArgumentClassLists(classId);
379+
if (registerLists.size() > 1)
380+
hasSharedIndex = true;
381+
382+
for (uint32_t regListId : registerLists)
383+
{
384+
vector<uint32_t> regs = GetRegisterArgumentListRegs(regListId);
385+
BNRegisterListKind kind = GetRegisterArgumentListKind(regListId);
386+
387+
allRegLists.push_back(regs);
388+
allListKinds.push_back(kind);
389+
allIterators.push_back(allRegLists.back().begin());
390+
allEndIterators.push_back(allRegLists.back().end());
391+
}
392+
}
393+
394+
// Fallback to legacy API if no register classes defined
395+
if (allRegLists.empty())
396+
{
397+
vector<uint32_t> intArgs = GetIntegerArgumentRegisters();
398+
vector<uint32_t> floatArgs = GetFloatArgumentRegisters();
399+
400+
if (!intArgs.empty())
401+
{
402+
allRegLists.push_back(intArgs);
403+
allListKinds.push_back(REGISTER_LIST_KIND_INTEGER_SEMANTICS);
404+
allIterators.push_back(allRegLists.back().begin());
405+
allEndIterators.push_back(allRegLists.back().end());
406+
}
407+
408+
if (!floatArgs.empty())
409+
{
410+
allRegLists.push_back(floatArgs);
411+
allListKinds.push_back(REGISTER_LIST_KIND_FLOAT_SEMANTICS);
412+
allIterators.push_back(allRegLists.back().begin());
413+
allEndIterators.push_back(allRegLists.back().end());
414+
}
415+
416+
hasSharedIndex = AreArgumentRegistersSharedIndex();
417+
}
366418

367419
vector<Variable> result;
368-
auto intArgIter = intArgs.begin();
369-
auto floatArgIter = floatArgs.begin();
370420
size_t addrSize = GetArchitecture()->GetAddressSize();
371421
int64_t stackOffset = 0;
372-
bool sharedIndex = AreArgumentRegistersSharedIndex();
422+
373423
if (GetArchitecture()->GetLinkRegister() == BN_INVALID_REGISTER)
374424
stackOffset = addrSize;
375425
if (IsStackReservedForArgumentRegisters())
376-
stackOffset += intArgs.size() * addrSize;
377-
426+
{
427+
// Count total registers for stack reservation
428+
size_t totalRegs = 0;
429+
for (const auto& list : allRegLists)
430+
totalRegs = std::max(totalRegs, list.size());
431+
stackOffset += totalRegs * addrSize;
432+
}
378433

379434
// TODO: Structure in register and multi-reg parameters
380435
for (auto& param : params)
@@ -385,22 +440,26 @@ vector<Variable> CallingConvention::GetVariablesForParameters(
385440
{
386441
// Parameter not storage in a normal location, use custom variable
387442
result.push_back(param.location);
443+
388444
if (param.location.type == RegisterVariableSourceType)
389445
{
390-
// If non-default location matches the next register in the register parameter
391-
// lists, advance the iterators. It may just be a type mismatch, and we still
392-
// want to maintain the state for future parameters.
393-
if (intArgIter != intArgs.end() && *intArgIter == param.location.storage)
394-
{
395-
intArgIter++;
396-
if (sharedIndex && floatArgIter != floatArgs.end())
397-
floatArgIter++;
398-
}
399-
else if (floatArgIter != floatArgs.end() && *floatArgIter == param.location.storage)
446+
for (size_t i = 0; i < allIterators.size(); ++i)
400447
{
401-
floatArgIter++;
402-
if (sharedIndex && intArgIter != intArgs.end())
403-
intArgIter++;
448+
if (allIterators[i] != allEndIterators[i] && *allIterators[i] == param.location.storage)
449+
{
450+
allIterators[i]++;
451+
452+
// Advance all other iterators if shared index
453+
if (hasSharedIndex)
454+
{
455+
for (size_t j = i + 1; j < allIterators.size(); ++j)
456+
{
457+
if (allIterators[j] != allEndIterators[j])
458+
allIterators[j]++;
459+
}
460+
}
461+
break;
462+
}
404463
}
405464
}
406465
else if (param.location.type == StackVariableSourceType)
@@ -416,62 +475,74 @@ vector<Variable> CallingConvention::GetVariablesForParameters(
416475
continue;
417476
}
418477

419-
if (param.type->IsFloat())
478+
// Try to find a suitable register for this parameter
479+
bool paramPlaced = false;
480+
481+
for (size_t i = 0; i < allIterators.size(); ++i)
420482
{
421-
if (permittedRegs.has_value() && floatArgIter != floatArgs.end()
422-
&& permittedRegs.value().count(*floatArgIter) == 0)
423-
{
424-
// Disallowed register parameter, start spilling to stack. This is used in calling
425-
// conventions that place all variable argument parameters on the stack.
426-
floatArgIter = floatArgs.end();
427-
if (sharedIndex)
428-
intArgIter = intArgs.end();
429-
}
430-
else if (floatArgIter != floatArgs.end())
483+
if (allIterators[i] == allEndIterators[i])
484+
continue;
485+
486+
// Check if this register is permitted
487+
if (permittedRegs.has_value() && permittedRegs.value().count(*allIterators[i]) == 0)
431488
{
432-
BNRegisterInfo regInfo = GetArchitecture()->GetRegisterInfo(*floatArgIter);
433-
if (width <= regInfo.size)
489+
// Disallowed register parameter, mark this list as exhausted
490+
allIterators[i] = allEndIterators[i];
491+
if (hasSharedIndex)
434492
{
435-
result.emplace_back(RegisterVariableSourceType, 0, *floatArgIter);
436-
floatArgIter++;
437-
if (sharedIndex && intArgIter != intArgs.end())
438-
intArgIter++;
439-
continue;
493+
// Mark all lists as exhausted when shared index
494+
for (size_t j = 0; j < allIterators.size(); ++j)
495+
allIterators[j] = allEndIterators[j];
440496
}
497+
continue;
441498
}
442-
}
443-
else
444-
{
445-
if (permittedRegs.has_value() && intArgIter != intArgs.end()
446-
&& permittedRegs.value().count(*intArgIter) == 0)
447-
{
448-
// Disallowed register parameter, start spilling to stack. This is used in calling
449-
// conventions that place all variable argument parameters on the stack.
450-
intArgIter = intArgs.end();
451-
if (sharedIndex)
452-
floatArgIter = floatArgs.end();
453-
}
454-
else if (intArgIter != intArgs.end())
499+
500+
// Check if the type matches the register semantics
501+
bool typeMatches = false;
502+
BNRegisterListKind kind = allListKinds[i];
503+
504+
if (kind == REGISTER_LIST_KIND_INTEGER_SEMANTICS && !param.type->IsFloat())
505+
typeMatches = true;
506+
else if (kind == REGISTER_LIST_KIND_FLOAT_SEMANTICS && param.type->IsFloat())
507+
typeMatches = true;
508+
else if (kind == REGISTER_LIST_KIND_POINTER_SEMANTICS && param.type->IsPointer())
509+
typeMatches = true;
510+
511+
if (typeMatches)
455512
{
456-
BNRegisterInfo regInfo = GetArchitecture()->GetRegisterInfo(*intArgIter);
513+
BNRegisterInfo regInfo = GetArchitecture()->GetRegisterInfo(*allIterators[i]);
457514
if (width <= regInfo.size)
458515
{
459-
result.emplace_back(RegisterVariableSourceType, 0, *intArgIter);
460-
intArgIter++;
461-
if (sharedIndex && floatArgIter != floatArgs.end())
462-
floatArgIter++;
463-
continue;
516+
result.emplace_back(RegisterVariableSourceType, 0, *allIterators[i]);
517+
allIterators[i]++;
518+
519+
// Advance all other iterators if shared index
520+
if (hasSharedIndex)
521+
{
522+
for (size_t j = i + 1; j < allIterators.size(); ++j)
523+
{
524+
if (allIterators[j] != allEndIterators[j])
525+
allIterators[j]++;
526+
}
527+
}
528+
529+
paramPlaced = true;
530+
break;
464531
}
465532
}
466533
}
534+
535+
// If not placed in register, place on stack
536+
if (!paramPlaced)
537+
{
538+
result.emplace_back(StackVariableSourceType, 0, stackOffset);
467539

468-
result.emplace_back(StackVariableSourceType, 0, stackOffset);
469-
470-
if (width < addrSize)
471-
width = addrSize;
472-
else if ((width % addrSize) != 0)
473-
width += addrSize - (width % addrSize);
474-
stackOffset += width;
540+
if (width < addrSize)
541+
width = addrSize;
542+
else if ((width % addrSize) != 0)
543+
width += addrSize - (width % addrSize);
544+
stackOffset += width;
545+
}
475546
}
476547

477548
return result;

python/callingconvention.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class CallingConvention:
6363
register_argument_list_regs = {}
6464
register_argument_list_kinds = {}
6565

66+
variables_for_parameters = []
67+
68+
6669
_registered_calling_conventions = []
6770

6871
def __init__(
@@ -686,8 +689,8 @@ def perform_get_variables_for_parameters(self, param_types, permitted_regs: Opti
686689
stack_offset = 0
687690
addr_size = self.arch.address_size
688691

689-
# If there's a link register, start stack after it
690-
if self.arch.link_reg is not None:
692+
# If no link register, start stack after return address
693+
if self.arch.link_reg is None:
691694
stack_offset = addr_size
692695

693696
# Reserve stack space for argument registers if needed
@@ -713,28 +716,28 @@ def is_reg_permitted(reg_name):
713716

714717
allocated = False
715718

716-
# Try to allocate in appropriate register type
719+
# Try float registers first for float types (if available)
720+
# TODO: Add proper float type detection when type system supports it
721+
# For now, treat all parameters as potentially using integer registers
722+
723+
# Try integer registers (for all parameter types when no float regs, or non-float types)
717724
try:
718-
# For now, assume integer allocation (TODO: add float type detection)
719-
try:
720-
reg_name = next(int_arg_iter)
721-
if is_reg_permitted(reg_name):
722-
reg_index = self.arch.regs[reg_name].index
723-
result.append(variable.Variable(variable.RegisterVariableSourceType, 0, reg_index))
724-
allocated = True
725-
if shared_index:
726-
try:
727-
next(float_arg_iter) # Advance float iterator too
728-
except StopIteration:
729-
pass
730-
else:
731-
# Register not permitted, spill to stack
732-
int_arg_iter = iter([]) # Empty the iterator
733-
if shared_index:
734-
float_arg_iter = iter([])
735-
except StopIteration:
736-
pass
737-
except:
725+
reg_name = next(int_arg_iter)
726+
if is_reg_permitted(reg_name):
727+
reg_index = self.arch.regs[reg_name].index
728+
result.append(variable.Variable(variable.RegisterVariableSourceType, 0, reg_index))
729+
allocated = True
730+
if shared_index:
731+
try:
732+
next(float_arg_iter) # Advance float iterator too
733+
except StopIteration:
734+
pass
735+
else:
736+
# Register not permitted, spill to stack
737+
int_arg_iter = iter([]) # Empty the iterator
738+
if shared_index:
739+
float_arg_iter = iter([])
740+
except StopIteration:
738741
pass
739742

740743
if not allocated:

0 commit comments

Comments
 (0)