Skip to content

Commit

Permalink
fix: ann assignments inside function bodies should never be quoted
Browse files Browse the repository at this point in the history
  • Loading branch information
Daverball committed May 28, 2024
1 parent fae7589 commit 448bb4b
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 8 deletions.
55 changes: 47 additions & 8 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,9 @@ def __init__(self, node: ast.Module | ast.ClassDef | Function, parent: Scope | N
# it in symbol lookups
self.class_name = node.name if isinstance(node, ast.ClassDef) else None

#: Whether or not annotation assignments never get evaluated in this scope
self.ann_assign_never_evaluates = not is_head and isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))

def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only: bool = True) -> Symbol | None:
"""
Simulate a symbol lookup.
Expand Down Expand Up @@ -824,22 +827,38 @@ def __init__(self) -> None:
#: All type annotations in the file, with quotes around them
self.wrapped_annotations: list[WrappedAnnotation] = []

#: All type annotations in the file with unnecessary quotes around them
self.excess_wrapped_annotations: list[WrappedAnnotation] = []

#: All the invalid uses of string literals inside ast.BinOp
self.invalid_binop_literals: list[ast.Constant] = []

#: Whether or not this annotation ever gets evaluated
# e.g. AnnAssign.annotation within a function body will never evaluate
self.never_evaluates = False

def visit(
self, node: ast.AST, scope: Scope | None = None, type: Literal['annotation', 'alias', 'new-alias'] | None = None
self,
node: ast.AST,
scope: Scope | None = None,
type: Literal['annotation', 'alias', 'new-alias'] | None = None,
never_evaluates: bool | None = None,
) -> None:
"""Visit the node with the given scope and annotation type."""
if scope is not None:
self.scope = scope
if type is not None:
self.type = type
if never_evaluates is not None:
self.never_evaluates = never_evaluates
super().visit(node)

def visit_annotation_name(self, node: ast.Name) -> None:
"""Register unwrapped annotation."""
setattr(node, ANNOTATION_PROPERTY, True)
if self.never_evaluates:
return

self.unwrapped_annotations.append(
UnwrappedAnnotation(node.lineno, node.col_offset, node.id, self.scope, self.type)
)
Expand All @@ -851,7 +870,7 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
if getattr(node, BINOP_OPERAND_PROPERTY, False):
self.invalid_binop_literals.append(node)
else:
self.wrapped_annotations.append(
(self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append(
WrappedAnnotation(
node.lineno, node.col_offset, node.value, set(NAME_RE.findall(node.value)), self.scope, self.type
)
Expand Down Expand Up @@ -971,6 +990,11 @@ def wrapped_annotations(self) -> list[WrappedAnnotation]:
"""All type annotations in the file, with quotes around them."""
return self.annotation_visitor.wrapped_annotations

@property
def excess_wrapped_annotations(self) -> list[WrappedAnnotation]:
"""All type annotations in the file, with excess quotes around them."""
return self.annotation_visitor.excess_wrapped_annotations

@property
def invalid_binop_literals(self) -> list[ast.Constant]:
"""All invalid uses of binop literals."""
Expand Down Expand Up @@ -1330,10 +1354,14 @@ def visit_Constant(self, node: ast.Constant) -> ast.Constant:
return node

def add_annotation(
self, node: ast.AST, scope: Scope, type: Literal['annotation', 'alias', 'new-alias'] = 'annotation'
self,
node: ast.AST,
scope: Scope,
type: Literal['annotation', 'alias', 'new-alias'] = 'annotation',
never_evaluates: bool = False,
) -> None:
"""Map all annotations on an AST node."""
self.annotation_visitor.visit(node, scope, type)
self.annotation_visitor.visit(node, scope, type, never_evaluates)

@staticmethod
def set_child_node_attribute(node: Any, attr: str, val: Any) -> Any:
Expand Down Expand Up @@ -1363,7 +1391,10 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
a ForwardRef with a future import, like a true annotation would.
"""
super().visit_AnnAssign(node)
self.add_annotation(node.annotation, self.current_scope)

self.add_annotation(
node.annotation, self.current_scope, never_evaluates=self.current_scope.ann_assign_never_evaluates
)

if node.value is None:
return
Expand Down Expand Up @@ -1839,9 +1870,11 @@ def unused_imports(self) -> Flake8Generator:
unused_imports = set(self.visitor.imports) - self.visitor.names - self.visitor.mapped_names
used_imports = set(self.visitor.imports) - unused_imports
already_imported_modules = [self.visitor.imports[name].module for name in used_imports]
annotation_names = [n for i in self.visitor.wrapped_annotations for n in i.names] + [
i.annotation for i in self.visitor.unwrapped_annotations
]
annotation_names = (
[n for i in self.visitor.wrapped_annotations for n in i.names]
+ [i.annotation for i in self.visitor.unwrapped_annotations]
+ [n for i in self.visitor.excess_wrapped_annotations for n in i.names]
)

for name in unused_imports:
if name not in annotation_names:
Expand Down Expand Up @@ -2046,6 +2079,12 @@ def excess_quotes(self) -> Flake8Generator:

yield item.lineno, item.col_offset, error, None

for item in self.visitor.excess_wrapped_annotations:
assert item.type == 'annotation'
# This always generates an error in either case
yield item.lineno, item.col_offset, TC101.format(annotation=item.annotation), None
yield item.lineno, item.col_offset, TC201.format(annotation=item.annotation), None

@property
def errors(self) -> Flake8Generator:
"""
Expand Down
11 changes: 11 additions & 0 deletions tests/test_tc100.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
('if TYPE_CHECKING:\n\tfrom typing import Dict\ndef example(x: Dict[str, int] = {}):\n\tpass', {'1:0 ' + TC100}),
# Import used for returns
('if TYPE_CHECKING:\n\tfrom typing import Dict\ndef example() -> Dict[str, int]:\n\tpass', {'1:0 ' + TC100}),
(
# Regression test for #186
textwrap.dedent('''
if TYPE_CHECKING:
from baz import Bar
def foo(self) -> None:
x: Bar
'''),
set(),
),
]

if sys.version_info >= (3, 12):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_tc101.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,22 @@ def foo(self) -> 'X':
'''),
set(),
),
(
# Regression test for #186
textwrap.dedent('''
def foo(self) -> None:
x: Bar
'''),
set(),
),
(
# Reverse regression test for #186
textwrap.dedent('''
def foo(self) -> None:
x: 'Bar'
'''),
{'3:7 ' + TC101.format(annotation='Bar')},
),
]

if sys.version_info >= (3, 12):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_tc200.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,17 @@ class FooDict(TypedDict):
'18:9 ' + TC200.format(annotation='Sequence'),
},
),
(
# Regression test for #186
textwrap.dedent('''
if TYPE_CHECKING:
from baz import Bar
def foo(self) -> None:
x: Bar
'''),
set(),
),
]

if sys.version_info >= (3, 11):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_tc201.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,22 @@ def bar(*args: 'P.args', **kwargs: 'P.kwargs') -> None:
'''),
set(),
),
(
# Regression test for #186
textwrap.dedent('''
def foo(self) -> None:
x: Bar
'''),
set(),
),
(
# Reverse regression test for #186
textwrap.dedent('''
def foo(self) -> None:
x: 'Bar'
'''),
{'3:7 ' + TC201.format(annotation='Bar')},
),
]

if sys.version_info >= (3, 12):
Expand Down

0 comments on commit 448bb4b

Please sign in to comment.