/79082fec3743eb30a3
Created 1 year, 1 month ago...
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
import ast
import os
import sys
from io import BytesIO
from typing import TYPE_CHECKING, Tuple, Union
if sys.version_info >= (3, 10) or TYPE_CHECKING:
StrPath = Union[str, os.PathLike[str]]
else:
StrPath = Union[str, os.PathLike]
class AnnotationVisitor(ast.NodeVisitor):
"""Find the locations for all annotations that contain `float` or `complex`.
Also, track a place to place an import if needed to accomodate adding `typing.Union` imports in pre-3.10 Python.
"""
def __init__(self):
self.float_ann_locations: list[tuple[int, int]] = []
self.complex_ann_locations: list[tuple[int, int]] = []
self.import_position: tuple[int, int] = (-1, -1)
def _visit_annotation(self, node: Union[ast.arg, ast.AnnAssign]) -> None:
if node.annotation is not None:
for ann_component in ast.walk(node.annotation):
if isinstance(ann_component, ast.Name):
if ann_component.id == "float":
self.float_ann_locations.append((ann_component.lineno, ann_component.col_offset))
elif ann_component.id == "complex":
self.complex_ann_locations.append((ann_component.lineno, ann_component.col_offset))
def visit_arg(self, node: ast.arg):
"""Find all the parts of annotations in parameters that use either `float` or `complex`."""
self._visit_annotation(node)
self.generic_visit(node)
def visit_Module(self, node: ast.Module) -> None:
"""Find the first available plce in the module to insert an import if need be.
This will be used to place an import for `typing.Union` in pre-3.10 Python.
"""
expect_docstring = True
for sub_node in node.body:
if (
isinstance(sub_node, ast.Expr)
and isinstance(sub_node.value, ast.Constant)
and isinstance(sub_node.value.value, str)
and expect_docstring
):
expect_docstring = False
elif isinstance(sub_node, ast.ImportFrom) and sub_node.module == "__future__" and sub_node.level == 0:
pass
else:
self.import_position = (sub_node.lineno, sub_node.col_offset)
break
self.generic_visit(node)
def fix_file(filename: StrPath, py_version: Tuple[int, int]) -> None:
"""Update a file to adjust annotations for `float` and `complex` to `float | int` and `complex | float | int` respectively.
Parameters
----------
filename: StrPath
A path-like object that can be used to open a file. The resulting file will be modified.
py_version: tuple[int, int]
The version of Python this script should assume the file's code is written in. If below (3, 10), Union will be
used instead of the pipe when substituting the annotations, resulting in an added import from `typing`.
Examples
--------
Initial state of test.py:
def example(a: float, *, b: complex, **kwargs: float) -> float:
result: float = a + b + sum(kwargs.values())
return result
After `fix_file("test.py", (3, 9))`:
from typing import Union
def example(a: Union[float, int], *, b: Union[complex, float, int], **kwargs: Union[float, int]) -> float:
result: float = a + b + sum(kwargs.values())
return result
After `fix_file("test.py", (3, 10))`:
def example(a: float | int, *, b: complex | float | int, **kwargs: float | int) -> float:
result: float = a + b + sum(kwargs.values())
return result
"""
if py_version >= (3, 10):
float_replacement = b"float | int"
complex_replacement = b"complex | float | int"
else:
float_replacement = b"Union[float, int]"
complex_replacement = b"Union[complex, float, int]"
with open(filename, "rb") as fp:
source = fp.read()
tree = ast.parse(source)
visitor = AnnotationVisitor()
visitor.visit(tree)
def modify_source():
for lineno, line in enumerate(iter(BytesIO(source).readline, b""), start=1):
if lineno == visitor.import_position[0] and py_version < (3, 10):
yield b"from typing import Union\n"
relevant_float_locations = [
(float_lineno, col_offset, b"float")
for float_lineno, col_offset in visitor.float_ann_locations
if lineno == float_lineno
]
relevant_complex_locations = [
(complex_lineno, col_offset, b"complex")
for complex_lineno, col_offset in visitor.complex_ann_locations
if lineno == complex_lineno
]
relevant_locations = sorted((*relevant_float_locations, *relevant_complex_locations), reverse=True)
for _, rel_col_offset, type_ in relevant_locations:
if type_ == b"float":
curr_replacement = float_replacement
else:
curr_replacement = complex_replacement
line = line[:rel_col_offset] + curr_replacement + line[rel_col_offset + len(type_) :]
yield line
with open(filename, "wb") as fp:
fp.writelines(modify_source())
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("filename", help="The file that will be modified.")
parser.add_argument(
"python_version",
help=(
"The python version to take into account when modifying the file. Determines whether Union or | is used. "
"Must be in the format '3.x', with x being a number."
),
)
args = parser.parse_args()
filename = args.filename
py_major, _, py_minor = args.python_version.partition(".")
fix_file(filename, (int(py_major), int(py_minor)))
if __name__ == "__main__":
raise SystemExit(main())