Skip to content

[mypyc] Constant fold int operations and str concat #11194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,20 @@ def visit_comparison_op(self, op: ComparisonOp) -> None:
rhs = self.reg(op.rhs)
lhs_cast = ""
rhs_cast = ""
signed_op = {ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE}
unsigned_op = {ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE}
if op.op in signed_op:
if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
# Always signed comparison op
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
elif op.op in unsigned_op:
elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
# Always unsigned comparison op
lhs_cast = self.emit_unsigned_int_cast(op.lhs.type)
rhs_cast = self.emit_unsigned_int_cast(op.rhs.type)
elif isinstance(op.lhs, Integer) and op.lhs.value < 0:
# Force signed ==/!= with negative operand
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
elif isinstance(op.rhs, Integer) and op.rhs.value < 0:
# Force signed ==/!= with negative operand
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
self.emit_line('%s = %s%s %s %s%s;' % (dest, lhs_cast, lhs,
op.op_str[op.op], rhs_cast, rhs))

Expand Down Expand Up @@ -542,7 +548,12 @@ def reg(self, reg: Value) -> str:
s = str(val)
if val >= (1 << 31):
# Avoid overflowing signed 32-bit int
s += 'U'
s += 'ULL'
elif val == -(1 << 63):
# Avoid overflowing C integer literal
s = '(-9223372036854775807LL - 1)'
elif val <= -(1 << 31):
s += 'LL'
return s
else:
return self.emitter.reg(reg)
Expand Down
1 change: 1 addition & 0 deletions mypyc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
# Note: Assume that the compiled code uses the same bit width as mypyc, except for
# Python 3.5 on macOS.
MAX_LITERAL_SHORT_INT: Final = sys.maxsize >> 1 if not IS_MIXED_32_64_BIT_BUILD else 2 ** 30 - 1
MIN_LITERAL_SHORT_INT: Final = -MAX_LITERAL_SHORT_INT - 1

# Runtime C library files
RUNTIME_C_FILES: Final = [
Expand Down
99 changes: 99 additions & 0 deletions mypyc/irbuild/constant_fold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Constant folding of IR values.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this file be in the transform folder?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put the file here since transform contains IR-to-IR tranforms, and this is basically an AST analysis pass, which seems closer to what is happening in irbuild overall.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, then I don't have anything else blocking this from merging.


For example, 3 + 5 can be constant folded into 8.
"""

from typing import Optional, Union
from typing_extensions import Final

from mypy.nodes import Expression, IntExpr, StrExpr, OpExpr, UnaryExpr, NameExpr, MemberExpr, Var
from mypyc.irbuild.builder import IRBuilder


# All possible result types of constant folding
ConstantValue = Union[int, str]
CONST_TYPES: Final = (int, str)


def constant_fold_expr(builder: IRBuilder, expr: Expression) -> Optional[ConstantValue]:
"""Return the constant value of an expression for supported operations.

Return None otherwise.
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
elif isinstance(expr, NameExpr):
node = expr.node
if isinstance(node, Var) and node.is_final:
value = node.final_value
if isinstance(value, (CONST_TYPES)):
return value
elif isinstance(expr, MemberExpr):
final = builder.get_final_ref(expr)
if final is not None:
fn, final_var, native = final
if final_var.is_final:
value = final_var.final_value
if isinstance(value, (CONST_TYPES)):
return value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(builder, expr.left)
right = constant_fold_expr(builder, expr.right)
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(expr.op, left, right)
elif isinstance(left, str) and isinstance(right, str):
return constant_fold_binary_str_op(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(builder, expr.expr)
if isinstance(value, int):
return constant_fold_unary_int_op(expr.op, value)
return None


def constant_fold_binary_int_op(op: str, left: int, right: int) -> Optional[int]:
if op == '+':
return left + right
if op == '-':
return left - right
elif op == '*':
return left * right
elif op == '//':
if right != 0:
return left // right
elif op == '%':
if right != 0:
return left % right
elif op == '&':
return left & right
elif op == '|':
return left | right
elif op == '^':
return left ^ right
elif op == '<<':
if right >= 0:
return left << right
elif op == '>>':
if right >= 0:
return left >> right
elif op == '**':
if right >= 0:
return left ** right
return None


def constant_fold_unary_int_op(op: str, value: int) -> Optional[int]:
if op == '-':
return -value
elif op == '~':
return ~value
elif op == '+':
return value
return None


def constant_fold_binary_str_op(op: str, left: str, right: str) -> Optional[str]:
if op == '+':
return left + right
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A random thought: add support for "A" * 20?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a good idea. We already have an item about this (and a few other things) in the issue mypyc/mypyc#772. My intention is to add support for this and a few more cases in follow-up PRs (e.g. float arithmetic).

22 changes: 22 additions & 0 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
translate_list_comprehension, translate_set_comprehension,
comprehension_helper
)
from mypyc.irbuild.constant_fold import constant_fold_expr


# Name and attribute references
Expand Down Expand Up @@ -378,6 +379,10 @@ def translate_cast_expr(builder: IRBuilder, expr: CastExpr) -> Value:


def transform_unary_expr(builder: IRBuilder, expr: UnaryExpr) -> Value:
folded = try_constant_fold(builder, expr)
if folded:
return folded

return builder.unary_op(builder.accept(expr.expr), expr.op, expr.line)


Expand All @@ -391,6 +396,10 @@ def transform_op_expr(builder: IRBuilder, expr: OpExpr) -> Value:
if ret is not None:
return ret

folded = try_constant_fold(builder, expr)
if folded:
return folded

return builder.binary_op(
builder.accept(expr.left), builder.accept(expr.right), expr.op, expr.line
)
Expand All @@ -413,6 +422,19 @@ def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value:
base, '__getitem__', [index_reg], builder.node_type(expr), expr.line)


def try_constant_fold(builder: IRBuilder, expr: Expression) -> Optional[Value]:
"""Return the constant value of an expression if possible.

Return None otherwise.
"""
value = constant_fold_expr(builder, expr)
if isinstance(value, int):
return builder.load_int(value)
elif isinstance(value, str):
return builder.load_str(value)
return None


def try_gen_slice_op(builder: IRBuilder, base: Value, index: SliceExpr) -> Optional[Value]:
"""Generate specialized slice op for some index expressions.

Expand Down
6 changes: 3 additions & 3 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from mypyc.ir.func_ir import FuncDecl, FuncSignature
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
from mypyc.common import (
FAST_ISINSTANCE_MAX_SUBCLASSES, MAX_LITERAL_SHORT_INT, PLATFORM_SIZE, use_vectorcall,
use_method_vectorcall
FAST_ISINSTANCE_MAX_SUBCLASSES, MAX_LITERAL_SHORT_INT, MIN_LITERAL_SHORT_INT, PLATFORM_SIZE,
use_vectorcall, use_method_vectorcall
)
from mypyc.primitives.registry import (
method_call_ops, CFunctionDescription, function_ops,
Expand Down Expand Up @@ -789,7 +789,7 @@ def none_object(self) -> Value:

def load_int(self, value: int) -> Value:
"""Load a tagged (Python) integer literal value."""
if abs(value) > MAX_LITERAL_SHORT_INT:
if value > MAX_LITERAL_SHORT_INT or value < MIN_LITERAL_SHORT_INT:
return self.add(LoadLiteral(value, int_rprimitive))
else:
return Integer(value)
Expand Down
39 changes: 17 additions & 22 deletions mypyc/test-data/analysis.test
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,8 @@ def lol(x):
r2 :: object
r3 :: str
r4 :: object
r5 :: bit
r6 :: int
r7 :: bit
r8, r9 :: int
r5, r6 :: bit
r7, r8 :: int
L0:
L1:
r0 = CPyTagged_Id(x)
Expand All @@ -555,9 +553,8 @@ L3:
r5 = CPy_ExceptionMatches(r4)
if r5 goto L4 else goto L5 :: bool
L4:
r6 = CPyTagged_Negate(2)
CPy_RestoreExcInfo(r1)
return r6
return -2
L5:
CPy_Reraise()
if not 0 goto L8 else goto L6 :: bool
Expand All @@ -568,16 +565,16 @@ L7:
goto L10
L8:
CPy_RestoreExcInfo(r1)
r7 = CPy_KeepPropagating()
if not r7 goto L11 else goto L9 :: bool
r6 = CPy_KeepPropagating()
if not r6 goto L11 else goto L9 :: bool
L9:
unreachable
L10:
r8 = CPyTagged_Add(st, 2)
return r8
r7 = CPyTagged_Add(st, 2)
return r7
L11:
r9 = <error> :: int
return r9
r8 = <error> :: int
return r8
(0, 0) {x} {x}
(1, 0) {x} {r0}
(1, 1) {r0} {st}
Expand All @@ -589,20 +586,18 @@ L11:
(2, 4) {r1, r4} {r1, r4}
(3, 0) {r1, r4} {r1, r5}
(3, 1) {r1, r5} {r1}
(4, 0) {r1} {r1, r6}
(4, 1) {r1, r6} {r6}
(4, 2) {r6} {}
(4, 0) {r1} {}
(4, 1) {} {}
(5, 0) {r1} {r1}
(5, 1) {r1} {r1}
(6, 0) {} {}
(7, 0) {r1, st} {st}
(7, 1) {st} {st}
(8, 0) {r1} {}
(8, 1) {} {r7}
(8, 2) {r7} {}
(8, 1) {} {r6}
(8, 2) {r6} {}
(9, 0) {} {}
(10, 0) {st} {r8}
(10, 1) {r8} {}
(11, 0) {} {r9}
(11, 1) {r9} {}

(10, 0) {st} {r7}
(10, 1) {r7} {}
(11, 0) {} {r8}
(11, 1) {r8} {}
5 changes: 5 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __sub__(self, n: int) -> int: pass
def __mul__(self, n: int) -> int: pass
def __pow__(self, n: int, modulo: Optional[int] = None) -> int: pass
def __floordiv__(self, x: int) -> int: pass
def __truediv__(self, x: float) -> float: pass
def __mod__(self, x: int) -> int: pass
def __neg__(self) -> int: pass
def __pos__(self) -> int: pass
Expand Down Expand Up @@ -271,6 +272,10 @@ class NotImplementedError(RuntimeError): pass
class StopIteration(Exception):
value: Any

class ArithmeticError(Exception): pass

class ZeroDivisionError(Exception): pass

def any(i: Iterable[T]) -> bool: pass
def all(i: Iterable[T]) -> bool: pass
def reversed(object: Sequence[T]) -> Iterator[T]: ...
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -581,12 +581,12 @@ L8:

[case testUnaryMinus]
def f(n: int) -> int:
return -1
return -n
[out]
def f(n):
n, r0 :: int
L0:
r0 = CPyTagged_Negate(2)
r0 = CPyTagged_Negate(n)
return r0

[case testConditionalExpr]
Expand Down
Loading