Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add pretty nbytes repr to .show and jupyter repr #3348

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 112 additions & 78 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from awkward._regularize import is_non_string_like_iterable
from awkward._typing import Any, TypeVar
from awkward._util import STDOUT
from awkward.prettyprint import Formatter
from awkward.prettyprint import Formatter, highlevel_array_show_rows
from awkward.prettyprint import valuestr as prettyprint_valuestr

__all__ = ("Array", "ArrayBuilder", "Record")
Expand Down Expand Up @@ -1398,10 +1398,13 @@ def show(
self,
limit_rows=20,
limit_cols=80,
*,
type=False,
named_axis=False,
nbytes=False,
backend=False,
all=False,
stream=STDOUT,
*,
formatter=None,
precision=3,
):
Expand All @@ -1411,9 +1414,16 @@ def show(
limit_cols (int): Maximum number of columns (characters wide).
type (bool): If True, print the type as well. (Doesn't count toward number
of rows/lines limit.)
named_axis (bool): If True, print the named axis as well. (Doesn't count toward number
of rows/lines limit.)
nbytes (bool): If True, print the number of bytes as well. (Doesn't count toward number
of rows/lines limit.)
backend (bool): If True, print the backend of the array as well. (Doesn't count toward number
of rows/lines limit.)
all (bool): If True, print the 'type', 'named axis', 'nbytes', and 'backend' of the array. (Doesn't count toward number
of rows/lines limit.)
stream (object with a ``write(str)`` method or None): Stream to write the
output to. If None, return a string instead of writing to a stream.

formatter (Mapping or None): Mapping of types/type-classes to string formatters.
If None, use the default formatter.

Expand All @@ -1426,64 +1436,70 @@ def show(
key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting
string values, falling back upon `"str_kind"`.
"""
formatter_impl = Formatter(formatter, precision=precision)

valuestr = prettyprint_valuestr(
self, limit_rows, limit_cols, formatter=formatter_impl
rows = highlevel_array_show_rows(
array=self,
limit_rows=limit_rows,
limit_cols=limit_cols,
type=type or all,
named_axis=named_axis or all,
nbytes=nbytes or all,
backend=backend or all,
formatter=formatter,
precision=precision,
)
array_line = rows.pop(0)

out_io = io.StringIO()
if type:
out_io.write("type: ")
self.type.show(stream=out_io)
if named_axis and self.named_axis:
out_io.write("axes: ")
out_io.write(
_prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
)
# it's always the second row (after the array)
type_line = rows.pop(0)
out_io.write(type_line)

# the rest of the rows we sort by the length of their '<prefix>:'
# but we sort it from shortest to longest contrary to _repr_mimebundle_
sorted_rows = sorted([r for r in rows if r], key=lambda x: len(x.split(":")[0]))

if sorted_rows:
out_io.write("\n".join(sorted_rows))
out_io.write("\n")
out_io.write(valuestr)

out_io.write(array_line)
if stream is None:
return out_io
return out_io.getvalue()
else:
if stream is STDOUT:
stream = STDOUT.stream
stream.write(out_io.getvalue() + "\n")

def _repr_mimebundle_(self, include=None, exclude=None):
# order: 1. array, 2. named_axis, 3. type
value_buff = io.StringIO()
self.show(type=False, named_axis=False, stream=value_buff)
header_lines = value_buff.getvalue().splitlines()

named_axis_line = ""
if self.named_axis:
named_axis_buff = io.StringIO()
named_axis_buff.write("axes: ")
named_axis_buff.write(
_prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
)
named_axis_line = named_axis_buff.getvalue()
# order:
# first: array,
# last: type,
# middle: rest sorted by length of prefix (longest first)

rows = highlevel_array_show_rows(
array=self,
type=True,
named_axis=True,
nbytes=True,
backend=True,
)
header_lines = rows.pop(0).removesuffix("\n").splitlines()

type_buff = io.StringIO()
self.type.show(stream=type_buff)
footer_lines = type_buff.getvalue().splitlines()
# Prepend a `type: ` prefix to the type information
footer_lines[0] = f"type: {footer_lines[0]}"
# it's always the second row (after the array)
type_lines = [rows.pop(0).removesuffix("\n")]

if header_lines[-1] == "":
del header_lines[-1]
# the rest of the rows we sort by the length of their '<prefix>:'
# but we sort it from longest to shortest for _repr_mimebundle_
sorted_rows = sorted(rows, key=lambda x: -len(x.split(":")[0]))

n_cols = max(
len(line)
for line in itertools.chain(header_lines, [named_axis_line], footer_lines)
len(line) for line in itertools.chain(header_lines, sorted_rows, type_lines)
)
body_lines = header_lines
body_lines.append("-" * n_cols)
if named_axis_line:
body_lines.append(named_axis_line)
body_lines.extend(footer_lines)
body_lines.extend(sorted_rows)
body_lines.extend(type_lines)
body = "\n".join(body_lines)

return {
Expand Down Expand Up @@ -2317,10 +2333,13 @@ def show(
self,
limit_rows=20,
limit_cols=80,
*,
type=False,
named_axis=False,
nbytes=False,
backend=False,
all=False,
stream=STDOUT,
*,
formatter=None,
precision=3,
):
Expand All @@ -2330,6 +2349,14 @@ def show(
limit_cols (int): Maximum number of columns (characters wide).
type (bool): If True, print the type as well. (Doesn't count toward number
of rows/lines limit.)
named_axis (bool): If True, print the named axis as well. (Doesn't count toward number
of rows/lines limit.)
nbytes (bool): If True, print the number of bytes as well. (Doesn't count toward number
of rows/lines limit.)
backend (bool): If True, print the backend of the array as well. (Doesn't count toward number
of rows/lines limit.)
all (bool): If True, print the 'type', 'named axis', 'nbytes', and 'backend' of the array. (Doesn't count toward number
of rows/lines limit.)
stream (object with a ``write(str)`` method or None): Stream to write the
output to. If None, return a string instead of writing to a stream.
formatter (Mapping or None): Mapping of types/type-classes to string formatters.
Expand All @@ -2344,23 +2371,34 @@ def show(
key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting
string values, falling back upon `"str_kind"`.
"""
formatter_impl = Formatter(formatter, precision=precision)
valuestr = prettyprint_valuestr(
self, limit_rows, limit_cols, formatter=formatter_impl
rows = highlevel_array_show_rows(
array=self,
limit_rows=limit_rows,
limit_cols=limit_cols,
type=type or all,
named_axis=named_axis or all,
nbytes=nbytes or all,
backend=backend or all,
formatter=formatter,
precision=precision,
)
array_line = rows.pop(0)

out_io = io.StringIO()
if type:
out_io.write("type: ")
self.type.show(stream=out_io)
if named_axis and self.named_axis:
out_io.write("axes: ")
out_io.write(
_prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
)
# it's always the second row (after the array)
type_line = rows.pop(0)
out_io.write(type_line)

# the rest of the rows we sort by the length of their '<prefix>:'
# but we sort it from shortest to longest contrary to _repr_mimebundle_
sorted_rows = sorted([r for r in rows if r], key=lambda x: len(x.split(":")[0]))

if sorted_rows:
out_io.write("\n".join(sorted_rows))
out_io.write("\n")
out_io.write(valuestr)

out_io.write(array_line)
if stream is None:
return out_io.getvalue()
else:
Expand All @@ -2369,38 +2407,34 @@ def show(
stream.write(out_io.getvalue() + "\n")

def _repr_mimebundle_(self, include=None, exclude=None):
# order: 1. array, 2. named_axis, 3. type
value_buff = io.StringIO()
self.show(type=False, named_axis=False, stream=value_buff)
header_lines = value_buff.getvalue().splitlines()

named_axis_line = ""
if self.named_axis:
named_axis_buff = io.StringIO()
named_axis_buff.write("axes: ")
named_axis_buff.write(
_prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
)
named_axis_line = named_axis_buff.getvalue()
# order:
# first: array,
# last: type,
# middle: rest sorted by length of prefix (longest first)

rows = highlevel_array_show_rows(
array=self,
type=True,
named_axis=True,
nbytes=True,
backend=True,
)
header_lines = rows.pop(0).removesuffix("\n").splitlines()

type_buff = io.StringIO()
self.type.show(stream=type_buff)
footer_lines = type_buff.getvalue().splitlines()
# Prepend a `type: ` prefix to the type information
footer_lines[0] = f"type: {footer_lines[0]}"
# it's always the second row (after the array)
type_lines = [rows.pop(0).removesuffix("\n")]

if header_lines[-1] == "":
del header_lines[-1]
# the rest of the rows we sort by the length of their '<prefix>:'
# but we sort it from longest to shortest for _repr_mimebundle_
sorted_rows = sorted(rows, key=lambda x: -len(x.split(":")[0]))

n_cols = max(
len(line)
for line in itertools.chain(header_lines, [named_axis_line], footer_lines)
len(line) for line in itertools.chain(header_lines, sorted_rows, type_lines)
)
body_lines = header_lines
body_lines.append("-" * n_cols)
if named_axis_line:
body_lines.append(named_axis_line)
body_lines.extend(footer_lines)
body_lines.extend(sorted_rows)
body_lines.extend(type_lines)
body = "\n".join(body_lines)

return {
Expand Down
61 changes: 61 additions & 0 deletions src/awkward/prettyprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from __future__ import annotations

import io
import math
import re
from collections.abc import Callable

import awkward as ak
from awkward._layout import wrap_layout
from awkward._namedaxis import _prettify_named_axes
from awkward._nplikes.numpy import Numpy, NumpyMetadata
from awkward._typing import TYPE_CHECKING, Any, TypeAlias, TypedDict

Expand Down Expand Up @@ -436,3 +438,62 @@ def valuestr(

else:
raise AssertionError(type(data))


def bytes_repr(nbytes: int) -> str:
count, unit = (
(f"{nbytes / 1e9 :,.1f}", "GB")
if nbytes > 1e9
else (f"{nbytes / 1e6 :,.1f}", "MB")
if nbytes > 1e6
else (f"{nbytes / 1e3 :,.1f}", "kB")
if nbytes > 1e3
else (f"{nbytes:,}", "B")
)

return f"{count} {unit}"


def highlevel_array_show_rows(
array,
limit_rows=20,
limit_cols=80,
type=False,
named_axis=False,
nbytes=False,
backend=False,
*,
formatter=None,
precision=3,
) -> list[str]:
rows = []
formatter_impl = Formatter(formatter, precision=precision)

array_line = valuestr(array, limit_rows, limit_cols, formatter=formatter_impl)
rows.append(array_line)

if type:
typeio = io.StringIO()
array.type.show(stream=typeio)
type_line = "type: "
type_line += typeio.getvalue().removesuffix("\n")
rows.append(type_line)

# other info
if named_axis and array.named_axis:
named_axis_line = "named axis: "
named_axis_line += _prettify_named_axes(
array.named_axis, delimiter=", ", maxlen=None
)
rows.append(named_axis_line)
if nbytes:
nbytes_line = f"nbytes: {bytes_repr(array.nbytes)}"
rows.append(nbytes_line)
if backend:
backend_line = f"backend: {array.layout.backend.name}"
rows.append(backend_line)

# make sure the type is always the second row, don't move it
if type:
assert rows[1].startswith("type: ")
return rows
Loading