From 6ee9d2ea12b964fd6a54deb63cc2dd7a0c791d40 Mon Sep 17 00:00:00 2001 From: ioistired Date: Mon, 27 May 2024 00:48:40 +0000 Subject: [PATCH 1/6] [wip] port jishaku.inline_import --- README.md | 2 +- import_expression/__init__.py | 21 +- import_expression/__main__.py | 1 - import_expression/_codec/__init__.py | 63 ------ import_expression/_codec/compat.py | 27 --- import_expression/_parser.py | 298 +++++++++++---------------- import_expression/_syntax.py | 215 +++++++++++-------- import_expression/constants.py | 3 +- setup.py | 4 +- tests.py | 61 +----- 10 files changed, 255 insertions(+), 440 deletions(-) delete mode 100644 import_expression/_codec/__init__.py delete mode 100644 import_expression/_codec/compat.py diff --git a/README.md b/README.md index 556a104..cae884a 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ importlib.import_module('urllib.parse').quote('hello there') Counter({'e': 4, 'd': 3, 'c': 2, 'b': 1}) ``` -The other public functions are `exec`, `compile`, `parse`, `find_imports`, and `update_globals`. +The other public functions are `exec`, `compile`, `parse`, `find_imports`. See their docstrings for details. By default, the filename for `SyntaxError`s is ``. diff --git a/import_expression/__init__.py b/import_expression/__init__.py index 3636a01..2c016fb 100755 --- a/import_expression/__init__.py +++ b/import_expression/__init__.py @@ -29,14 +29,14 @@ from . import constants from ._syntax import fix_syntax as _fix_syntax -from ._parser import parse_ast as _parse_ast +from ._parser import transform_ast as _transform_ast from ._parser import find_imports as _find_imports from .version import __version__ with _contextlib.suppress(NameError): del version -__all__ = ('compile', 'parse', 'eval', 'exec', 'constants', 'find_imports', 'update_globals') +__all__ = ('compile', 'parse', 'eval', 'exec', 'constants', 'find_imports') _source = _typing.Union[_ast.AST, _typing.AnyStr] @@ -57,14 +57,14 @@ def parse(source: _source, filename=constants.DEFAULT_FILENAME, mode='exec', *, """ # for some API compatibility with ast, allow parse(parse('foo')) to work if isinstance(source, _ast.AST): - return _parse_ast(source, filename=filename) + return _transform_ast(source, filename=filename) fixed = _fix_syntax(source, filename=filename) if flags & PyCF_DONT_IMPLY_DEDENT: # just run it for the syntax errors, which codeop picks up on _builtins.compile(fixed, filename, mode, flags) tree = _ast.parse(fixed, filename, mode, **kwargs) - return _parse_ast(tree, source=source, filename=filename) + return _transform_ast(tree, source=source, filename=filename) def compile( source: _source, @@ -110,23 +110,10 @@ def find_imports(source: str, filename=constants.DEFAULT_FILENAME, mode='exec'): tree = _ast.parse(fixed, filename, mode) return _find_imports(tree, filename=filename) -def update_globals(globals: dict) -> dict: - """Ensure that the variables required for eval/exec are present in the given dict. - Note that import_expression.eval and import_expression.exec do this for you automatically. - Calling this function yourself is only necessary if you want to call builtins.eval or builtins.exec - with the return value of import_expression.compile. - - This function always returns the passed dictionary to make expression chaining easier. - """ - globals.update({constants.IMPORTER: _importlib.import_module}) - return globals - def _parse_eval_exec_args(globals, locals): if globals is None: globals = {} - update_globals(globals) - if locals is None: locals = globals diff --git a/import_expression/__main__.py b/import_expression/__main__.py index 2c1070d..60f6607 100644 --- a/import_expression/__main__.py +++ b/import_expression/__main__.py @@ -90,7 +90,6 @@ def __call__(self, source, filename, symbol, **kwargs): class ImportExpressionInteractiveConsole(code.InteractiveConsole): def __init__(self, locals=None, filename=''): super().__init__(locals, filename) - self.locals.update({constants.IMPORTER: importlib.import_module}) self.compile = ImportExpressionCommandCompiler() # we must vendor this class because it creates global variables that the main code depends on diff --git a/import_expression/_codec/__init__.py b/import_expression/_codec/__init__.py deleted file mode 100644 index a042a91..0000000 --- a/import_expression/_codec/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -import ast -from . import compat as astunparse -import codecs, io, encodings -from encodings import utf_8 - -import import_expression as ie -from ..constants import IMPORTER - -IMPORT_STATEMENT = ast.parse(f'from importlib import import_module as {IMPORTER}').body[0] - -def decode(b, errors='strict'): - if not b: - return '', 0 - - decoded = codecs.decode(b, errors=errors, encoding='utf-8') - parsed = ie.parse(decoded) - parsed.body.insert(0, IMPORT_STATEMENT) - unparsed = astunparse.unparse(parsed) - return unparsed, len(decoded) - -# copied from future_fstrings.py at bd6bf81 - -class IncrementalDecoder(codecs.BufferedIncrementalDecoder): - def _buffer_decode(self, input, errors, final): - if final: - return decode(input, errors) - else: - return '', 0 - -class StreamReader(utf_8.StreamReader): # pragma: no cover - """decode is deferred to support better error messages""" - _stream = None - _decoded = False - - @property - def stream(self): - if not self._decoded: - text, _ = decode(self._stream.read()) - self._stream = io.BytesIO(text.encode('utf-8')) - self._decoded = True - return self._stream - - @stream.setter - def stream(self, stream): - self._stream = stream - self._decoded = False - -def search_function(encoding, codec_names={'import_expression', 'ie'}): - if encoding not in codec_names: # pragma: no cover - return None - utf8 = encodings.search_function('utf-8') - return codecs.CodecInfo( - name='import_expression', - encode=utf8.encode, - decode=decode, - incrementalencoder=utf8.incrementalencoder, - incrementaldecoder=IncrementalDecoder, - streamreader=StreamReader, - streamwriter=utf8.streamwriter, - ) - -def register(): - codecs.register(search_function) diff --git a/import_expression/_codec/compat.py b/import_expression/_codec/compat.py deleted file mode 100644 index 40a690c..0000000 --- a/import_expression/_codec/compat.py +++ /dev/null @@ -1,27 +0,0 @@ -import contextlib -from io import StringIO -from astunparse import Unparser - -class Unparser(Unparser): - def _Constant(self, t): - value = t.value - if isinstance(value, tuple): - self.write("(") - if len(value) == 1: - self._write_constant(value[0]) - self.write(",") - else: - interleave(lambda: self.write(", "), self._write_constant, value) - self.write(")") - elif value is Ellipsis: # instead of `...` for Py2 compatibility - self.write("...") - else: - with contextlib.suppress(AttributeError): - if t.kind == "u": - self.write("u") - self._write_constant(t.value) - -def unparse(tree): - out = StringIO() - Unparser(tree, file=out) - return out.getvalue() diff --git a/import_expression/_parser.py b/import_expression/_parser.py index 9cf41dc..f36a1f7 100644 --- a/import_expression/_parser.py +++ b/import_expression/_parser.py @@ -19,206 +19,103 @@ # THE SOFTWARE. import ast +import sys +import typing +import functools import contextlib from collections import namedtuple - +from typing_extensions import ParamSpec +from typing_extensions import Buffer as ReadableBuffer from .constants import * -def parse_ast(root_node, **kwargs): return ast.fix_missing_locations(Transformer(**kwargs).visit(root_node)) +P = ParamSpec("P") +T = typing.TypeVar("T") -def find_imports(root_node, **kwargs): - t = ListingTransformer(**kwargs) - t.visit(root_node) - return t.imports +# https://github.com/python/cpython/blob/5d04cc50e51cb262ee189a6ef0e79f4b372d1583/Objects/exceptions.c#L2438-L2441 +_sec_fields = 'filename lineno offset text'.split() +if sys.version_info >= (3, 10): + _sec_fields.extend('end_lineno end_offset'.split()) + +SyntaxErrorContext = namedtuple('SyntaxErrorContext', _sec_fields) -def remove_string_right(haystack, needle): - left, needle, right = haystack.rpartition(needle) - if not right: - return left - # needle not found - return haystack - -def remove_import_op(name): return remove_string_right(name, MARKER) -def has_any_import_op(name): return MARKER in name -def has_invalid_import_op(name): - removed = remove_import_op(name) - return MARKER in removed or not removed -def find_valid_imported_name(name): - """return a name preceding an import op, or False if there isn't one""" - return name.endswith(MARKER) and remove_import_op(name) - -# source: CPython Objects/exceptions.c:1362 at v3.8.0 -SyntaxErrorContext = namedtuple('SyntaxErrorContext', 'filename lineno offset text') +del _sec_fields + +def transform_ast(root_node, **kwargs): return ast.fix_missing_locations(Transformer(**kwargs).visit(root_node)) class Transformer(ast.NodeTransformer): + """An AST transformer that replaces calls to MARKER with '__import__("importlib").import_module(...)'.""" + def __init__(self, *, filename=None, source=None): self.filename = filename self.source_lines = source.splitlines() if source is not None else None - def visit_Attribute(self, node): - """ - convert Attribute nodes containing import expressions into Attribute nodes containing import calls - """ - self._ensure_only_valid_import_ops(node) - - maybe_transformed = self._transform_attribute_attr(node) - if maybe_transformed: - return maybe_transformed - else: - transformed_lhs = self.visit(node.value) - return ast.copy_location( - ast.Attribute( - value=transformed_lhs, - ctx=node.ctx, - attr=node.attr), - node) - - def visit_Name(self, node): - """convert solitary Names that have import expressions, such as "a!", into import calls""" - self._ensure_only_valid_import_ops(node) - - id = find_valid_imported_name(node.id) - if id: - return ast.copy_location(self.import_call(id, node.ctx), node) - return node - - @staticmethod - def import_call(attribute_source, ctx): - return ast.Call( - func=ast.Name(id=IMPORTER, ctx=ctx), - args=[ast.Str(attribute_source)], - keywords=[]) - - def _transform_attribute_attr(self, node): - """convert an Attribute node's left hand side into an import call""" - - attr = find_valid_imported_name(node.attr) - if not attr: - return None - - node.attr = attr - as_source = self.attribute_source(node) - - return ast.copy_location( - self.import_call(as_source, node.ctx), - node) - - def attribute_source(self, node: ast.Attribute, _seen_import_op=False): - """return a source-code representation of an Attribute node""" - if self._find_valid_imported_name(node): - _seen_import_op = True - - stripped = self._remove_import_op(node) - if type(node) is ast.Name: - if _seen_import_op: - raise self._syntax_error('multiple import expressions not allowed', node) from None - return stripped - - lhs = self.attribute_source(node.value, _seen_import_op) - rhs = stripped - - return lhs + '.' + rhs - - def visit_def_(self, node): - if not has_any_import_op(node.name): - # it's valid so far, just ensure that arguments and body are also visited - return self.generic_visit(node) - - if isinstance(node, ast.ClassDef): - type_name = 'class' - else: - type_name = 'function' - - raise self._syntax_error( - f'"{IMPORT_OP}" not allowed in the name of a {type_name}', - node - ) from None - - visit_FunctionDef = visit_def_ - visit_ClassDef = visit_def_ - - def visit_arg(self, node): - """ensure foo(x! = 1) or def foo(x!) does not occur""" - if node.arg is not None and has_any_import_op(node.arg): - raise self._syntax_error( - f'"{IMPORT_OP}" not allowed in function arguments', - node - ) from None - - # regular arguments may have import expr annotations as children - return super().generic_visit(node) - - def visit_keyword(self, node): - self.visit_arg(node) - # keyword arguments may have import expressions as children - return super().generic_visit(node) - - def visit_alias(self, node): - # from x import y **as z** - self._ensure_no_import_ops(node) - return node - - def visit_ImportFrom(self, node): - self._ensure_no_import_ops(node) - # ImportFrom nodes can have alias children that we also need to check - return super().generic_visit(node) - - def _ensure_only_valid_import_ops(self, node): - if self._for_any_child_node_string(has_invalid_import_op, node): - raise self._syntax_error( - f'"{IMPORT_OP}" only allowed at end of attribute name', - node - ) from None + def _collapse_attributes(self, node: typing.Union[ast.Attribute, ast.Name]) -> str: + if isinstance(node, ast.Name): + return node.id - def _ensure_no_import_ops(self, node): - if self._for_any_child_node_string(has_any_import_op, node): + if not ( + isinstance(node, ast.Attribute) # pyright: ignore[reportUnnecessaryIsInstance] + and isinstance(node.value, (ast.Attribute, ast.Name)) + ): raise self._syntax_error( - 'import expressions are only allowed in variables and attributes', - node - ) from None - - @classmethod - def _for_any_child_node_string(cls, predicate, node): - for child_node in ast.walk(node): - if cls._for_any_node_string(predicate, node): - return True - - return False - - @staticmethod - def _for_any_node_string(predicate, node): - for field, value in ast.iter_fields(node): - if isinstance(value, str) and predicate(value): - return True - - return False - - def _call_on_name_or_attribute(func): - def checker(self, node): - if type(node) is ast.Attribute: - to_check = node.attr - elif type(node) is ast.Name: - to_check = node.id - else: - raise self._syntax_error('invalid syntax', node) - return func(to_check) - - return checker - - _find_valid_imported_name = _call_on_name_or_attribute(find_valid_imported_name) - _remove_import_op = _call_on_name_or_attribute(remove_import_op) - - del _call_on_name_or_attribute + "Only names and attribute access (dot operator) " + "can be within the inline import expression.", + node, + ) # noqa: TRY004 + + return self._collapse_attributes(node.value) + f".{node.attr}" + + def visit_Call(self, node: ast.Call) -> ast.AST: + """Replace the import calls with a valid inline import expression.""" + + if ( + isinstance(node.func, ast.Name) + and node.func.id == MARKER + and len(node.args) == 1 + and isinstance(node.args[0], (ast.Attribute, ast.Name)) + ): + node.func = ast.Attribute( + value=ast.Call( + func=ast.Name(id="__import__", ctx=ast.Load()), + args=[ast.Constant(value="importlib")], + keywords=[], + ), + attr="import_module", + ctx=ast.Load(), + ) + node.args[0] = ast.Constant(value=self._collapse_attributes(node.args[0])) + + return self.generic_visit(node) def _syntax_error(self, message, node): lineno = getattr(node, 'lineno', None) offset = getattr(node, 'col_offset', None) - line = None + end_lineno = getattr(node, 'end_lineno', None) + end_offset = getattr(node, 'end_offset', None) + + text = None if self.source_lines is not None and lineno: + if end_offset is None: + sl = lineno-1 + else: + sl = slice(lineno-1, end_lineno-1) + with contextlib.suppress(IndexError): - line = self.source_lines[lineno-1] - ctx = SyntaxErrorContext(filename=self.filename, lineno=lineno, offset=offset, text=line) - return SyntaxError(message, ctx) + text = self.source_lines[sl] + + kwargs = dict( + filename=self.filename, + lineno=lineno, + offset=offset, + text=text, + ) + if sys.version_info >= (3, 10): + kwargs.update(dict( + end_lineno=end_lineno, + end_offset=end_offset, + )) + + return SyntaxError(message, SyntaxErrorContext(**kwargs)) class ListingTransformer(Transformer): """like the parent class but lists all imported modules as self.imports""" @@ -230,3 +127,44 @@ def __init__(self, *args, **kwargs): def import_call(self, attribute_source, *args, **kwargs): self.imports.append(attribute_source) return super().import_call(attribute_source, *args, **kwargs) + +def find_imports(root_node, **kwargs): + t = ListingTransformer(**kwargs) + t.visit(root_node) + return t.imports + +def copy_annotations( + original_func: typing.Callable[P, T], +) -> typing.Callable[[typing.Callable[..., typing.Any]], typing.Callable[P, T]]: + """A decorator that applies the annotations from one function onto another. + + It can be a lie, but it aids the type checker and any IDE intellisense. + """ + + def inner(new_func: typing.Callable[..., typing.Any]) -> typing.Callable[P, T]: + return functools.update_wrapper(new_func, original_func, ("__doc__", "__annotations__")) # type: ignore + + return inner + + +# Some of the parameter annotations are too narrow or wide, but they should be "overriden" by this decorator. +@copy_annotations(ast.parse) +def parse( + source: typing.Union[str, ReadableBuffer], + filename: str = DEFAULT_FILENAME, + mode: str = "exec", + *, + type_comments: bool = False, + feature_version: typing.Optional[typing.Tuple[int, int]] = None, +) -> ast.Module: + """Convert source code with inline import expressions to an AST. Has the same signature as ast.parse.""" + + return transform_ast( + ast.parse( + transform_source(source), + filename, + mode, + type_comments=type_comments, + feature_version=feature_version, + ) + ) diff --git a/import_expression/_syntax.py b/import_expression/_syntax.py index f157bf4..d13ad19 100644 --- a/import_expression/_syntax.py +++ b/import_expression/_syntax.py @@ -22,23 +22,20 @@ # It is used under the Python Software Foundation License Version 2. # See LICENSE for details. -import collections import io import re +import sys import string -import tokenize as tokenize_ import typing +import collections from token import * -# TODO only import what we need -vars().update({k: v for k, v in vars(tokenize_).items() if not k.startswith('_')}) - from .constants import * +import tokenize as tokenize_ +from typing_extensions import Buffer as ReadableBuffer +from typing_extensions import ParamSpec -tokenize_.TokenInfo.value = property(lambda self: self.string) - -is_import = lambda token: token.type == tokenize_.ERRORTOKEN and token.string == IMPORT_OP - -NEWLINES = {NEWLINE, tokenize_.NL} +P = ParamSpec("P") +T = typing.TypeVar("T") def fix_syntax(s: typing.AnyStr, filename=DEFAULT_FILENAME) -> bytes: try: @@ -53,81 +50,127 @@ def fix_syntax(s: typing.AnyStr, filename=DEFAULT_FILENAME) -> bytes: raise SyntaxError(message, (filename, lineno-1, offset, source_line)) from None - return Untokenizer().untokenize(tokens) - -# modified from Lib/tokenize.py at 3.6 -class Untokenizer: - def __init__(self): - self.tokens = collections.deque() - self.indents = collections.deque() - self.prev_row = 1 - self.prev_col = 0 - self.startline = False - self.encoding = None - - def add_whitespace(self, start): - row, col = start - if row < self.prev_row or row == self.prev_row and col < self.prev_col: - raise ValueError( - "start ({},{}) precedes previous end ({},{})".format(row, col, self.prev_row, self.prev_col)) - - col_offset = col - self.prev_col - self.tokens.append(" " * col_offset) - - def untokenize(self, iterable): - indents = [] - startline = False - for token in iterable: - if token.type == tokenize_.ENCODING: - self.encoding = token.value + transformed = transform_tokens(tokens) + return tokenize_.untokenize(tokens) + +def offset_token_horizontal(tok: tokenize_.TokenInfo, offset: int) -> tokenize_.TokenInfo: + """Takes a token and returns a new token with the columns for start and end offset by a given amount.""" + + start_row, start_col = tok.start + end_row, end_col = tok.end + return tok._replace(start=(start_row, start_col + offset), end=(end_row, end_col + offset)) + +def offset_line_horizontal( + tokens: typing.List[tokenize_.TokenInfo], + start_index: int = 0, + *, + line: int, + offset: int, +) -> None: + """Takes a list of tokens and changes the offset of some of the tokens in place.""" + + for i, tok in enumerate(tokens[start_index:], start=start_index): + if tok.start[0] != line: + break + tokens[i] = offset_token_horizontal(tok, offset) + +def transform_tokens(tokens: typing.Iterable[tokenize_.TokenInfo]) -> typing.List[tokenize_.TokenInfo]: + """Find the inline import expressions in a list of tokens and replace the relevant tokens to wrap the imported + modules with a call to MARKER. + + Later, the AST transformer step will replace those with valid import expressions. + """ + + orig_tokens = list(tokens) + new_tokens: typing.List[tokenize_.TokenInfo] = [] + + for orig_i, tok in enumerate(orig_tokens): + # "!" is only an OP in >=3.12. + if tok.type in {tokenize_.OP, tokenize_.ERRORTOKEN} and tok.string == IMPORT_OP: + has_invalid_syntax = False + + # Collect all name and attribute access-related tokens directly connected to the "!". + last_place = len(new_tokens) + looking_for_name = True + + for old_tok in reversed(new_tokens): + if old_tok.exact_type != (tokenize_.NAME if looking_for_name else tokenize_.DOT): + # The "!" was placed somewhere in a class definition, e.g. "class Fo!o: pass". + has_invalid_syntax = (old_tok.exact_type == tokenize_.NAME and old_tok.string == "class") + + # There's a name immediately following "!". Might be a f-string conversion flag + # like "f'{thing!r}'" or just something invalid like "def fo!o(): pass". + try: + peek = orig_tokens[orig_i + 1] + except IndexError: + pass + else: + has_invalid_syntax = (has_invalid_syntax or peek.type == tokenize_.NAME) + + break + + last_place -= 1 + looking_for_name = not looking_for_name + + # The "!" is just by itself or in a bad spot. Let it error later if it's wrong. + # Also allows other token transformers to work with it without erroring early. + if has_invalid_syntax or last_place == len(new_tokens): + new_tokens.append(tok) continue - if token.type == tokenize_.ENDMARKER: - break - - # XXX this abomination comes from tokenize.py - # i tried to move it to a separate method but failed - - if token.type == tokenize_.INDENT: - indents.append(token.value) - continue - elif token.type == tokenize_.DEDENT: - indents.pop() - self.prev_row, self.prev_col = token.end - continue - elif token.type in NEWLINES: - startline = True - elif startline and indents: - indent = indents[-1] - start_row, start_col = token.start - if start_col >= len(indent): - self.tokens.append(indent) - self.prev_col = len(indent) - startline = False - - # end abomination - - self.add_whitespace(token.start) - - if is_import(token): - self.tokens.append(MARKER) - else: - self.tokens.append(token.value) - - self.prev_row, self.prev_col = token.end - - # don't ask me why this shouldn't be "in NEWLINES", - # but ignoring tokenize_.NL here fixes #3 - if token.type == NEWLINE: - self.prev_row += 1 - self.prev_col = 0 - - return "".join(self.tokens) - -def tokenize(string: typing.AnyStr): - if isinstance(string, bytes): - # call the internal tokenize func to avoid sniffing the encoding - # if it tried to sniff the encoding of a "# encoding: import_expression" file, - # it would call our code again, resulting in a RecursionError. - return tokenize_._tokenize(io.BytesIO(string).readline, encoding='utf-8') - return tokenize_.generate_tokens(io.StringIO(string).readline) + # Insert a call to the MARKER just before the inline import expression. + old_first = new_tokens[last_place] + old_f_row, old_f_col = old_first.start + + new_tokens[last_place:last_place] = [ + old_first._replace(type=tokenize_.NAME, string=MARKER, end=(old_f_row, old_f_col + len(MARKER))), + tokenize_.TokenInfo( + tokenize_.OP, + "(", + (old_f_row, old_f_col + 17), + (old_f_row, old_f_col + 18), + old_first.line, + ), + ] + + # Adjust the positions of the following tokens within the inline import expression. + new_tokens[last_place + 2:] = (offset_token_horizontal(tok, 18) for tok in new_tokens[last_place + 2:]) + + # Add a closing parenthesis. + (end_row, end_col) = new_tokens[-1].end + line = new_tokens[-1].line + end_paren_token = tokenize_.TokenInfo(tokenize_.OP, ")", (end_row, end_col), (end_row, end_col + 1), line) + new_tokens.append(end_paren_token) + + # Fix the positions of the rest of the tokens on the same line. + fixed_line_tokens: typing.List[tokenize_.TokenInfo] = [] + offset_line_horizontal(orig_tokens, orig_i, line=new_tokens[-1].start[0], offset=18) + + # Check the rest of the line for inline import expressions. + new_tokens.extend(transform_tokens(fixed_line_tokens)) + + else: + new_tokens.append(tok) + + # Hack to get around a bug where code that ends in a comment, but no newline, has an extra + # NEWLINE token added in randomly. This patch wasn't backported to 3.8. + # https://github.com/python/cpython/issues/79288 + # https://github.com/python/cpython/issues/88833 + if sys.version_info < (3, 9): + if len(new_tokens) >= 4 and ( + new_tokens[-4].type == tokenize_.COMMENT + and new_tokens[-3].type == tokenize_.NL + and new_tokens[-2].type == tokenize_.NEWLINE + and new_tokens[-1].type == tokenize_.ENDMARKER + ): + del new_tokens[-2] + + return new_tokens + +def tokenize(source: typing.Union[str, ReadableBuffer]) -> str: + if isinstance(source, str): + source = source.encode('utf-8') + stream = io.BytesIO(source) + encoding, _ = tokenize_.detect_encoding(stream.readline) + stream.seek(0) + return tokenize_.tokenize(stream.readline) diff --git a/import_expression/constants.py b/import_expression/constants.py index 31ec456..e74534d 100644 --- a/import_expression/constants.py +++ b/import_expression/constants.py @@ -19,7 +19,6 @@ # THE SOFTWARE. IMPORT_OP = '!' -IMPORTER = '_IMPORT_MODULE' -MARKER = '_IMPORT_EXPR_END' # TODO replace with UUIDs +MARKER = '_IMPORT_MARKER' DEFAULT_FILENAME = '' diff --git a/setup.py b/setup.py index 445b7da..b1123bd 100755 --- a/setup.py +++ b/setup.py @@ -155,9 +155,7 @@ def finalize_options(self): packages=['import_expression', 'import_expression._codec'], - install_requires=[ - 'astunparse>=1.6.3,<2.0.0', - ], + install_requires=[], extras_require={ 'test': [ diff --git a/tests.py b/tests.py index 70af093..d735dee 100644 --- a/tests.py +++ b/tests.py @@ -26,14 +26,6 @@ import import_expression as ie -try: - import astunparse -except ImportError: - HAVE_ASTUNPARSE = False -else: - HAVE_ASTUNPARSE = True - - invalid_attribute_cases = ( # arrange this as if ! is binary 1, empty str is 0 '!a', @@ -216,11 +208,6 @@ def bar(): assert g['foo'](1) == 'can we make it into jishaku?' -def test_importer_name_not_mangled(): - # if import_expression.constants.IMPORTER.startswith('__'), - # this will fail - ie.exec('class Foo: x = io!') - def test_flags(): import ast assert isinstance(ie.compile('foo', flags=ast.PyCF_ONLY_AST), ast.AST) @@ -248,7 +235,7 @@ def test_dont_imply_dedent(): with pytest.raises(SyntaxError): ie.compile('def foo():\n\tpass', mode='single', flags=PyCF_DONT_IMPLY_DEDENT) -def test_parse_ast(): +def test_transform_ast(): from typing import Any node = ie.parse(ie.parse('typing!.Any', mode='eval')) assert ie.eval(node) is Any @@ -257,12 +244,6 @@ def test_locals_arg(): ie.exec('assert locals() is globals()', {}) ie.exec('assert locals() is not globals()', {}, {}) -def test_update_globals(): - import collections - code = ie.compile('collections!.Counter', mode='eval') - g = ie.update_globals({}) - assert eval(code, g) is collections.Counter - def test_find_imports(): with pytest.raises(SyntaxError): ie.find_imports('x; y', mode='eval') @@ -287,46 +268,6 @@ def test_bytes(): import typing assert ie.eval(b'typing!.TYPE_CHECKING') == typing.TYPE_CHECKING -@pytest.mark.skipif(not HAVE_ASTUNPARSE, reason='requires the [codec] setup.py extra') -@pytest.mark.parametrize('encoding', ['import_expression', 'ie']) -def test_encoding(encoding): - import import_expression._codec - import_expression._codec.register() - - import tempfile - import typing - fn = tempfile.mktemp() - with remove(fn), open(fn, mode='w+', encoding=encoding) as f: - f.write('x = typing!.TYPE_CHECKING') - f.seek(0) - g = {} - exec(f.read(), g) - assert g['x'] == typing.TYPE_CHECKING - assert not f.read() - - f.seek(0) - while f.readline(): # we must reach EOF eventually - pass - -@pytest.mark.skipif(not HAVE_ASTUNPARSE, reason='requires the [codec] setup.py extra') -@pytest.mark.parametrize('encoding', ['import_expression', 'ie']) -def test_encoding_2(encoding): - import codecs - import typing - g = {} - exec(codecs.decode(b'x = typing!.TYPE_CHECKING', encoding=encoding), g) - assert g['x'] == typing.TYPE_CHECKING - -@pytest.mark.skipif(not HAVE_ASTUNPARSE, reason='no need to test built in encoding without the [codec] setup.py extra') -def test_utf8_unaffected(): - import tempfile - fn = tempfile.mktemp() - with remove(fn), open(fn, mode='w+', encoding='shift-jis') as f: - f.write('foo') - f.seek(0) - assert f.read() == 'foo' - assert not f.read() - def test_beat_is_gay(): with pytest.raises(SyntaxError): ie.compile('"beat".succ!') From 57b51aa67937674085288bb23b505ffe78edeb08 Mon Sep 17 00:00:00 2001 From: ioistired Date: Mon, 27 May 2024 01:44:34 +0000 Subject: [PATCH 2/6] attribute Thanos --- import_expression/_parser.py | 1 + import_expression/_syntax.py | 1 + 2 files changed, 2 insertions(+) diff --git a/import_expression/_parser.py b/import_expression/_parser.py index f36a1f7..8e9d8ff 100644 --- a/import_expression/_parser.py +++ b/import_expression/_parser.py @@ -1,4 +1,5 @@ # Copyright © io mintz +# Copyright © Thanos <111999343+Sachaa-Thanasius@users.noreply.github.com> # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the “Software”), diff --git a/import_expression/_syntax.py b/import_expression/_syntax.py index d13ad19..ef9cb54 100644 --- a/import_expression/_syntax.py +++ b/import_expression/_syntax.py @@ -1,4 +1,5 @@ # Copyright © io mintz +# Copyright © Thanos <111999343+Sachaa-Thanasius@users.noreply.github.com> # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the “Software”), From 8574478f939e79fea0628ef7d2cea9771f0c505a Mon Sep 17 00:00:00 2001 From: ioistired Date: Mon, 27 May 2024 02:17:50 +0000 Subject: [PATCH 3/6] [wip] port jishaku.inline_import --- import_expression/__init__.py | 12 +------- import_expression/_main2.py | 36 ------------------------ import_expression/_parser.py | 52 ----------------------------------- import_expression/_syntax.py | 15 +++++----- 4 files changed, 9 insertions(+), 106 deletions(-) delete mode 100644 import_expression/_main2.py diff --git a/import_expression/__init__.py b/import_expression/__init__.py index 2c016fb..445f1dd 100755 --- a/import_expression/__init__.py +++ b/import_expression/__init__.py @@ -30,13 +30,12 @@ from . import constants from ._syntax import fix_syntax as _fix_syntax from ._parser import transform_ast as _transform_ast -from ._parser import find_imports as _find_imports from .version import __version__ with _contextlib.suppress(NameError): del version -__all__ = ('compile', 'parse', 'eval', 'exec', 'constants', 'find_imports') +__all__ = ('compile', 'parse', 'eval', 'exec', 'constants') _source = _typing.Union[_ast.AST, _typing.AnyStr] @@ -101,15 +100,6 @@ def exec(source: _code, globals=None, locals=None): return _builtins.eval(source, globals, locals) _builtins.eval(compile(source, constants.DEFAULT_FILENAME, 'exec'), globals, locals) -def find_imports(source: str, filename=constants.DEFAULT_FILENAME, mode='exec'): - """return a list of all module names required by the given source code.""" - # passing an AST is not supported because it doesn't make sense to. - # either the AST is one that we made, in which case the imports have already been made and calling parse_ast again - # would find no imports, or it's an AST made by parsing the output of fix_syntax, which is internal. - fixed = _fix_syntax(source, filename=filename) - tree = _ast.parse(fixed, filename, mode) - return _find_imports(tree, filename=filename) - def _parse_eval_exec_args(globals, locals): if globals is None: globals = {} diff --git a/import_expression/_main2.py b/import_expression/_main2.py deleted file mode 100644 index 3a9f596..0000000 --- a/import_expression/_main2.py +++ /dev/null @@ -1,36 +0,0 @@ -import argparse -import shutil -import sys - -import import_expression -import import_expression._codec - -def main(): - import argparse - - version_info = ( - f'Import Expression Parser {import_expression.__version__}\n' - f'Python {sys.version}' - ) - - parser = argparse.ArgumentParser( - prog='import-expression-rewrite', - description='rewrites import expresion python to standard python', - ) - parser.add_argument('-i', '--in-place', dest='in_place', action='store_true', help='whether to rewrite in place') - parser.add_argument('filename', metavar='module', help='path to a python file to rewrite to stdout') - - args = parser.parse_args() - - import_expression._codec.register() - if args.in_place: - with open(args.filename, 'r+', encoding='import_expression') as infp: - buf = infp.read() - infp.seek(0) - infp.write(buf) - else: - with open(args.filename, encoding='import_expression') as infp: - shutil.copyfileobj(infp, sys.stdout) - -if __name__ == '__main__': - main() diff --git a/import_expression/_parser.py b/import_expression/_parser.py index 8e9d8ff..1214bac 100644 --- a/import_expression/_parser.py +++ b/import_expression/_parser.py @@ -117,55 +117,3 @@ def _syntax_error(self, message, node): )) return SyntaxError(message, SyntaxErrorContext(**kwargs)) - -class ListingTransformer(Transformer): - """like the parent class but lists all imported modules as self.imports""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.imports = [] - - def import_call(self, attribute_source, *args, **kwargs): - self.imports.append(attribute_source) - return super().import_call(attribute_source, *args, **kwargs) - -def find_imports(root_node, **kwargs): - t = ListingTransformer(**kwargs) - t.visit(root_node) - return t.imports - -def copy_annotations( - original_func: typing.Callable[P, T], -) -> typing.Callable[[typing.Callable[..., typing.Any]], typing.Callable[P, T]]: - """A decorator that applies the annotations from one function onto another. - - It can be a lie, but it aids the type checker and any IDE intellisense. - """ - - def inner(new_func: typing.Callable[..., typing.Any]) -> typing.Callable[P, T]: - return functools.update_wrapper(new_func, original_func, ("__doc__", "__annotations__")) # type: ignore - - return inner - - -# Some of the parameter annotations are too narrow or wide, but they should be "overriden" by this decorator. -@copy_annotations(ast.parse) -def parse( - source: typing.Union[str, ReadableBuffer], - filename: str = DEFAULT_FILENAME, - mode: str = "exec", - *, - type_comments: bool = False, - feature_version: typing.Optional[typing.Tuple[int, int]] = None, -) -> ast.Module: - """Convert source code with inline import expressions to an AST. Has the same signature as ast.parse.""" - - return transform_ast( - ast.parse( - transform_source(source), - filename, - mode, - type_comments=type_comments, - feature_version=feature_version, - ) - ) diff --git a/import_expression/_syntax.py b/import_expression/_syntax.py index ef9cb54..d402784 100644 --- a/import_expression/_syntax.py +++ b/import_expression/_syntax.py @@ -40,7 +40,7 @@ def fix_syntax(s: typing.AnyStr, filename=DEFAULT_FILENAME) -> bytes: try: - tokens = list(tokenize(s)) + tokens, encoding = tokenize(s) except tokenize_.TokenError as ex: message, (lineno, offset) = ex.args @@ -51,8 +51,9 @@ def fix_syntax(s: typing.AnyStr, filename=DEFAULT_FILENAME) -> bytes: raise SyntaxError(message, (filename, lineno-1, offset, source_line)) from None + tokens = list(tokens) transformed = transform_tokens(tokens) - return tokenize_.untokenize(tokens) + return tokenize_.untokenize(transformed).decode(encoding) def offset_token_horizontal(tok: tokenize_.TokenInfo, offset: int) -> tokenize_.TokenInfo: """Takes a token and returns a new token with the columns for start and end offset by a given amount.""" @@ -128,14 +129,14 @@ def transform_tokens(tokens: typing.Iterable[tokenize_.TokenInfo]) -> typing.Lis tokenize_.TokenInfo( tokenize_.OP, "(", - (old_f_row, old_f_col + 17), - (old_f_row, old_f_col + 18), + (old_f_row, old_f_col + len(MARKER)), + (old_f_row, old_f_col + len(MARKER)+1), old_first.line, ), ] # Adjust the positions of the following tokens within the inline import expression. - new_tokens[last_place + 2:] = (offset_token_horizontal(tok, 18) for tok in new_tokens[last_place + 2:]) + new_tokens[last_place + 2:] = (offset_token_horizontal(tok, len(MARKER)+1) for tok in new_tokens[last_place + 2:]) # Add a closing parenthesis. (end_row, end_col) = new_tokens[-1].end @@ -145,7 +146,7 @@ def transform_tokens(tokens: typing.Iterable[tokenize_.TokenInfo]) -> typing.Lis # Fix the positions of the rest of the tokens on the same line. fixed_line_tokens: typing.List[tokenize_.TokenInfo] = [] - offset_line_horizontal(orig_tokens, orig_i, line=new_tokens[-1].start[0], offset=18) + offset_line_horizontal(orig_tokens, orig_i, line=new_tokens[-1].start[0], offset=len(MARKER)+1) # Check the rest of the line for inline import expressions. new_tokens.extend(transform_tokens(fixed_line_tokens)) @@ -174,4 +175,4 @@ def tokenize(source: typing.Union[str, ReadableBuffer]) -> str: stream = io.BytesIO(source) encoding, _ = tokenize_.detect_encoding(stream.readline) stream.seek(0) - return tokenize_.tokenize(stream.readline) + return tokenize_.tokenize(stream.readline), encoding From 52d69b700ef1d518f8cc7abaccf07d5279944078 Mon Sep 17 00:00:00 2001 From: ioistired Date: Mon, 27 May 2024 02:19:16 +0000 Subject: [PATCH 4/6] [wip] port jishaku.inline_import --- tests.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests.py b/tests.py index d735dee..e0b04ce 100644 --- a/tests.py +++ b/tests.py @@ -244,26 +244,6 @@ def test_locals_arg(): ie.exec('assert locals() is globals()', {}) ie.exec('assert locals() is not globals()', {}, {}) -def test_find_imports(): - with pytest.raises(SyntaxError): - ie.find_imports('x; y', mode='eval') - - assert set(ie.find_imports(textwrap.dedent(""" - x = a! - y = a.b!.c - z = d.e.f - """))) == {'a', 'a.b'} - - assert ie.find_imports('urllib.parse!.quote', mode='eval') == ['urllib.parse'] - -class remove(contextlib.AbstractContextManager): - def __init__(self, name): - self.name = name - def __enter__(self): - return self - def __exit__(self, *excinfo): - os.remove(self.name) - def test_bytes(): import typing assert ie.eval(b'typing!.TYPE_CHECKING') == typing.TYPE_CHECKING From 729e5d7209db8de945bad74f4c80a1600ff7b981 Mon Sep 17 00:00:00 2001 From: ioistired Date: Mon, 27 May 2024 02:19:51 +0000 Subject: [PATCH 5/6] [wip] port jishaku.inline_import --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cae884a..5412d60 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ importlib.import_module('urllib.parse').quote('hello there') Counter({'e': 4, 'd': 3, 'c': 2, 'b': 1}) ``` -The other public functions are `exec`, `compile`, `parse`, `find_imports`. +The other public functions are `exec`, `compile`, and `parse`. See their docstrings for details. By default, the filename for `SyntaxError`s is ``. From 098d13f5001c3f5ee8ce19d29ebaae4264e9d15d Mon Sep 17 00:00:00 2001 From: ioistired Date: Mon, 27 May 2024 02:25:09 +0000 Subject: [PATCH 6/6] [wip] port jishaku.inline_import --- README.md | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/README.md b/README.md index 5412d60..a235bb5 100644 --- a/README.md +++ b/README.md @@ -47,15 +47,6 @@ for line in sys.stdin: print(import_expression.eval(code, dict(l=line))) ``` -### Custom encoding - -```py -# encoding: import_expression -print(typing!.TYPE_CHECKING) -``` - -This file, when run, will print True/False. For maximum laziness you can also do `#coding:ie`. - ### REPL usage Run `import-expression` for an import expression enabled REPL. \ @@ -67,11 +58,6 @@ See `import-expression --help` for more details. Run `import-expression `. -### File rewriter - -Run `import-expression-rewrite ` to rewrite a file containing import expressions to standard Python. \ -Add the `-i` flag to rewrite in-place. - ## Limitations / Known Issues * Due to the hell that is f-string parsing, and because `!` is already an operator inside f-strings,