diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 49527fac70d9..65417daac580 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -255,7 +255,7 @@ def visit_unbound_type(self, t: UnboundType) -> str: s = t.name self.stubgen.import_tracker.require_name(s) if t.args: - s += '[{}]'.format(self.list_str(t.args)) + s += '[{}]'.format(self.args_str(t.args)) return s def visit_none_type(self, t: NoneType) -> str: @@ -264,6 +264,22 @@ def visit_none_type(self, t: NoneType) -> str: def visit_type_list(self, t: TypeList) -> str: return '[{}]'.format(self.list_str(t.items)) + def args_str(self, args: Iterable[Type]) -> str: + """Convert an array of arguments to strings and join the results with commas. + + The main difference from list_str is the preservation of quotes for string + arguments + """ + types = ['builtins.bytes', 'builtins.unicode'] + res = [] + for arg in args: + arg_str = arg.accept(self) + if isinstance(arg, UnboundType) and arg.original_str_fallback in types: + res.append("'{}'".format(arg_str)) + else: + res.append(arg_str) + return ', '.join(res) + class AliasPrinter(NodeVisitor[str]): """Visitor used to collect type aliases _and_ type variable definitions. diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 7d95072800dc..e193f61cd016 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -64,6 +64,25 @@ def g(x: Foo = Foo()) -> Bar: ... def f(x: Foo) -> Bar: ... def g(x: Foo = ...) -> Bar: ... +[case testPreserveFunctionAnnotationWithArgs] +def f(x: foo['x']) -> bar: ... +def g(x: foo[x]) -> bar: ... +def h(x: foo['x', 'y']) -> bar: ... +def i(x: foo[x, y]) -> bar: ... +def j(x: foo['x', y]) -> bar: ... +def k(x: foo[x, 'y']) -> bar: ... +def lit_str(x: Literal['str']) -> Literal['str']: ... +def lit_int(x: Literal[1]) -> Literal[1]: ... +[out] +def f(x: foo['x']) -> bar: ... +def g(x: foo[x]) -> bar: ... +def h(x: foo['x', 'y']) -> bar: ... +def i(x: foo[x, y]) -> bar: ... +def j(x: foo['x', y]) -> bar: ... +def k(x: foo[x, 'y']) -> bar: ... +def lit_str(x: Literal['str']) -> Literal['str']: ... +def lit_int(x: Literal[1]) -> Literal[1]: ... + [case testPreserveVarAnnotation] x: Foo [out]