Skip to content

Commit b11f15a

Browse files
add assert_type() (#434)
1 parent 9879790 commit b11f15a

File tree

4 files changed

+80
-4
lines changed

4 files changed

+80
-4
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Add support for `assert_type()` (#433)
56
- `reveal_type()` and `dump_value()` now return their argument,
67
the anticipated behavior for `typing.reveal_type()` in Python
78
3.11 (#432)

pyanalyze/extensions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,28 @@ def f(x: int) -> None:
392392
return value
393393

394394

395+
def assert_type(val: _T, typ: Any) -> _T:
396+
"""Assert the inferred static type of an expression.
397+
398+
When a static type checker encounters a call to this function,
399+
it checks that the inferred type of `val` matches the `typ`
400+
argument, and if it dooes not, it emits an error.
401+
402+
Example::
403+
404+
def f(x: int) -> None:
405+
assert_type(x, int) # ok
406+
assert_type(x, str) # error
407+
408+
This is useful for checking that the type checker interprets
409+
a complicated set of type annotations in the way the user intended.
410+
411+
At runtime this returns the first argument unchanged.
412+
413+
"""
414+
return val
415+
416+
395417
_overloads: Dict[str, List[Callable[..., Any]]] = defaultdict(list)
396418
_type_evaluations: Dict[str, List[Callable[..., Any]]] = defaultdict(list)
397419

pyanalyze/implementation.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import typing
2-
import typing_extensions
31
from .annotations import type_from_value
42
from .error_code import ErrorCode
5-
from .extensions import reveal_type
3+
from .extensions import assert_type, reveal_type
64
from .format_strings import parse_format_string
75
from .predicates import IsAssignablePredicate
86
from .safe import safe_hasattr, safe_isinstance, safe_issubclass
@@ -52,6 +50,7 @@
5250
concrete_values_from_iterable,
5351
kv_pairs_from_mapping,
5452
make_weak,
53+
unannotate,
5554
unite_values,
5655
flatten_values,
5756
replace_known_sequence_value,
@@ -66,6 +65,8 @@
6665
import inspect
6766
import warnings
6867
from types import FunctionType
68+
import typing
69+
import typing_extensions
6970
from typing import (
7071
Sequence,
7172
TypeVar,
@@ -1042,6 +1043,20 @@ def _cast_impl(ctx: CallContext) -> Value:
10421043
return type_from_value(typ, visitor=ctx.visitor, node=ctx.node)
10431044

10441045

1046+
def _assert_type_impl(ctx: CallContext) -> Value:
1047+
# TODO maybe we should walk over the whole value and remove Annotated.
1048+
val = unannotate(ctx.vars["val"])
1049+
typ = ctx.vars["typ"]
1050+
expected_type = type_from_value(typ, visitor=ctx.visitor, node=ctx.node)
1051+
if val != expected_type:
1052+
ctx.show_error(
1053+
f"Type is {val} (expected {expected_type})",
1054+
error_code=ErrorCode.inference_failure,
1055+
arg="obj",
1056+
)
1057+
return val
1058+
1059+
10451060
def _subclasses_impl(ctx: CallContext) -> Value:
10461061
"""Overridden because typeshed types make it (T) => List[T] instead."""
10471062
self_obj = ctx.vars["self"]
@@ -1423,7 +1438,18 @@ def get_default_argspecs() -> Dict[object, Signature]:
14231438
callable=str.format,
14241439
),
14251440
Signature.make(
1426-
[SigParameter("typ"), SigParameter("val")], callable=cast, impl=_cast_impl
1441+
[SigParameter("typ", _POS_ONLY), SigParameter("val", _POS_ONLY)],
1442+
callable=cast,
1443+
impl=_cast_impl,
1444+
),
1445+
Signature.make(
1446+
[
1447+
SigParameter("val", _POS_ONLY, annotation=TypeVarValue(T)),
1448+
SigParameter("typ", _POS_ONLY),
1449+
],
1450+
TypeVarValue(T),
1451+
callable=assert_type,
1452+
impl=_assert_type_impl,
14271453
),
14281454
# workaround for https://github.com/python/typeshed/pull/3501
14291455
Signature.make(
@@ -1566,4 +1592,20 @@ def get_default_argspecs() -> Dict[object, Signature]:
15661592
callable=reveal_type_func,
15671593
)
15681594
signatures.append(sig)
1595+
# Anticipating that this will be added to the stdlib
1596+
try:
1597+
assert_type_func = getattr(mod, "assert_type")
1598+
except AttributeError:
1599+
pass
1600+
else:
1601+
sig = Signature.make(
1602+
[
1603+
SigParameter("val", _POS_ONLY, annotation=TypeVarValue(T)),
1604+
SigParameter("typ", _POS_ONLY),
1605+
],
1606+
TypeVarValue(T),
1607+
callable=assert_type_func,
1608+
impl=_assert_type_impl,
1609+
)
1610+
signatures.append(sig)
15691611
return {sig.callable: sig for sig in signatures}

pyanalyze/test_implementation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,17 @@ def capybara():
11841184
assert_is_value(x, KnownValue(1))
11851185
assert_is_value(y, KnownValue(1))
11861186

1187+
@assert_passes()
1188+
def test_assert_type(self) -> None:
1189+
from pyanalyze.extensions import assert_type
1190+
from typing import Any
1191+
1192+
def capybara(x: int) -> None:
1193+
assert_type(x, int)
1194+
assert_type(x, "int")
1195+
assert_type(x, Any) # E: inference_failure
1196+
assert_type(x, str) # E: inference_failure
1197+
11871198

11881199
class TestCallableGuards(TestNameCheckVisitorBase):
11891200
@assert_passes()

0 commit comments

Comments
 (0)