22
22
Context , Decorator , PrintStmt , BreakStmt , PassStmt , ContinueStmt ,
23
23
ComparisonExpr , StarExpr , EllipsisExpr , RefExpr , PromoteExpr ,
24
24
Import , ImportFrom , ImportAll , ImportBase , TypeAlias ,
25
- ARG_POS , ARG_STAR , LITERAL_TYPE , MDEF , GDEF ,
25
+ ARG_POS , ARG_STAR , LITERAL_TYPE , LDEF , MDEF , GDEF ,
26
26
CONTRAVARIANT , COVARIANT , INVARIANT , TypeVarExpr , AssignmentExpr ,
27
27
is_final_node ,
28
28
ARG_NAMED )
29
29
from mypy import nodes
30
30
from mypy import operators
31
31
from mypy .literals import literal , literal_hash , Key
32
- from mypy .typeanal import has_any_from_unimported_type , check_for_explicit_any
32
+ from mypy .typeanal import has_any_from_unimported_type , check_for_explicit_any , make_optional_type
33
33
from mypy .types import (
34
34
Type , AnyType , CallableType , FunctionLike , Overloaded , TupleType , TypedDictType ,
35
35
Instance , NoneType , strip_type , TypeType , TypeOfAny ,
@@ -203,6 +203,14 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
203
203
# directly or indirectly.
204
204
module_refs : Set [str ]
205
205
206
+ # A map from variable nodes to a snapshot of the frame ids of the
207
+ # frames that were active when the variable was declared. This can
208
+ # be used to determine nearest common ancestor frame of a variable's
209
+ # declaration and the current frame, which lets us determine if it
210
+ # was declared in a different branch of the same `if` statement
211
+ # (if that frame is a conditional_frame).
212
+ var_decl_frames : Dict [Var , Set [int ]]
213
+
206
214
# Plugin that provides special type checking rules for specific library
207
215
# functions such as open(), etc.
208
216
plugin : Plugin
@@ -229,6 +237,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
229
237
self .dynamic_funcs = []
230
238
self .partial_types = []
231
239
self .partial_reported = set ()
240
+ self .var_decl_frames = {}
232
241
self .deferred_nodes = []
233
242
self .type_map = {}
234
243
self .module_refs = set ()
@@ -411,7 +420,7 @@ def accept_loop(self, body: Statement, else_body: Optional[Statement] = None, *,
411
420
Then check the else_body.
412
421
"""
413
422
# The outer frame accumulates the results of all iterations
414
- with self .binder .frame_context (can_skip = False ):
423
+ with self .binder .frame_context (can_skip = False , conditional_frame = True ):
415
424
while True :
416
425
with self .binder .frame_context (can_skip = True ,
417
426
break_frame = 2 , continue_frame = 1 ):
@@ -2167,6 +2176,31 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
2167
2176
rvalue_type , lvalue_type , infer_lvalue_type = self .check_member_assignment (
2168
2177
instance_type , lvalue_type , rvalue , context = rvalue )
2169
2178
else :
2179
+ # Hacky special case for assigning a literal None
2180
+ # to a variable defined in a previous if
2181
+ # branch. When we detect this, we'll go back and
2182
+ # make the type optional. This is somewhat
2183
+ # unpleasant, and a generalization of this would
2184
+ # be an improvement!
2185
+ if (is_literal_none (rvalue ) and
2186
+ isinstance (lvalue , NameExpr ) and
2187
+ lvalue .kind == LDEF and
2188
+ isinstance (lvalue .node , Var ) and
2189
+ lvalue .node .type and
2190
+ lvalue .node in self .var_decl_frames and
2191
+ not isinstance (get_proper_type (lvalue_type ), AnyType )):
2192
+ decl_frame_map = self .var_decl_frames [lvalue .node ]
2193
+ # Check if the nearest common ancestor frame for the definition site
2194
+ # and the current site is the enclosing frame of an if/elif/else block.
2195
+ has_if_ancestor = False
2196
+ for frame in reversed (self .binder .frames ):
2197
+ if frame .id in decl_frame_map :
2198
+ has_if_ancestor = frame .conditional_frame
2199
+ break
2200
+ if has_if_ancestor :
2201
+ lvalue_type = make_optional_type (lvalue_type )
2202
+ self .set_inferred_type (lvalue .node , lvalue , lvalue_type )
2203
+
2170
2204
rvalue_type = self .check_simple_assignment (lvalue_type , rvalue , context = rvalue ,
2171
2205
code = codes .ASSIGNMENT )
2172
2206
@@ -2992,6 +3026,9 @@ def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
2992
3026
if var and not self .current_node_deferred :
2993
3027
var .type = type
2994
3028
var .is_inferred = True
3029
+ if var not in self .var_decl_frames :
3030
+ # Used for the hack to improve optional type inference in conditionals
3031
+ self .var_decl_frames [var ] = {frame .id for frame in self .binder .frames }
2995
3032
if isinstance (lvalue , MemberExpr ) and self .inferred_attribute_types is not None :
2996
3033
# Store inferred attribute type so that we can check consistency afterwards.
2997
3034
if lvalue .def_var is not None :
@@ -3298,7 +3335,7 @@ def visit_if_stmt(self, s: IfStmt) -> None:
3298
3335
"""Type check an if statement."""
3299
3336
# This frame records the knowledge from previous if/elif clauses not being taken.
3300
3337
# Fall-through to the original frame is handled explicitly in each block.
3301
- with self .binder .frame_context (can_skip = False , fall_through = 0 ):
3338
+ with self .binder .frame_context (can_skip = False , conditional_frame = True , fall_through = 0 ):
3302
3339
for e , b in zip (s .expr , s .body ):
3303
3340
t = get_proper_type (self .expr_checker .accept (e ))
3304
3341
@@ -3437,7 +3474,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
3437
3474
# was the top frame on entry.
3438
3475
with self .binder .frame_context (can_skip = False , fall_through = 2 , try_frame = try_frame ):
3439
3476
# This frame receives exit via exception, and runs exception handlers
3440
- with self .binder .frame_context (can_skip = False , fall_through = 2 ):
3477
+ with self .binder .frame_context (can_skip = False , conditional_frame = True , fall_through = 2 ):
3441
3478
# Finally, the body of the try statement
3442
3479
with self .binder .frame_context (can_skip = False , fall_through = 2 , try_frame = True ):
3443
3480
self .accept (s .body )
@@ -4925,6 +4962,7 @@ def handle_partial_var_type(
4925
4962
if is_local or not self .options .allow_untyped_globals :
4926
4963
self .msg .need_annotation_for_var (node , context ,
4927
4964
self .options .python_version )
4965
+ self .partial_reported .add (node )
4928
4966
else :
4929
4967
# Defer the node -- we might get a better type in the outer scope
4930
4968
self .handle_cannot_determine_type (node .name , context )
0 commit comments