Skip to content

Commit d21c5ab

Browse files
authored
Treat generators with await as async. (#12925)
Treat generators with await as async.
1 parent 1636a05 commit d21c5ab

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

mypy/checkexpr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
make_optional_type,
1515
)
1616
from mypy.semanal_enum import ENUM_BASES
17+
from mypy.traverser import has_await_expression
1718
from mypy.types import (
1819
Type, AnyType, CallableType, Overloaded, NoneType, TypeVarType,
1920
TupleType, TypedDictType, Instance, ErasedType, UnionType,
@@ -3798,8 +3799,8 @@ def visit_set_comprehension(self, e: SetComprehension) -> Type:
37983799

37993800
def visit_generator_expr(self, e: GeneratorExpr) -> Type:
38003801
# If any of the comprehensions use async for, the expression will return an async generator
3801-
# object
3802-
if any(e.is_async):
3802+
# object, or if the left-side expression uses await.
3803+
if any(e.is_async) or has_await_expression(e.left_expr):
38033804
typ = 'typing.AsyncGenerator'
38043805
# received type is always None in async generator expressions
38053806
additional_args: List[Type] = [NoneType()]

mypy/traverser.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom,
1919
LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr,
2020
YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, Node, REVEAL_TYPE,
21+
Expression,
2122
)
2223

2324

@@ -397,6 +398,21 @@ def has_yield_expression(fdef: FuncBase) -> bool:
397398
return seeker.found
398399

399400

401+
class AwaitSeeker(TraverserVisitor):
402+
def __init__(self) -> None:
403+
super().__init__()
404+
self.found = False
405+
406+
def visit_await_expr(self, o: AwaitExpr) -> None:
407+
self.found = True
408+
409+
410+
def has_await_expression(expr: Expression) -> bool:
411+
seeker = AwaitSeeker()
412+
expr.accept(seeker)
413+
return seeker.found
414+
415+
400416
class ReturnCollector(FuncCollectorBase):
401417
def __init__(self) -> None:
402418
super().__init__()

test-data/unit/check-async-await.test

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,14 @@ async with C() as x: # E: "async with" outside async function
864864

865865
[builtins fixtures/async_await.pyi]
866866
[typing fixtures/typing-async.pyi]
867+
868+
[case testAsyncGeneratorExpressionAwait]
869+
from typing import AsyncGenerator
870+
871+
async def f() -> AsyncGenerator[int, None]:
872+
async def g(x: int) -> int:
873+
return x
874+
875+
return (await g(x) for x in [1, 2, 3])
876+
877+
[typing fixtures/typing-async.pyi]

0 commit comments

Comments
 (0)