Skip to content

Commit

Permalink
Add more typing to safer
Browse files Browse the repository at this point in the history
  • Loading branch information
rec committed Dec 29, 2023
1 parent 1d481eb commit bfc472e
Showing 1 changed file with 57 additions and 44 deletions.
101 changes: 57 additions & 44 deletions safer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,32 +147,27 @@
import contextlib
import functools
import io
import json
import os
import shutil
import sys
import tempfile
import traceback
import typing as t
from pathlib import Path
from typing import IO, Callable, Optional, Union

# There's an edge case in #23 I can't yet fix, so I fail
# deliberately
BUG_MESSAGE = 'Sorry, safer.writer fails if temp_file (#23)'

__all__ = 'writer', 'open', 'closer', 'dump', 'printer'


def writer(
stream: Union[Callable, None, IO, Path, str] = None,
is_binary: Optional[bool] = None,
stream: t.Union[t.Callable, None, t.IO, Path, str] = None,
is_binary: t.Optional[bool] = None,
close_on_exit: bool = False,
temp_file: bool = False,
chunk_size: int = 0x100000,
delete_failures: bool = True,
dry_run: Union[bool, Callable] = False,
dry_run: t.Union[bool, t.Callable] = False,
enabled: bool = True,
) -> Union[Callable, IO]:
) -> t.Union[t.Callable, t.IO]:
"""
Write safely to file streams, sockets and callables.
Expand Down Expand Up @@ -233,7 +228,7 @@ def writer(
if not enabled:
return stream

write: Optional[Callable]
write: t.Optional[t.Callable]

if callable(dry_run):
write, dry_run = dry_run, True
Expand Down Expand Up @@ -307,21 +302,26 @@ def write(v):
return closer.fp


# There's an edge case in #23 I can't yet fix, so I fail
# deliberately
BUG_MESSAGE = 'Sorry, safer.writer fails if temp_file (#23)'


def open(
name: Union[Path, str],
name: t.Union[Path, str],
mode: str = 'r',
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
encoding: t.Optional[str] = None,
errors: t.Optional[str] = None,
newline: t.Optional[str] = None,
closefd: bool = True,
opener: Optional[Callable] = None,
opener: t.Optional[t.Callable] = None,
make_parents: bool = False,
delete_failures: bool = True,
temp_file: bool = False,
dry_run: Union[bool, Callable] = False,
dry_run: t.Union[bool, t.Callable] = False,
enabled: bool = True,
) -> IO:
) -> t.IO:
"""
Args:
make_parents: If true, create the parent directory of the file if needed
Expand Down Expand Up @@ -443,7 +443,9 @@ def simple_write(value):
return closer._make_stream(buffering, mode, **kwargs)


def closer(stream, is_binary=None, close_on_exit=True, **kwds):
def closer(
stream: t.IO, is_binary: t.Optional[bool] = None, close_on_exit: bool = True, **kwds
) -> t.Union[t.Callable, t.IO]:
"""
Like `safer.writer()` but with `close_on_exit=True` by default
Expand All @@ -453,7 +455,12 @@ def closer(stream, is_binary=None, close_on_exit=True, **kwds):
return writer(stream, is_binary, close_on_exit, **kwds)


def dump(obj, stream=None, dump=None, **kwargs):
def dump(
obj,
stream: t.Union[t.Callable, None, t.IO, Path, str] = None,
dump: t.Any = None,
**kwargs,
) -> t.Any:
"""
Safely serialize `obj` as a formatted stream to `fp`` (a
`.write()`-supporting file-like object, or a filename),
Expand All @@ -476,23 +483,34 @@ def dump(obj, stream=None, dump=None, **kwargs):
kwargs:
Additional arguments to `dump`.
"""
if isinstance(stream, str):
name = stream
is_binary = False
else:
name = getattr(stream, 'name', None)
mode = getattr(stream, 'mode', None)
if not isinstance(stream, str):
name = getattr(stream, 'name', '')
mode = getattr(stream, 'mode', '')
if name and mode:
is_binary = 'b' in mode
else:
is_binary = hasattr(stream, 'recv') and hasattr(stream, 'send')
else:
name = stream
is_binary = False

if name and not dump:
dump = Path(name).suffix[1:] or None
if dump == 'yml':
dump = 'yaml'
dump = _get_dumper(dump or Path(name).suffix[1:])

with t.cast(t.IO, writer(stream)) as fp:
if is_binary:
write = fp.write
fp.write = lambda s: write(s.encode('utf-8')) # type: ignore

return dump(obj, fp)


def _get_dumper(dump: t.Any) -> t.Callable:
if isinstance(dump, str):
if not dump:
dump = 'json'
elif dump == 'yml':
dump = 'yaml'

try:
dump = __import__(dump)
except ImportError:
Expand All @@ -501,24 +519,19 @@ def dump(obj, stream=None, dump=None, **kwargs):
mod, name = dump.rsplit('.', maxsplit=1)
dump = getattr(__import__(mod), name)

if dump is None:
dump = json.dump

elif not callable(dump):
try:
dump = dump.safe_dump
except AttributeError:
dump = dump.dump
if callable(dump):
return dump

with writer(stream) as fp:
if is_binary:
write = fp.write
fp.write = lambda s: write(s.encode('utf-8'))
return dump(obj, fp)
try:
return dump.safe_dump
except AttributeError:
return dump.dump


@contextlib.contextmanager
def printer(name, mode='w', *args, **kwargs):
def printer(
name: t.Union[Path, str], mode: str = 'w', *args, **kwargs
) -> t.Generator[t.Callable, None, None]:
"""
A context manager that yields a function that prints to the opened file,
only writing to the original file at the exit of the context,
Expand Down

0 comments on commit bfc472e

Please sign in to comment.