import ast import os import sys from io import BytesIO from typing import TYPE_CHECKING, Tuple, Union if sys.version_info >= (3, 10) or TYPE_CHECKING: StrPath = Union[str, os.PathLike[str]] else: StrPath = Union[str, os.PathLike] class AnnotationVisitor(ast.NodeVisitor): """Find the locations for all annotations that contain `float` or `complex`. Also, track a place to place an import if needed to accomodate adding `typing.Union` imports in pre-3.10 Python. """ def __init__(self): self.float_ann_locations: list[tuple[int, int]] = [] self.complex_ann_locations: list[tuple[int, int]] = [] self.import_position: tuple[int, int] = (-1, -1) def _visit_annotation(self, node: Union[ast.arg, ast.AnnAssign]) -> None: if node.annotation is not None: for ann_component in ast.walk(node.annotation): if isinstance(ann_component, ast.Name): if ann_component.id == "float": self.float_ann_locations.append((ann_component.lineno, ann_component.col_offset)) elif ann_component.id == "complex": self.complex_ann_locations.append((ann_component.lineno, ann_component.col_offset)) def visit_arg(self, node: ast.arg): """Find all the parts of annotations in parameters that use either `float` or `complex`.""" self._visit_annotation(node) self.generic_visit(node) def visit_Module(self, node: ast.Module) -> None: """Find the first available plce in the module to insert an import if need be. This will be used to place an import for `typing.Union` in pre-3.10 Python. """ expect_docstring = True for sub_node in node.body: if ( isinstance(sub_node, ast.Expr) and isinstance(sub_node.value, ast.Constant) and isinstance(sub_node.value.value, str) and expect_docstring ): expect_docstring = False elif isinstance(sub_node, ast.ImportFrom) and sub_node.module == "__future__" and sub_node.level == 0: pass else: self.import_position = (sub_node.lineno, sub_node.col_offset) break self.generic_visit(node) def fix_file(filename: StrPath, py_version: Tuple[int, int]) -> None: """Update a file to adjust annotations for `float` and `complex` to `float | int` and `complex | float | int` respectively. Parameters ---------- filename: StrPath A path-like object that can be used to open a file. The resulting file will be modified. py_version: tuple[int, int] The version of Python this script should assume the file's code is written in. If below (3, 10), Union will be used instead of the pipe when substituting the annotations, resulting in an added import from `typing`. Examples -------- Initial state of test.py: def example(a: float, *, b: complex, **kwargs: float) -> float: result: float = a + b + sum(kwargs.values()) return result After `fix_file("test.py", (3, 9))`: from typing import Union def example(a: Union[float, int], *, b: Union[complex, float, int], **kwargs: Union[float, int]) -> float: result: float = a + b + sum(kwargs.values()) return result After `fix_file("test.py", (3, 10))`: def example(a: float | int, *, b: complex | float | int, **kwargs: float | int) -> float: result: float = a + b + sum(kwargs.values()) return result """ if py_version >= (3, 10): float_replacement = b"float | int" complex_replacement = b"complex | float | int" else: float_replacement = b"Union[float, int]" complex_replacement = b"Union[complex, float, int]" with open(filename, "rb") as fp: source = fp.read() tree = ast.parse(source) visitor = AnnotationVisitor() visitor.visit(tree) def modify_source(): for lineno, line in enumerate(iter(BytesIO(source).readline, b""), start=1): if lineno == visitor.import_position[0] and py_version < (3, 10): yield b"from typing import Union\n" relevant_float_locations = [ (float_lineno, col_offset, b"float") for float_lineno, col_offset in visitor.float_ann_locations if lineno == float_lineno ] relevant_complex_locations = [ (complex_lineno, col_offset, b"complex") for complex_lineno, col_offset in visitor.complex_ann_locations if lineno == complex_lineno ] relevant_locations = sorted((*relevant_float_locations, *relevant_complex_locations), reverse=True) for _, rel_col_offset, type_ in relevant_locations: if type_ == b"float": curr_replacement = float_replacement else: curr_replacement = complex_replacement line = line[:rel_col_offset] + curr_replacement + line[rel_col_offset + len(type_) :] yield line with open(filename, "wb") as fp: fp.writelines(modify_source()) def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("filename", help="The file that will be modified.") parser.add_argument( "python_version", help=( "The python version to take into account when modifying the file. Determines whether Union or | is used. " "Must be in the format '3.x', with x being a number." ), ) args = parser.parse_args() filename = args.filename py_major, _, py_minor = args.python_version.partition(".") fix_file(filename, (int(py_major), int(py_minor))) if __name__ == "__main__": raise SystemExit(main())