From e10dec5c6b9718349e53c4e97dc55ef28a47adc4 Mon Sep 17 00:00:00 2001 From: Valentin Stanciu <250871+svalentin@users.noreply.github.com> Date: Mon, 2 Dec 2024 18:12:46 +0000 Subject: [PATCH 1/3] Optimize str.encode with specializations for common used encodings --- mypyc/irbuild/specialize.py | 44 ++++++++++++++++++++++++++++++++ mypyc/primitives/str_ops.py | 24 +++++++++++++++++ mypyc/test-data/fixtures/ir.py | 2 +- mypyc/test-data/irbuild-str.test | 32 +++++++++++++++-------- 4 files changed, 90 insertions(+), 12 deletions(-) diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 7c5958457886..26d00211ab86 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -90,6 +90,11 @@ ) from mypyc.primitives.list_ops import new_list_set_item_op from mypyc.primitives.tuple_ops import new_tuple_set_item_op +from mypyc.primitives.str_ops import ( + str_encode_utf8_strict, + str_encode_ascii_strict, + str_encode_latin1_strict, +) # Specializers are attempted before compiling the arguments to the # function. Specializers can return None to indicate that they failed @@ -682,6 +687,45 @@ def translate_fstring(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Va return None +@specialize_function("encode", str_rprimitive) +def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: + """Specialize common cases of str.encode for most used encodings and strict errors.""" + + if not isinstance(callee, MemberExpr): + return None + + # We can only specialize strict errors + if ( + len(expr.arg_kinds) > 1 + and isinstance(expr.args[1], StrExpr) + and expr.args[1].value != "strict" + ): + return None + + if ( + len(expr.args) > 0 + and expr.arg_kinds[0] == ARG_NAMED + and expr.arg_names[0] == "errors" + and isinstance(expr.args[0], StrExpr) + and expr.args[0].value != "strict" + ): + return None + + encoding = "utf8" + if len(expr.args) > 0 and isinstance(expr.args[0], StrExpr): + encoding = expr.args[0].value.lower().replace("-", "_") + + # Specialized encodings and their accepted aliases + if encoding in ['u8', 'utf', 'utf8', 'utf_8', 'cp65001']: + return builder.call_c(str_encode_utf8_strict, [builder.accept(callee.expr)], expr.line) + elif encoding in ["ascii", "646", "us_ascii"]: + return builder.call_c(str_encode_ascii_strict, [builder.accept(callee.expr)], expr.line) + elif encoding in ['iso_8859_1', 'iso8859_1', '8859', 'cp819', 'latin', 'latin1', 'latin_1', 'l1']: + return builder.call_c(str_encode_latin1_strict, [builder.accept(callee.expr)], expr.line) + + return None + + @specialize_function("mypy_extensions.i64") def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS: diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index 2ff1fbdb4b3e..3a5495e21c1b 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -219,6 +219,30 @@ extra_int_constants=[(0, pointer_rprimitive)], ) +# str.encode(encoding) - utf8 strict specialization +str_encode_utf8_strict = custom_op( + arg_types=[str_rprimitive], + return_type=bytes_rprimitive, + c_function_name="PyUnicode_AsUTF8String", + error_kind=ERR_MAGIC, +) + +# str.encode(encoding) - ascii strict specialization +str_encode_ascii_strict = custom_op( + arg_types=[str_rprimitive], + return_type=bytes_rprimitive, + c_function_name="PyUnicode_AsASCIIString", + error_kind=ERR_MAGIC, +) + +# str.encode(encoding) - latin1 strict specialization +str_encode_latin1_strict = custom_op( + arg_types=[str_rprimitive], + return_type=bytes_rprimitive, + c_function_name="PyUnicode_AsLatin1String", + error_kind=ERR_MAGIC, +) + # str.encode(encoding, errors) method_op( name="encode", diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index ac95ffe2c047..be66307286fc 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -110,7 +110,7 @@ def upper(self) -> str: ... def startswith(self, x: str, start: int=..., end: int=...) -> bool: ... def endswith(self, x: str, start: int=..., end: int=...) -> bool: ... def replace(self, old: str, new: str, maxcount: int=...) -> str: ... - def encode(self, x: str=..., y: str=...) -> bytes: ... + def encode(self, encoding: str=..., errors: str=...) -> bytes: ... class float: def __init__(self, x: object) -> None: pass diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 771dcc4c0e68..1e8265b91edb 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -293,20 +293,30 @@ L0: def f(s: str) -> None: s.encode() s.encode('utf-8') + s.encode('utf-8', 'strict') + s.encode('utf-8', errors='strict') + s.encode('utf-8', 'backslashreplace') + s.encode(encoding='ascii') s.encode('ascii', 'backslashreplace') [out] def f(s): s :: str - r0 :: bytes - r1 :: str - r2 :: bytes - r3, r4 :: str - r5 :: bytes + r0, r1, r2, r3 :: bytes + r4, r5 :: str + r6, r7 :: bytes + r8, r9 :: str + r10 :: bytes L0: - r0 = CPy_Encode(s, 0, 0) - r1 = 'utf-8' - r2 = CPy_Encode(s, r1, 0) - r3 = 'ascii' - r4 = 'backslashreplace' - r5 = CPy_Encode(s, r3, r4) + r0 = PyUnicode_AsUTF8String(s) + r1 = PyUnicode_AsUTF8String(s) + r2 = PyUnicode_AsUTF8String(s) + r3 = PyUnicode_AsUTF8String(s) + r4 = 'utf-8' + r5 = 'backslashreplace' + r6 = CPy_Encode(s, r4, r5) + r7 = PyUnicode_AsASCIIString(s) + r8 = 'ascii' + r9 = 'backslashreplace' + r10 = CPy_Encode(s, r8, r9) return 1 + From 61cd52284f64c76665be27c1162c02f38afdda74 Mon Sep 17 00:00:00 2001 From: Valentin Stanciu <250871+svalentin@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:25:13 +0000 Subject: [PATCH 2/3] Update arg logic and add more tests --- mypyc/irbuild/specialize.py | 51 ++++++++++------- mypyc/test-data/irbuild-str.test | 96 ++++++++++++++++++++++++++------ 2 files changed, 112 insertions(+), 35 deletions(-) diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 26d00211ab86..544e158a79bd 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -694,33 +694,46 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> if not isinstance(callee, MemberExpr): return None - # We can only specialize strict errors - if ( - len(expr.arg_kinds) > 1 - and isinstance(expr.args[1], StrExpr) - and expr.args[1].value != "strict" - ): + # We can only specialize if we have string literals as args + if len(expr.arg_kinds) > 0 and not isinstance(expr.args[0], StrExpr): return None - - if ( - len(expr.args) > 0 - and expr.arg_kinds[0] == ARG_NAMED - and expr.arg_names[0] == "errors" - and isinstance(expr.args[0], StrExpr) - and expr.args[0].value != "strict" - ): + if len(expr.arg_kinds) > 1 and not isinstance(expr.args[1], StrExpr): return None encoding = "utf8" - if len(expr.args) > 0 and isinstance(expr.args[0], StrExpr): - encoding = expr.args[0].value.lower().replace("-", "_") + errors = "strict" + if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr): + if expr.arg_kinds[0] == ARG_NAMED: + if expr.arg_names[0] == "encoding": + encoding = expr.args[0].value + elif expr.arg_names[0] == "errors": + errors = expr.args[0].value + elif expr.arg_kinds[0] == ARG_POS: + encoding = expr.args[0].value + else: + return None + if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr): + if expr.arg_kinds[1] == ARG_NAMED: + if expr.arg_names[1] == "encoding": + encoding = expr.args[1].value + elif expr.arg_names[1] == "errors": + errors = expr.args[1].value + elif expr.arg_kinds[1] == ARG_POS: + errors = expr.args[1].value + else: + return None + + if errors != "strict": + # We can only specialize strict errors + return None + encoding = encoding.lower().replace("-", "").replace("_", "") # normalize # Specialized encodings and their accepted aliases - if encoding in ['u8', 'utf', 'utf8', 'utf_8', 'cp65001']: + if encoding in ["u8", "utf", "utf8", "cp65001"]: return builder.call_c(str_encode_utf8_strict, [builder.accept(callee.expr)], expr.line) - elif encoding in ["ascii", "646", "us_ascii"]: + elif encoding in ["646", "ascii", "usascii"]: return builder.call_c(str_encode_ascii_strict, [builder.accept(callee.expr)], expr.line) - elif encoding in ['iso_8859_1', 'iso8859_1', '8859', 'cp819', 'latin', 'latin1', 'latin_1', 'l1']: + elif encoding in ["iso88591", "8859", "cp819", "latin", "latin1", "l1"]: return builder.call_c(str_encode_latin1_strict, [builder.accept(callee.expr)], expr.line) return None diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 1e8265b91edb..1993574d7b66 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -293,30 +293,94 @@ L0: def f(s: str) -> None: s.encode() s.encode('utf-8') - s.encode('utf-8', 'strict') - s.encode('utf-8', errors='strict') - s.encode('utf-8', 'backslashreplace') + s.encode('utf8', 'strict') + s.encode('latin1', errors='strict') s.encode(encoding='ascii') + s.encode(errors='strict', encoding='latin-1') + s.encode('utf-8', 'backslashreplace') s.encode('ascii', 'backslashreplace') + encoding = 'utf8' + s.encode(encoding) + errors = 'strict' + s.encode('utf8', errors) + s.encode('utf8', errors=errors) + s.encode(errors=errors) + s.encode(encoding=encoding, errors=errors) + s.encode('latin2') + [out] def f(s): s :: str - r0, r1, r2, r3 :: bytes - r4, r5 :: str - r6, r7 :: bytes - r8, r9 :: str - r10 :: bytes + r0, r1, r2, r3, r4, r5 :: bytes + r6, r7 :: str + r8 :: bytes + r9, r10 :: str + r11 :: bytes + r12, encoding :: str + r13 :: bytes + r14, errors, r15 :: str + r16 :: bytes + r17, r18 :: str + r19 :: object + r20 :: str + r21 :: tuple + r22 :: dict + r23 :: object + r24 :: str + r25 :: object + r26 :: str + r27 :: tuple + r28 :: dict + r29 :: object + r30 :: str + r31 :: object + r32, r33 :: str + r34 :: tuple + r35 :: dict + r36 :: object + r37 :: str + r38 :: bytes L0: r0 = PyUnicode_AsUTF8String(s) r1 = PyUnicode_AsUTF8String(s) r2 = PyUnicode_AsUTF8String(s) - r3 = PyUnicode_AsUTF8String(s) - r4 = 'utf-8' - r5 = 'backslashreplace' - r6 = CPy_Encode(s, r4, r5) - r7 = PyUnicode_AsASCIIString(s) - r8 = 'ascii' - r9 = 'backslashreplace' - r10 = CPy_Encode(s, r8, r9) + r3 = PyUnicode_AsLatin1String(s) + r4 = PyUnicode_AsASCIIString(s) + r5 = PyUnicode_AsLatin1String(s) + r6 = 'utf-8' + r7 = 'backslashreplace' + r8 = CPy_Encode(s, r6, r7) + r9 = 'ascii' + r10 = 'backslashreplace' + r11 = CPy_Encode(s, r9, r10) + r12 = 'utf8' + encoding = r12 + r13 = CPy_Encode(s, encoding, 0) + r14 = 'strict' + errors = r14 + r15 = 'utf8' + r16 = CPy_Encode(s, r15, errors) + r17 = 'utf8' + r18 = 'encode' + r19 = CPyObject_GetAttr(s, r18) + r20 = 'errors' + r21 = PyTuple_Pack(1, r17) + r22 = CPyDict_Build(1, r20, errors) + r23 = PyObject_Call(r19, r21, r22) + r24 = 'encode' + r25 = CPyObject_GetAttr(s, r24) + r26 = 'errors' + r27 = PyTuple_Pack(0) + r28 = CPyDict_Build(1, r26, errors) + r29 = PyObject_Call(r25, r27, r28) + r30 = 'encode' + r31 = CPyObject_GetAttr(s, r30) + r32 = 'encoding' + r33 = 'errors' + r34 = PyTuple_Pack(0) + r35 = CPyDict_Build(2, r32, encoding, r33, errors) + r36 = PyObject_Call(r31, r34, r35) + r37 = 'latin2' + r38 = CPy_Encode(s, r37, 0) return 1 From 5c641165e36f35055ac969a60fb52285ea697b34 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:26:11 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypyc/irbuild/specialize.py | 4 ++-- mypyc/test-data/irbuild-str.test | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 544e158a79bd..cb69852af1ce 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -89,12 +89,12 @@ dict_values_op, ) from mypyc.primitives.list_ops import new_list_set_item_op -from mypyc.primitives.tuple_ops import new_tuple_set_item_op from mypyc.primitives.str_ops import ( - str_encode_utf8_strict, str_encode_ascii_strict, str_encode_latin1_strict, + str_encode_utf8_strict, ) +from mypyc.primitives.tuple_ops import new_tuple_set_item_op # Specializers are attempted before compiling the arguments to the # function. Specializers can return None to indicate that they failed diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 1993574d7b66..61e5a42cf3ef 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -383,4 +383,3 @@ L0: r37 = 'latin2' r38 = CPy_Encode(s, r37, 0) return 1 -