Skip to content

[mypyc] Optimize str.encode with specializations for common used encodings #18232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@
dict_values_op,
)
from mypyc.primitives.list_ops import new_list_set_item_op
from mypyc.primitives.str_ops import (
str_encode_ascii_strict,
str_encode_latin1_strict,
str_encode_utf8_strict,
)
from mypyc.primitives.tuple_ops import new_tuple_set_item_op

# Specializers are attempted before compiling the arguments to the
Expand Down Expand Up @@ -682,6 +687,58 @@ def translate_fstring(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Va
return None


@specialize_function("encode", str_rprimitive)
def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
"""Specialize common cases of str.encode for most used encodings and strict errors."""

if not isinstance(callee, MemberExpr):
return None

# We can only specialize if we have string literals as args
if len(expr.arg_kinds) > 0 and not isinstance(expr.args[0], StrExpr):
return None
if len(expr.arg_kinds) > 1 and not isinstance(expr.args[1], StrExpr):
return None

encoding = "utf8"
errors = "strict"
if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr):
if expr.arg_kinds[0] == ARG_NAMED:
if expr.arg_names[0] == "encoding":
encoding = expr.args[0].value
elif expr.arg_names[0] == "errors":
errors = expr.args[0].value
elif expr.arg_kinds[0] == ARG_POS:
encoding = expr.args[0].value
else:
return None
if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr):
if expr.arg_kinds[1] == ARG_NAMED:
if expr.arg_names[1] == "encoding":
encoding = expr.args[1].value
elif expr.arg_names[1] == "errors":
errors = expr.args[1].value
elif expr.arg_kinds[1] == ARG_POS:
errors = expr.args[1].value
else:
return None

if errors != "strict":
# We can only specialize strict errors
return None

encoding = encoding.lower().replace("-", "").replace("_", "") # normalize
# Specialized encodings and their accepted aliases
if encoding in ["u8", "utf", "utf8", "cp65001"]:
return builder.call_c(str_encode_utf8_strict, [builder.accept(callee.expr)], expr.line)
elif encoding in ["646", "ascii", "usascii"]:
return builder.call_c(str_encode_ascii_strict, [builder.accept(callee.expr)], expr.line)
elif encoding in ["iso88591", "8859", "cp819", "latin", "latin1", "l1"]:
return builder.call_c(str_encode_latin1_strict, [builder.accept(callee.expr)], expr.line)

return None


@specialize_function("mypy_extensions.i64")
def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:
Expand Down
24 changes: 24 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,30 @@
extra_int_constants=[(0, pointer_rprimitive)],
)

# str.encode(encoding) - utf8 strict specialization
str_encode_utf8_strict = custom_op(
arg_types=[str_rprimitive],
return_type=bytes_rprimitive,
c_function_name="PyUnicode_AsUTF8String",
error_kind=ERR_MAGIC,
)

# str.encode(encoding) - ascii strict specialization
str_encode_ascii_strict = custom_op(
arg_types=[str_rprimitive],
return_type=bytes_rprimitive,
c_function_name="PyUnicode_AsASCIIString",
error_kind=ERR_MAGIC,
)

# str.encode(encoding) - latin1 strict specialization
str_encode_latin1_strict = custom_op(
arg_types=[str_rprimitive],
return_type=bytes_rprimitive,
c_function_name="PyUnicode_AsLatin1String",
error_kind=ERR_MAGIC,
)

# str.encode(encoding, errors)
method_op(
name="encode",
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def upper(self) -> str: ...
def startswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def endswith(self, x: str, start: int=..., end: int=...) -> bool: ...
def replace(self, old: str, new: str, maxcount: int=...) -> str: ...
def encode(self, x: str=..., y: str=...) -> bytes: ...
def encode(self, encoding: str=..., errors: str=...) -> bytes: ...

class float:
def __init__(self, x: object) -> None: pass
Expand Down
95 changes: 84 additions & 11 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -293,20 +293,93 @@ L0:
def f(s: str) -> None:
s.encode()
s.encode('utf-8')
s.encode('utf8', 'strict')
s.encode('latin1', errors='strict')
s.encode(encoding='ascii')
s.encode(errors='strict', encoding='latin-1')
s.encode('utf-8', 'backslashreplace')
s.encode('ascii', 'backslashreplace')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test cases where the specialization shouldn't be applied. Examples: s.encode(x), s.encode('a', x), s.encode('utf8', errors=x) and s.encode(errors=x) where x is not a literal.

Test cases where we have two keyword args: s.encode(encoding=..., errors=...) and s.encode(errors=..., encoding=...).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the logic to work out the args better and added more tests. Please take another look!

encoding = 'utf8'
s.encode(encoding)
errors = 'strict'
s.encode('utf8', errors)
s.encode('utf8', errors=errors)
s.encode(errors=errors)
s.encode(encoding=encoding, errors=errors)
s.encode('latin2')

[out]
def f(s):
s :: str
r0 :: bytes
r1 :: str
r2 :: bytes
r3, r4 :: str
r5 :: bytes
r0, r1, r2, r3, r4, r5 :: bytes
r6, r7 :: str
r8 :: bytes
r9, r10 :: str
r11 :: bytes
r12, encoding :: str
r13 :: bytes
r14, errors, r15 :: str
r16 :: bytes
r17, r18 :: str
r19 :: object
r20 :: str
r21 :: tuple
r22 :: dict
r23 :: object
r24 :: str
r25 :: object
r26 :: str
r27 :: tuple
r28 :: dict
r29 :: object
r30 :: str
r31 :: object
r32, r33 :: str
r34 :: tuple
r35 :: dict
r36 :: object
r37 :: str
r38 :: bytes
L0:
r0 = CPy_Encode(s, 0, 0)
r1 = 'utf-8'
r2 = CPy_Encode(s, r1, 0)
r3 = 'ascii'
r4 = 'backslashreplace'
r5 = CPy_Encode(s, r3, r4)
r0 = PyUnicode_AsUTF8String(s)
r1 = PyUnicode_AsUTF8String(s)
r2 = PyUnicode_AsUTF8String(s)
r3 = PyUnicode_AsLatin1String(s)
r4 = PyUnicode_AsASCIIString(s)
r5 = PyUnicode_AsLatin1String(s)
r6 = 'utf-8'
r7 = 'backslashreplace'
r8 = CPy_Encode(s, r6, r7)
r9 = 'ascii'
r10 = 'backslashreplace'
r11 = CPy_Encode(s, r9, r10)
r12 = 'utf8'
encoding = r12
r13 = CPy_Encode(s, encoding, 0)
r14 = 'strict'
errors = r14
r15 = 'utf8'
r16 = CPy_Encode(s, r15, errors)
r17 = 'utf8'
r18 = 'encode'
r19 = CPyObject_GetAttr(s, r18)
r20 = 'errors'
r21 = PyTuple_Pack(1, r17)
r22 = CPyDict_Build(1, r20, errors)
r23 = PyObject_Call(r19, r21, r22)
r24 = 'encode'
r25 = CPyObject_GetAttr(s, r24)
r26 = 'errors'
r27 = PyTuple_Pack(0)
r28 = CPyDict_Build(1, r26, errors)
r29 = PyObject_Call(r25, r27, r28)
r30 = 'encode'
r31 = CPyObject_GetAttr(s, r30)
r32 = 'encoding'
r33 = 'errors'
r34 = PyTuple_Pack(0)
r35 = CPyDict_Build(2, r32, encoding, r33, errors)
r36 = PyObject_Call(r31, r34, r35)
r37 = 'latin2'
r38 = CPy_Encode(s, r37, 0)
return 1
Loading