From bfc472e782ba072e1e42799b6597ba9caa8486d2 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Fri, 29 Dec 2023 14:19:42 +0100 Subject: [PATCH] Add more typing to safer --- safer/__init__.py | 101 ++++++++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 44 deletions(-) diff --git a/safer/__init__.py b/safer/__init__.py index 5cb3484..364320f 100644 --- a/safer/__init__.py +++ b/safer/__init__.py @@ -147,32 +147,27 @@ import contextlib import functools import io -import json import os import shutil import sys import tempfile import traceback +import typing as t from pathlib import Path -from typing import IO, Callable, Optional, Union - -# There's an edge case in #23 I can't yet fix, so I fail -# deliberately -BUG_MESSAGE = 'Sorry, safer.writer fails if temp_file (#23)' __all__ = 'writer', 'open', 'closer', 'dump', 'printer' def writer( - stream: Union[Callable, None, IO, Path, str] = None, - is_binary: Optional[bool] = None, + stream: t.Union[t.Callable, None, t.IO, Path, str] = None, + is_binary: t.Optional[bool] = None, close_on_exit: bool = False, temp_file: bool = False, chunk_size: int = 0x100000, delete_failures: bool = True, - dry_run: Union[bool, Callable] = False, + dry_run: t.Union[bool, t.Callable] = False, enabled: bool = True, -) -> Union[Callable, IO]: +) -> t.Union[t.Callable, t.IO]: """ Write safely to file streams, sockets and callables. @@ -233,7 +228,7 @@ def writer( if not enabled: return stream - write: Optional[Callable] + write: t.Optional[t.Callable] if callable(dry_run): write, dry_run = dry_run, True @@ -307,21 +302,26 @@ def write(v): return closer.fp +# There's an edge case in #23 I can't yet fix, so I fail +# deliberately +BUG_MESSAGE = 'Sorry, safer.writer fails if temp_file (#23)' + + def open( - name: Union[Path, str], + name: t.Union[Path, str], mode: str = 'r', buffering: int = -1, - encoding: Optional[str] = None, - errors: Optional[str] = None, - newline: Optional[str] = None, + encoding: t.Optional[str] = None, + errors: t.Optional[str] = None, + newline: t.Optional[str] = None, closefd: bool = True, - opener: Optional[Callable] = None, + opener: t.Optional[t.Callable] = None, make_parents: bool = False, delete_failures: bool = True, temp_file: bool = False, - dry_run: Union[bool, Callable] = False, + dry_run: t.Union[bool, t.Callable] = False, enabled: bool = True, -) -> IO: +) -> t.IO: """ Args: make_parents: If true, create the parent directory of the file if needed @@ -443,7 +443,9 @@ def simple_write(value): return closer._make_stream(buffering, mode, **kwargs) -def closer(stream, is_binary=None, close_on_exit=True, **kwds): +def closer( + stream: t.IO, is_binary: t.Optional[bool] = None, close_on_exit: bool = True, **kwds +) -> t.Union[t.Callable, t.IO]: """ Like `safer.writer()` but with `close_on_exit=True` by default @@ -453,7 +455,12 @@ def closer(stream, is_binary=None, close_on_exit=True, **kwds): return writer(stream, is_binary, close_on_exit, **kwds) -def dump(obj, stream=None, dump=None, **kwargs): +def dump( + obj, + stream: t.Union[t.Callable, None, t.IO, Path, str] = None, + dump: t.Any = None, + **kwargs, +) -> t.Any: """ Safely serialize `obj` as a formatted stream to `fp`` (a `.write()`-supporting file-like object, or a filename), @@ -476,23 +483,34 @@ def dump(obj, stream=None, dump=None, **kwargs): kwargs: Additional arguments to `dump`. """ - if isinstance(stream, str): - name = stream - is_binary = False - else: - name = getattr(stream, 'name', None) - mode = getattr(stream, 'mode', None) + if not isinstance(stream, str): + name = getattr(stream, 'name', '') + mode = getattr(stream, 'mode', '') if name and mode: is_binary = 'b' in mode else: is_binary = hasattr(stream, 'recv') and hasattr(stream, 'send') + else: + name = stream + is_binary = False - if name and not dump: - dump = Path(name).suffix[1:] or None - if dump == 'yml': - dump = 'yaml' + dump = _get_dumper(dump or Path(name).suffix[1:]) + + with t.cast(t.IO, writer(stream)) as fp: + if is_binary: + write = fp.write + fp.write = lambda s: write(s.encode('utf-8')) # type: ignore + return dump(obj, fp) + + +def _get_dumper(dump: t.Any) -> t.Callable: if isinstance(dump, str): + if not dump: + dump = 'json' + elif dump == 'yml': + dump = 'yaml' + try: dump = __import__(dump) except ImportError: @@ -501,24 +519,19 @@ def dump(obj, stream=None, dump=None, **kwargs): mod, name = dump.rsplit('.', maxsplit=1) dump = getattr(__import__(mod), name) - if dump is None: - dump = json.dump - - elif not callable(dump): - try: - dump = dump.safe_dump - except AttributeError: - dump = dump.dump + if callable(dump): + return dump - with writer(stream) as fp: - if is_binary: - write = fp.write - fp.write = lambda s: write(s.encode('utf-8')) - return dump(obj, fp) + try: + return dump.safe_dump + except AttributeError: + return dump.dump @contextlib.contextmanager -def printer(name, mode='w', *args, **kwargs): +def printer( + name: t.Union[Path, str], mode: str = 'w', *args, **kwargs +) -> t.Generator[t.Callable, None, None]: """ A context manager that yields a function that prints to the opened file, only writing to the original file at the exit of the context,