Skip to content

Commit 56b6803

Browse files
authored
Support making a variable Optional in an else branch (#11002)
That is, support patterns such as: ``` if condition: foo = Foo() else: foo = None ``` Currently this does not work, but the the *reverse* does (because foo will be inferred as a PartialType). I think it might be worth tackling this in a more general way, for other types, though I think that is a little fiddlier and likely to be more controversial, so I'm starting with something special-cased for the "assigning literal None" case first. The rule we implement is that we allow updating the type of a variable when assigning `None` to it if the variable's type was inferred and it was defined in an earlier branch of the same `if/then/else` statement. Some infrastructure is added to make determinations about that.
1 parent ea7fed1 commit 56b6803

File tree

3 files changed

+303
-11
lines changed

3 files changed

+303
-11
lines changed

mypy/binder.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ class Frame:
3131
that were assigned in that frame.
3232
"""
3333

34-
def __init__(self) -> None:
34+
def __init__(self, id: int, conditional_frame: bool = False) -> None:
35+
self.id = id
3536
self.types: Dict[Key, Type] = {}
3637
self.unreachable = False
38+
self.conditional_frame = conditional_frame
3739

3840
# Should be set only if we're entering a frame where it's not
3941
# possible to accurately determine whether or not contained
@@ -72,13 +74,15 @@ class A:
7274
type_assignments: Optional[Assigns] = None
7375

7476
def __init__(self) -> None:
77+
self.next_id = 1
78+
7579
# The stack of frames currently used. These map
7680
# literal_hash(expr) -- literals like 'foo.bar' --
7781
# to types. The last element of this list is the
7882
# top-most, current frame. Each earlier element
7983
# records the state as of when that frame was last
8084
# on top of the stack.
81-
self.frames = [Frame()]
85+
self.frames = [Frame(self._get_id())]
8286

8387
# For frames higher in the stack, we record the set of
8488
# Frames that can escape there, either by falling off
@@ -101,6 +105,10 @@ def __init__(self) -> None:
101105
self.break_frames: List[int] = []
102106
self.continue_frames: List[int] = []
103107

108+
def _get_id(self) -> int:
109+
self.next_id += 1
110+
return self.next_id
111+
104112
def _add_dependencies(self, key: Key, value: Optional[Key] = None) -> None:
105113
if value is None:
106114
value = key
@@ -109,9 +117,9 @@ def _add_dependencies(self, key: Key, value: Optional[Key] = None) -> None:
109117
for elt in subkeys(key):
110118
self._add_dependencies(elt, value)
111119

112-
def push_frame(self) -> Frame:
120+
def push_frame(self, conditional_frame: bool = False) -> Frame:
113121
"""Push a new frame into the binder."""
114-
f = Frame()
122+
f = Frame(self._get_id(), conditional_frame)
115123
self.frames.append(f)
116124
self.options_on_return.append([])
117125
return f
@@ -349,7 +357,7 @@ def allow_jump(self, index: int) -> None:
349357
# so make sure the index is positive
350358
if index < 0:
351359
index += len(self.options_on_return)
352-
frame = Frame()
360+
frame = Frame(self._get_id())
353361
for f in self.frames[index + 1:]:
354362
frame.types.update(f.types)
355363
if f.unreachable:
@@ -367,6 +375,7 @@ def handle_continue(self) -> None:
367375
@contextmanager
368376
def frame_context(self, *, can_skip: bool, fall_through: int = 1,
369377
break_frame: int = 0, continue_frame: int = 0,
378+
conditional_frame: bool = False,
370379
try_frame: bool = False) -> Iterator[Frame]:
371380
"""Return a context manager that pushes/pops frames on enter/exit.
372381
@@ -401,7 +410,7 @@ def frame_context(self, *, can_skip: bool, fall_through: int = 1,
401410
if try_frame:
402411
self.try_frames.add(len(self.frames) - 1)
403412

404-
new_frame = self.push_frame()
413+
new_frame = self.push_frame(conditional_frame)
405414
if try_frame:
406415
# An exception may occur immediately
407416
self.allow_jump(-1)

mypy/checker.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt,
2323
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr,
2424
Import, ImportFrom, ImportAll, ImportBase, TypeAlias,
25-
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
25+
ARG_POS, ARG_STAR, LITERAL_TYPE, LDEF, MDEF, GDEF,
2626
CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr,
2727
is_final_node,
2828
ARG_NAMED)
2929
from mypy import nodes
3030
from mypy import operators
3131
from mypy.literals import literal, literal_hash, Key
32-
from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any
32+
from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any, make_optional_type
3333
from mypy.types import (
3434
Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType,
3535
Instance, NoneType, strip_type, TypeType, TypeOfAny,
@@ -203,6 +203,14 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
203203
# directly or indirectly.
204204
module_refs: Set[str]
205205

206+
# A map from variable nodes to a snapshot of the frame ids of the
207+
# frames that were active when the variable was declared. This can
208+
# be used to determine nearest common ancestor frame of a variable's
209+
# declaration and the current frame, which lets us determine if it
210+
# was declared in a different branch of the same `if` statement
211+
# (if that frame is a conditional_frame).
212+
var_decl_frames: Dict[Var, Set[int]]
213+
206214
# Plugin that provides special type checking rules for specific library
207215
# functions such as open(), etc.
208216
plugin: Plugin
@@ -229,6 +237,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
229237
self.dynamic_funcs = []
230238
self.partial_types = []
231239
self.partial_reported = set()
240+
self.var_decl_frames = {}
232241
self.deferred_nodes = []
233242
self.type_map = {}
234243
self.module_refs = set()
@@ -411,7 +420,7 @@ def accept_loop(self, body: Statement, else_body: Optional[Statement] = None, *,
411420
Then check the else_body.
412421
"""
413422
# The outer frame accumulates the results of all iterations
414-
with self.binder.frame_context(can_skip=False):
423+
with self.binder.frame_context(can_skip=False, conditional_frame=True):
415424
while True:
416425
with self.binder.frame_context(can_skip=True,
417426
break_frame=2, continue_frame=1):
@@ -2167,6 +2176,31 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
21672176
rvalue_type, lvalue_type, infer_lvalue_type = self.check_member_assignment(
21682177
instance_type, lvalue_type, rvalue, context=rvalue)
21692178
else:
2179+
# Hacky special case for assigning a literal None
2180+
# to a variable defined in a previous if
2181+
# branch. When we detect this, we'll go back and
2182+
# make the type optional. This is somewhat
2183+
# unpleasant, and a generalization of this would
2184+
# be an improvement!
2185+
if (is_literal_none(rvalue) and
2186+
isinstance(lvalue, NameExpr) and
2187+
lvalue.kind == LDEF and
2188+
isinstance(lvalue.node, Var) and
2189+
lvalue.node.type and
2190+
lvalue.node in self.var_decl_frames and
2191+
not isinstance(get_proper_type(lvalue_type), AnyType)):
2192+
decl_frame_map = self.var_decl_frames[lvalue.node]
2193+
# Check if the nearest common ancestor frame for the definition site
2194+
# and the current site is the enclosing frame of an if/elif/else block.
2195+
has_if_ancestor = False
2196+
for frame in reversed(self.binder.frames):
2197+
if frame.id in decl_frame_map:
2198+
has_if_ancestor = frame.conditional_frame
2199+
break
2200+
if has_if_ancestor:
2201+
lvalue_type = make_optional_type(lvalue_type)
2202+
self.set_inferred_type(lvalue.node, lvalue, lvalue_type)
2203+
21702204
rvalue_type = self.check_simple_assignment(lvalue_type, rvalue, context=rvalue,
21712205
code=codes.ASSIGNMENT)
21722206

@@ -2992,6 +3026,9 @@ def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
29923026
if var and not self.current_node_deferred:
29933027
var.type = type
29943028
var.is_inferred = True
3029+
if var not in self.var_decl_frames:
3030+
# Used for the hack to improve optional type inference in conditionals
3031+
self.var_decl_frames[var] = {frame.id for frame in self.binder.frames}
29953032
if isinstance(lvalue, MemberExpr) and self.inferred_attribute_types is not None:
29963033
# Store inferred attribute type so that we can check consistency afterwards.
29973034
if lvalue.def_var is not None:
@@ -3298,7 +3335,7 @@ def visit_if_stmt(self, s: IfStmt) -> None:
32983335
"""Type check an if statement."""
32993336
# This frame records the knowledge from previous if/elif clauses not being taken.
33003337
# Fall-through to the original frame is handled explicitly in each block.
3301-
with self.binder.frame_context(can_skip=False, fall_through=0):
3338+
with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0):
33023339
for e, b in zip(s.expr, s.body):
33033340
t = get_proper_type(self.expr_checker.accept(e))
33043341

@@ -3437,7 +3474,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
34373474
# was the top frame on entry.
34383475
with self.binder.frame_context(can_skip=False, fall_through=2, try_frame=try_frame):
34393476
# This frame receives exit via exception, and runs exception handlers
3440-
with self.binder.frame_context(can_skip=False, fall_through=2):
3477+
with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=2):
34413478
# Finally, the body of the try statement
34423479
with self.binder.frame_context(can_skip=False, fall_through=2, try_frame=True):
34433480
self.accept(s.body)
@@ -4925,6 +4962,7 @@ def handle_partial_var_type(
49254962
if is_local or not self.options.allow_untyped_globals:
49264963
self.msg.need_annotation_for_var(node, context,
49274964
self.options.python_version)
4965+
self.partial_reported.add(node)
49284966
else:
49294967
# Defer the node -- we might get a better type in the outer scope
49304968
self.handle_cannot_determine_type(node.name, context)

0 commit comments

Comments
 (0)