Skip to content
Open
27 changes: 26 additions & 1 deletion mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
from collections.abc import Sequence
from typing import Callable, Final, Optional
from typing_extensions import TypeGuard

from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
Expand Down Expand Up @@ -185,6 +186,7 @@
from mypyc.primitives.str_ops import (
str_check_if_true,
str_eq,
str_eq_literal,
str_ssize_t_size_op,
unicode_compare,
)
Expand Down Expand Up @@ -1550,10 +1552,33 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -

def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two strings"""

def is_string_literal(value: Value) -> TypeGuard[LoadLiteral]:
return isinstance(value, LoadLiteral) and is_str_rprimitive(value.type)

if op == "==":
if is_string_literal(lhs):
if is_string_literal(rhs):
# we can optimize out the check entirely in some Final cases
return self.true() if lhs.value == rhs.value else self.false()
literal_length = Integer(len(lhs.value), c_pyssize_t_rprimitive, line) # type: ignore [arg-type]
return self.primitive_op(str_eq_literal, [rhs, lhs, literal_length], line)
elif is_string_literal(rhs):
literal_length = Integer(len(rhs.value), c_pyssize_t_rprimitive, line) # type: ignore [arg-type]
return self.primitive_op(str_eq_literal, [lhs, rhs, literal_length], line)
return self.primitive_op(str_eq, [lhs, rhs], line)
elif op == "!=":
eq = self.primitive_op(str_eq, [lhs, rhs], line)
if is_string_literal(lhs):
if is_string_literal(rhs):
# we can optimize out the check entirely in some Final cases
return self.true() if lhs.value != rhs.value else self.false()
literal_length = Integer(len(lhs.value), c_pyssize_t_rprimitive, line) # type: ignore [arg-type]
eq = self.primitive_op(str_eq_literal, [rhs, lhs, literal_length], line)
elif is_string_literal(rhs):
literal_length = Integer(len(rhs.value), c_pyssize_t_rprimitive, line) # type: ignore [arg-type]
eq = self.primitive_op(str_eq_literal, [lhs, rhs, literal_length], line)
else:
eq = self.primitive_op(str_eq, [lhs, rhs], line)
return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line))

# TODO: modify 'str' to use same interface as 'compare_bytes' as it would avoid
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) {
#define BOTHSTRIP 2

char CPyStr_Equal(PyObject *str1, PyObject *str2);
char CPyStr_EqualLiteral(PyObject *str, PyObject *literal_str, Py_ssize_t literal_length);
PyObject *CPyStr_Build(Py_ssize_t len, ...);
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index);
Expand Down
29 changes: 21 additions & 8 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,33 @@ make_bloom_mask(int kind, const void* ptr, Py_ssize_t len)
#undef BLOOM_UPDATE
}

// Adapted from CPython 3.13.1 (_PyUnicode_Equal)
char CPyStr_Equal(PyObject *str1, PyObject *str2) {
if (str1 == str2) {
return 1;
}
Py_ssize_t len = PyUnicode_GET_LENGTH(str1);
if (PyUnicode_GET_LENGTH(str2) != len)
static char _CPyStr_Equal_NoIdentCheck(PyObject *str1, PyObject *str2, Py_ssize_t str2_length) {
// This helper function only exists to deduplicate code in CPyStr_Equal and CPyStr_EqualLiteral
Py_ssize_t str1_length = PyUnicode_GET_LENGTH(str1);
if (str1_length != str2_length)
return 0;
int kind = PyUnicode_KIND(str1);
if (PyUnicode_KIND(str2) != kind)
return 0;
const void *data1 = PyUnicode_DATA(str1);
const void *data2 = PyUnicode_DATA(str2);
return memcmp(data1, data2, len * kind) == 0;
return memcmp(data1, data2, str1_length * kind) == 0;
}

// Adapted from CPython 3.13.1 (_PyUnicode_Equal)
char CPyStr_Equal(PyObject *str1, PyObject *str2) {
if (str1 == str2) {
return 1;
}
Py_ssize_t str2_length = PyUnicode_GET_LENGTH(str2);
return _CPyStr_Equal_NoIdentCheck(str1, str2, str2_length);
}

char CPyStr_EqualLiteral(PyObject *str, PyObject *literal_str, Py_ssize_t literal_length) {
if (str == literal_str) {
return 1;
}
return _CPyStr_Equal_NoIdentCheck(str, literal_str, literal_length);
}

PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
Expand Down
8 changes: 8 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@
error_kind=ERR_NEVER,
)

str_eq_literal = custom_primitive_op(
name="str_eq_literal",
c_function_name="CPyStr_EqualLiteral",
arg_types=[str_rprimitive, str_rprimitive, c_pyssize_t_rprimitive],
return_type=bool_rprimitive,
error_kind=ERR_NEVER,
)

unicode_compare = custom_op(
arg_types=[str_rprimitive, str_rprimitive],
return_type=c_int_rprimitive,
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/irbuild-dict.test
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ L2:
k = r8
v = r7
r9 = 'name'
r10 = CPyStr_Equal(k, r9)
r10 = CPyStr_EqualLiteral(k, r9, 4)
if r10 goto L3 else goto L4 :: bool
L3:
name = v
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/irbuild-unreachable.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ L0:
r2 = CPyObject_GetAttr(r0, r1)
r3 = cast(str, r2)
r4 = 'x'
r5 = CPyStr_Equal(r3, r4)
r5 = CPyStr_EqualLiteral(r3, r4, 1)
if r5 goto L2 else goto L1 :: bool
L1:
r6 = r5
Expand Down Expand Up @@ -54,7 +54,7 @@ L0:
r2 = CPyObject_GetAttr(r0, r1)
r3 = cast(str, r2)
r4 = 'x'
r5 = CPyStr_Equal(r3, r4)
r5 = CPyStr_EqualLiteral(r3, r4, 1)
if r5 goto L2 else goto L1 :: bool
L1:
r6 = r5
Expand Down
Loading