Skip to content

Commit

Permalink
Adhoc support for lazy_imports
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 581790406
  • Loading branch information
Conchylicultor authored and The etils Authors committed Nov 13, 2023
1 parent 4cdd75c commit eb2b865
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 115 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest_and_autopublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
# Install deps
- uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.10
- run: pip --version
- run: pip install -e .[all,dev]
- run: pip freeze
Expand Down
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ Changelog follow https://keepachangelog.com/ format.

* `ecolab`:
* Added protobuf repeated fields support to `ecolab.inspect`
* Fixed auto-display when the line contain UTF-8 character
* `ecolab.auto_display`:
* Support specifiers to customize auto-display (`;s`, `;a`, `;i`,...)
* Fixed auto-display when the line contain UTF-8 character
* `epy`:
* `epy.lazy_imports()` support adhoc imports (will re-create the original
`ecolab.adhoc` context when resolved)

## [1.5.2] - 2023-10-24

Expand Down
15 changes: 13 additions & 2 deletions etils/ecolab/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,25 @@ Add a trailing `;` to any statement (assignment, expression, return statement)
to display the current line. This call `IPython.display.display()` for pretty
display.

Format:

* `my_obj;`: Alias for `IPython.display.display(x)`
* `my_obj;s`: (`spec`) Alias for `IPython.display.display(etree.spec_like(x))`
* `my_obj;i`: (`inspect`) Alias for `ecolab.inspect(x)`
* `my_obj;a`: (`array`) Alias for `media.show_images(x)` /
`media.show_videos(x)` (`ecolab.auto_plot_array` behavior)
* `my_obj;q`: (`quiet`) Don't display the line (e.g. last line)

```python
x = my_fn(); # Display `my_fn()` output

my_fn(); # Display `my_fn()` output
my_fn();i # Inspect `my_fn()` output
```

Note that `;` added to the last statement of the cell still silence the
output (`IPython` default behavior).
Note that contrary to `IPython` default behavior, `;` added to the last
statement of the cell will display the line. To silence the last output, use
`;q`.

`;` behavior can be disabled with `ecolab.auto_display(False)`

Expand Down
19 changes: 10 additions & 9 deletions etils/ecolab/array_as_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def show(*objs, **kwargs) -> None:

def auto_plot_array(
*,
# If updating this, also update `_array_repr_html_inner` !!!
video_min_num_frames: int = 15,
# Images outside this range are rescalled
height: None | int | tuple[int, int] = (100, 250),
Expand All @@ -79,11 +80,8 @@ def auto_plot_array(
if ipython is None:
return # Non-notebook environement

show_images_kwargs = show_images_kwargs or {}
show_videos_kwargs = show_videos_kwargs or {}

array_repr_html_fn = functools.partial(
_array_repr_html,
array_repr_html,
video_min_num_frames=video_min_num_frames,
height=height,
show_images_kwargs=show_images_kwargs,
Expand Down Expand Up @@ -126,7 +124,7 @@ def auto_plot_array(
formatter.for_type(enp.lazy.np.ndarray, array_repr_html_fn)


def _array_repr_html(
def array_repr_html(
array: Array,
**kwargs: Any,
) -> Optional[str]:
Expand All @@ -142,12 +140,15 @@ def _array_repr_html(
def _array_repr_html_inner(
img: Array,
*,
video_min_num_frames: int,
height: None | int | tuple[int, int],
show_images_kwargs: dict[str, Any],
show_videos_kwargs: dict[str, Any],
# If updating this, also update `auto_plot_array` !!!
video_min_num_frames: int = 15,
height: None | int | tuple[int, int] = (100, 250),
show_images_kwargs: Optional[dict[str, Any]] = None,
show_videos_kwargs: Optional[dict[str, Any]] = None,
) -> Optional[str]:
"""Display the normalized img, or `None` if the input is not an image."""
show_images_kwargs = show_images_kwargs or {}
show_videos_kwargs = show_videos_kwargs or {}

if not enp.lazy.is_array(img): # Not an array
return None
Expand Down
6 changes: 3 additions & 3 deletions etils/ecolab/array_as_img_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_array_repr_html_valid(
valid_shape: tuple[int, ...],
):
# 2D images are displayed as images
assert '<img' in array_as_img._array_repr_html(
assert '<img' in array_as_img.array_repr_html(
xnp.zeros(valid_shape), **_ARRAY_REPR_HTML_KWARGS
)

Expand All @@ -74,7 +74,7 @@ def test_array_repr_video_html_valid(
valid_shape: tuple[int, ...],
):
# 2D images are displayed as video
assert '<video' in array_as_img._array_repr_html(
assert '<video' in array_as_img.array_repr_html(
xnp.zeros(valid_shape), **_ARRAY_REPR_HTML_KWARGS
)

Expand All @@ -99,7 +99,7 @@ def test_array_repr_html_invalid(
invalid_shape: tuple[int, ...],
):
assert (
array_as_img._array_repr_html(
array_as_img.array_repr_html(
xnp.zeros(invalid_shape), **_ARRAY_REPR_HTML_KWARGS
)
is None
Expand Down
170 changes: 132 additions & 38 deletions etils/ecolab/auto_display_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@
from __future__ import annotations

import ast
import dataclasses
import functools
import re
import traceback
import typing
from typing import TypeVar

from etils import epy
from etils.ecolab import array_as_img
from etils.ecolab.inspects import core as inspects
from etils.etree import jax as etree # pylint: disable=g-importing-member
import IPython
import packaging

Expand All @@ -40,8 +46,8 @@ def auto_display(activate: bool = True) -> None:
statement) to display the current line.
This call `IPython.display.display()` for pretty display.
Note that `;` added to the last statement of the cell still silence the
output (`IPython` default behavior).
This change the default IPython behavior where `;` added to the last statement
of the cell still silence the output.
```python
x = my_fn(); # Display `my_fn()`
Expand All @@ -51,6 +57,16 @@ def auto_display(activate: bool = True) -> None:
`;` behavior can be disabled with `ecolab.auto_display(False)`
Format:
* `my_obj;`: Alias for `IPython.display.display(x)`
* `my_obj;s`: (`spec`) Alias for
`IPython.display.display(etree.spec_like(x))`
* `my_obj;i`: (`inspect`) Alias for `ecolab.inspect(x)`
* `my_obj;a`: (`array`) Alias for `media.show_images(x)` /
`media.show_videos(x)` (`ecolab.auto_plot_array` behavior)
* `my_obj;q`: (`quiet`) Don't display the line (e.g. last line)
Args:
activate: Allow to disable `auto_display`
"""
Expand Down Expand Up @@ -116,6 +132,8 @@ class _RecordLines(IPython.core.inputtransformer.InputTransformer):
def __init__(self):
self._lines = []
self.last_lines = []

self.trailing_stmt_line_nums = {}
super().__init__()

def push(self, line):
Expand All @@ -132,6 +150,8 @@ def reset(self):
for line in self._lines:
self.last_lines.extend(line.split('\n'))
self._lines.clear()

self.trailing_stmt_line_nums.clear()
return

else:
Expand All @@ -141,9 +161,13 @@ class _RecordLines:

def __init__(self):
self.last_lines = []
# Additional state (reset at each cell) to keep track of which lines
# contain trailing statements
self.trailing_stmt_line_nums = {}

def __call__(self, lines: list[str]) -> list[str]:
self.last_lines = [l.rstrip('\n') for l in lines]
self.trailing_stmt_line_nums = {}
return lines


Expand All @@ -159,29 +183,45 @@ def _maybe_display(
) -> ast.AST:
"""Wrap the node in a `display()` call."""
try:
has_trailing, is_last_statement = _has_trailing_semicolon(
self.lines_recorder.last_lines, node
)
if has_trailing:
if is_last_statement and isinstance(node, ast.Expr):
# Last expressions are already displayed by IPython, so instead
# IPython silence the statement
pass
elif node.value is None: # `AnnAssign().value` can be `None` (`a: int`)
if self._is_alias_stmt(node): # Alias statements should be no-op
return ast.Pass()

line_info = _has_trailing_semicolon(self.lines_recorder.last_lines, node)
if line_info.has_trailing:
if node.value is None: # `AnnAssign().value` can be `None` (`a: int`)
pass
else:

fn_name = _ALIAS_TO_DISPLAY_FN[line_info.alias].__name__

node.value = ast.Call(
func=_parse_expr('ecolab.auto_display_utils._display_and_return'),
func=_parse_expr(f'ecolab.auto_display_utils.{fn_name}'),
args=[node.value],
keywords=[],
)
self.lines_recorder.trailing_stmt_line_nums[line_info.line_num] = (
line_info
)
except Exception as e:
code = '\n'.join(self.lines_recorder.last_lines)
print(f'Error for code:\n-----\n{code}\n-----')
traceback.print_exception(e)
raise
return node

def _is_alias_stmt(self, node: ast.AST) -> bool:
match node:
case ast.Expr(value=ast.Name(id=name)):
pass
case _:
return False
if name not in _ALIAS_TO_DISPLAY_FN:
return False
# The alias is not in the same line as a trailing `;`
if node.end_lineno - 1 not in self.lines_recorder.trailing_stmt_line_nums:
return False
return True

# pylint: disable=invalid-name
visit_Assign = _maybe_display
visit_AnnAssign = _maybe_display
Expand All @@ -194,46 +234,100 @@ def _parse_expr(code: str) -> ast.AST:
return ast.parse(code, mode='eval').body


@dataclasses.dataclass(frozen=True)
class _LineInfo:
has_trailing: bool
alias: str
line_num: int


def _has_trailing_semicolon(
code_lines: list[str],
node: ast.AST,
) -> tuple[bool, bool]:
) -> _LineInfo:
"""Check if `node` has trailing `;`."""
if isinstance(node, ast.AnnAssign) and node.value is None:
return False, False # `AnnAssign().value` can be `None` (`a: int`)
return _LineInfo(
has_trailing=False,
alias='',
line_num=-1,
) # `AnnAssign().value` can be `None` (`a: int`)

# Extract the lines of the statement
last_line = code_lines[node.end_lineno - 1] # lineno starts at `1`
# Check if the last character is a `;` token
has_trailing = False
line_num = node.end_lineno - 1
last_line = code_lines[line_num] # lineno starts at `1`

# `node.end_col_offset` is in bytes, so UTF-8 characters count 3.
last_part_of_line = last_line.encode('utf-8')
last_part_of_line = last_part_of_line[node.end_col_offset :]
last_part_of_line = last_part_of_line.decode('utf-8')
for char in last_part_of_line:
if char == ';':
has_trailing = True
elif char == ' ':
continue
elif char == '#': # Comment,...
break
else: # e.g. `a=1;b=2`
has_trailing = False
break

is_last_statement = True # Assume statement is the last one
for line in code_lines[node.end_lineno :]: # Next statements are all empty
line = line.strip()
if line and not line.startswith('#'):
is_last_statement = False
break
if last_line.startswith(' '):
# statement is inside `if` / `with` / ...
is_last_statement = False
return has_trailing, is_last_statement

# Check if the last character is a `;` token
has_trailing = False
alias = ''
if match := _detect_trailing_regex().match(last_part_of_line):
has_trailing = True
if match.group(1):
alias = match.group(1)

return _LineInfo(
has_trailing=has_trailing,
alias=alias,
line_num=line_num,
)


@functools.cache
def _detect_trailing_regex() -> re.Pattern[str]:
"""Check if the last character is a `;` token."""
# Match:
# * `; a`
# * `; a # Some comment`
# * `; # Some comment`
# Do not match:
# * `; a; b`
# * `; a=1`
available_chars = ''.join(_ALIAS_TO_DISPLAY_FN)
return re.compile(f' *; *([{available_chars}])? *(?:#.*)?$')


def _display_and_return(x: _T) -> _T:
"""Print `x` and return `x`."""
IPython.display.display(x)
return x


def _display_specs_and_return(x: _T) -> _T:
"""Print `x` and return `x`."""
IPython.display.display(etree.spec_like(x))
return x


def _inspect_and_return(x: _T) -> _T:
"""Print `x` and return `x`."""
inspects.inspect(x)
return x


def _display_array_and_return(x: _T) -> _T:
"""Print `x` and return `x`."""
html = array_as_img.array_repr_html(x)
if html is None:
IPython.display.display(x)
else:
IPython.display.display(IPython.display.HTML(html))
return x


def _return_quietly(x: _T) -> _T:
"""Return `x` without display."""
return x


_ALIAS_TO_DISPLAY_FN = {
'': _display_and_return,
's': _display_specs_and_return,
'i': _inspect_and_return,
'a': _display_array_and_return,
'q': _return_quietly,
}
Loading

0 comments on commit eb2b865

Please sign in to comment.