Skip to content

Commit 8757747

Browse files
Add support for assert_type (#12584)
See python/cpython#30843. The implementation mostly follows that of cast(). It relies on `mypy.sametypes.is_same_type()`.
1 parent 44993e6 commit 8757747

24 files changed

+175
-12
lines changed

mypy/checkexpr.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
get_proper_types, flatten_nested_unions, LITERAL_TYPE_NAMES,
2424
)
2525
from mypy.nodes import (
26-
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
26+
AssertTypeExpr, NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
2727
MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr,
2828
OpExpr, UnaryExpr, IndexExpr, CastExpr, RevealExpr, TypeApplication, ListExpr,
2929
TupleExpr, DictExpr, LambdaExpr, SuperExpr, SliceExpr, Context, Expression,
@@ -3144,6 +3144,14 @@ def visit_cast_expr(self, expr: CastExpr) -> Type:
31443144
context=expr)
31453145
return target_type
31463146

3147+
def visit_assert_type_expr(self, expr: AssertTypeExpr) -> Type:
3148+
source_type = self.accept(expr.expr, type_context=AnyType(TypeOfAny.special_form),
3149+
allow_none_return=True, always_allow_any=True)
3150+
target_type = expr.type
3151+
if not is_same_type(source_type, target_type):
3152+
self.msg.assert_type_fail(source_type, target_type, expr)
3153+
return source_type
3154+
31473155
def visit_reveal_expr(self, expr: RevealExpr) -> Type:
31483156
"""Type check a reveal_type expression."""
31493157
if expr.kind == REVEAL_TYPE:

mypy/errorcodes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def __str__(self) -> str:
113113
REDUNDANT_CAST: Final = ErrorCode(
114114
"redundant-cast", "Check that cast changes type of expression", "General"
115115
)
116+
ASSERT_TYPE: Final = ErrorCode(
117+
"assert-type", "Check that assert_type() call succeeds", "General"
118+
)
116119
COMPARISON_OVERLAP: Final = ErrorCode(
117120
"comparison-overlap", "Check that types in comparisons and 'in' expressions overlap", "General"
118121
)

mypy/literals.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
ConditionalExpr, EllipsisExpr, YieldFromExpr, YieldExpr, RevealExpr, SuperExpr,
99
TypeApplication, LambdaExpr, ListComprehension, SetComprehension, DictionaryComprehension,
1010
GeneratorExpr, BackquoteExpr, TypeVarExpr, TypeAliasExpr, NamedTupleExpr, EnumCallExpr,
11-
TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, ParamSpecExpr
11+
TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, ParamSpecExpr,
12+
AssertTypeExpr,
1213
)
1314
from mypy.visitor import ExpressionVisitor
1415

@@ -175,6 +176,9 @@ def visit_slice_expr(self, e: SliceExpr) -> None:
175176
def visit_cast_expr(self, e: CastExpr) -> None:
176177
return None
177178

179+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
180+
return None
181+
178182
def visit_conditional_expr(self, e: ConditionalExpr) -> None:
179183
return None
180184

mypy/messages.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,11 @@ def redundant_cast(self, typ: Type, context: Context) -> None:
12131213
self.fail('Redundant cast to {}'.format(format_type(typ)), context,
12141214
code=codes.REDUNDANT_CAST)
12151215

1216+
def assert_type_fail(self, source_type: Type, target_type: Type, context: Context) -> None:
1217+
self.fail(f"Expression is of type {format_type(source_type)}, "
1218+
f"not {format_type(target_type)}", context,
1219+
code=codes.ASSERT_TYPE)
1220+
12161221
def unimported_type_becomes_any(self, prefix: str, typ: Type, ctx: Context) -> None:
12171222
self.fail("{} becomes {} due to an unfollowed import".format(prefix, format_type(typ)),
12181223
ctx, code=codes.NO_ANY_UNIMPORTED)

mypy/mixedtraverser.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22

33
from mypy.nodes import (
4-
Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
4+
AssertTypeExpr, Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt,
55
CastExpr, TypeApplication, TypeAliasExpr, TypeVarExpr, TypedDictExpr, NamedTupleExpr,
66
PromoteExpr, NewTypeExpr
77
)
@@ -79,6 +79,10 @@ def visit_cast_expr(self, o: CastExpr) -> None:
7979
super().visit_cast_expr(o)
8080
o.type.accept(self)
8181

82+
def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
83+
super().visit_assert_type_expr(o)
84+
o.type.accept(self)
85+
8286
def visit_type_application(self, o: TypeApplication) -> None:
8387
super().visit_type_application(o)
8488
for t in o.types:

mypy/nodes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,6 +1945,22 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
19451945
return visitor.visit_cast_expr(self)
19461946

19471947

1948+
class AssertTypeExpr(Expression):
1949+
"""Represents a typing.assert_type(expr, type) call."""
1950+
__slots__ = ('expr', 'type')
1951+
1952+
expr: Expression
1953+
type: "mypy.types.Type"
1954+
1955+
def __init__(self, expr: Expression, typ: 'mypy.types.Type') -> None:
1956+
super().__init__()
1957+
self.expr = expr
1958+
self.type = typ
1959+
1960+
def accept(self, visitor: ExpressionVisitor[T]) -> T:
1961+
return visitor.visit_assert_type_expr(self)
1962+
1963+
19481964
class RevealExpr(Expression):
19491965
"""Reveal type expression reveal_type(expr) or reveal_locals() expression."""
19501966

mypy/semanal.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from typing_extensions import Final, TypeAlias as _TypeAlias
5757

5858
from mypy.nodes import (
59-
MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
59+
AssertTypeExpr, MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef,
6060
ClassDef, Var, GDEF, FuncItem, Import, Expression, Lvalue,
6161
ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr,
6262
IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt,
@@ -99,7 +99,7 @@
9999
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
100100
get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType,
101101
PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES,
102-
is_named_instance,
102+
ASSERT_TYPE_NAMES, is_named_instance,
103103
)
104104
from mypy.typeops import function_type, get_type_vars
105105
from mypy.type_visitor import TypeQuery
@@ -3897,6 +3897,19 @@ def visit_call_expr(self, expr: CallExpr) -> None:
38973897
expr.analyzed.line = expr.line
38983898
expr.analyzed.column = expr.column
38993899
expr.analyzed.accept(self)
3900+
elif refers_to_fullname(expr.callee, ASSERT_TYPE_NAMES):
3901+
if not self.check_fixed_args(expr, 2, 'assert_type'):
3902+
return
3903+
# Translate second argument to an unanalyzed type.
3904+
try:
3905+
target = self.expr_to_unanalyzed_type(expr.args[1])
3906+
except TypeTranslationError:
3907+
self.fail('assert_type() type is not a type', expr)
3908+
return
3909+
expr.analyzed = AssertTypeExpr(expr.args[0], target)
3910+
expr.analyzed.line = expr.line
3911+
expr.analyzed.column = expr.column
3912+
expr.analyzed.accept(self)
39003913
elif refers_to_fullname(expr.callee, REVEAL_TYPE_NAMES):
39013914
if not self.check_fixed_args(expr, 1, 'reveal_type'):
39023915
return
@@ -4200,6 +4213,12 @@ def visit_cast_expr(self, expr: CastExpr) -> None:
42004213
if analyzed is not None:
42014214
expr.type = analyzed
42024215

4216+
def visit_assert_type_expr(self, expr: AssertTypeExpr) -> None:
4217+
expr.expr.accept(self)
4218+
analyzed = self.anal_type(expr.type)
4219+
if analyzed is not None:
4220+
expr.type = analyzed
4221+
42034222
def visit_reveal_expr(self, expr: RevealExpr) -> None:
42044223
if expr.kind == REVEAL_TYPE:
42054224
if expr.expr is not None:

mypy/server/astmerge.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
MypyFile, SymbolTable, Block, AssignmentStmt, NameExpr, MemberExpr, RefExpr, TypeInfo,
5252
FuncDef, ClassDef, NamedTupleExpr, SymbolNode, Var, Statement, SuperExpr, NewTypeExpr,
5353
OverloadedFuncDef, LambdaExpr, TypedDictExpr, EnumCallExpr, FuncBase, TypeAliasExpr, CallExpr,
54-
CastExpr, TypeAlias,
54+
CastExpr, TypeAlias, AssertTypeExpr,
5555
MDEF
5656
)
5757
from mypy.traverser import TraverserVisitor
@@ -226,6 +226,10 @@ def visit_cast_expr(self, node: CastExpr) -> None:
226226
super().visit_cast_expr(node)
227227
self.fixup_type(node.type)
228228

229+
def visit_assert_type_expr(self, node: AssertTypeExpr) -> None:
230+
super().visit_assert_type_expr(node)
231+
self.fixup_type(node.type)
232+
229233
def visit_super_expr(self, node: SuperExpr) -> None:
230234
super().visit_super_expr(node)
231235
if node.info is not None:

mypy/server/deps.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
8989
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
9090
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
9191
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
92-
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr
92+
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr,
93+
AssertTypeExpr,
9394
)
9495
from mypy.operators import (
9596
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
@@ -686,6 +687,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
686687
super().visit_cast_expr(e)
687688
self.add_type_dependencies(e.type)
688689

690+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
691+
super().visit_assert_type_expr(e)
692+
self.add_type_dependencies(e.type)
693+
689694
def visit_type_application(self, e: TypeApplication) -> None:
690695
super().visit_type_application(e)
691696
for typ in e.types:

mypy/server/subexpr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
SliceExpr, CastExpr, RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr,
88
IndexExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension,
99
ConditionalExpr, TypeApplication, LambdaExpr, StarExpr, BackquoteExpr, AwaitExpr,
10-
AssignmentExpr,
10+
AssignmentExpr, AssertTypeExpr,
1111
)
1212
from mypy.traverser import TraverserVisitor
1313

@@ -99,6 +99,10 @@ def visit_cast_expr(self, e: CastExpr) -> None:
9999
self.add(e)
100100
super().visit_cast_expr(e)
101101

102+
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
103+
self.add(e)
104+
super().visit_assert_type_expr(e)
105+
102106
def visit_reveal_expr(self, e: RevealExpr) -> None:
103107
self.add(e)
104108
super().visit_reveal_expr(e)

0 commit comments

Comments
 (0)