diff --git a/lucid/misc/io/loading.py b/lucid/misc/io/loading.py index a233de35..a14f4d6a 100644 --- a/lucid/misc/io/loading.py +++ b/lucid/misc/io/loading.py @@ -163,6 +163,18 @@ def _decompress_xz(handle, **kwargs): } +modes = { + ".png": "wb", + ".jpg": "wb", + ".jpeg": "wb", + ".webp": "wb", + ".npy": "wb", + ".npz": "wb", + ".json": "w", + ".txt": "w", + ".pb": "wb", +} + unsafe_loaders = { ".pickle": _load_pickle, ".pkl": _load_pickle, diff --git a/lucid/misc/io/reading.py b/lucid/misc/io/reading.py index 94279257..03a2a0b5 100644 --- a/lucid/misc/io/reading.py +++ b/lucid/misc/io/reading.py @@ -33,7 +33,9 @@ from lucid.misc.io.writing import write_handle from lucid.misc.io.scoping import scope_url +from lucid.misc.io.util import isazure +import blobfile # create logger with module name, e.g. lucid.misc.io.reading log = logging.getLogger(__name__) @@ -177,6 +179,8 @@ def _read_and_cache(url, mode="rb"): with lock: if os.path.exists(local_path): log.debug("Found cached file '%s'.", local_path) + if isazure(url): + return blobfile.BlobFile(url, mode) return _handle_gfile(local_path) log.debug("Caching URL '%s' locally at '%s'.", url, local_path) try: @@ -186,6 +190,8 @@ def _read_and_cache(url, mode="rb"): for chunk in _file_chunk_iterator(input_handle): output_handle.write(chunk) gc.collect() + if isazure(url): + return blobfile.BlobFile(url, mode) return _handle_gfile(local_path, mode=mode) except tf.errors.NotFoundError: raise diff --git a/lucid/misc/io/saving.py b/lucid/misc/io/saving.py index ff4ccd5d..a8e30d18 100644 --- a/lucid/misc/io/saving.py +++ b/lucid/misc/io/saving.py @@ -42,7 +42,7 @@ from lucid.misc.io.writing import write_handle from lucid.misc.io.serialize_array import _normalize_array from lucid.misc.io.scoping import current_io_scopes, set_io_scopes - +from lucid.misc.io.util import isazure # create logger with module name, e.g. lucid.misc.io.saving log = logging.getLogger(__name__) @@ -110,21 +110,21 @@ def save_json(object, handle, indent=2): obj_json = json.dumps(object, indent=indent, cls=ClarityJSONEncoder) handle.write(obj_json) - return {"type": "json", "url": handle.name} + return {"type": "json", "url": handle.path} def save_npy(object, handle): """Save numpy array as npy file.""" np.save(handle, object) - return {"type": "npy", "shape": object.shape, "dtype": str(object.dtype), "url": handle.name} + return {"type": "npy", "shape": object.shape, "dtype": str(object.dtype), "url": handle.path} def save_npz(object, handle): """Save dict of numpy array as npz file.""" # there is a bug where savez doesn't actually accept a file handle. log.warning("Saving npz files currently only works locally. :/") - path = handle.name + path = handle.path handle.close() if type(object) is dict: np.savez(path, **object) @@ -151,7 +151,7 @@ def save_img(object, handle, domain=None, **kwargs): return { "type": "image", "shape": object.size + (len(object.getbands()),), - "url": handle.name, + "url": handle.path, } @@ -174,13 +174,13 @@ def save_txt(object, handle, **kwargs): line += b"\n" handle.write(line) - return {"type": "txt", "url": handle.name} + return {"type": "txt", "url": handle.path} def save_str(object, handle, **kwargs): assert isinstance(object, str) handle.write(object) - return {"type": "txt", "url": handle.name} + return {"type": "txt", "url": handle.path} def save_pb(object, handle, **kwargs): @@ -194,7 +194,7 @@ def save_pb(object, handle, **kwargs): ) raise finally: - return {"type": "pb", "url": handle.name} + return {"type": "pb", "url": handle.path} def save_pickle(object, handle, **kwargs): @@ -208,7 +208,7 @@ def save_pickle(object, handle, **kwargs): def compress_xz(handle, **kwargs): try: ret = lzma.LZMAFile(handle, format=lzma.FORMAT_XZ, mode="wb") - ret.name = handle.name + ret.name = handle.path return ret except AttributeError as e: warnings.warn("`compress_xz` failed for handle {}. Re-raising original exception.".format(handle)) @@ -227,6 +227,18 @@ def compress_xz(handle, **kwargs): ".pb": save_pb, } +modes = { + ".png": "wb", + ".jpg": "wb", + ".jpeg": "wb", + ".webp": "wb", + ".npy": "wb", + ".npz": "wb", + ".json": "w", + ".txt": "w", + ".pb": "wb", +} + unsafe_savers = { ".pickle": save_pickle, ".pkl": save_pickle, @@ -255,6 +267,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona # Determine context # Is this a handle? What is the extension? Are we saving to GCS? + is_handle = hasattr(url_or_handle, "write") and hasattr(url_or_handle, "name") if is_handle: path = url_or_handle.name @@ -292,7 +305,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona else: handle_provider = write_handle - with handle_provider(url_or_handle) as handle: + with handle_provider(url_or_handle, mode = modes[ext]) as handle: with compressor(handle) as compressed_handle: result = saver(thing, compressed_handle, **kwargs) @@ -309,6 +322,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona # capture save if a save context is available save_context = save_context if save_context is not None else CaptureSaveContext.current_save_context() + if save_context: log.debug( "capturing save: resulted in {} -> {} in save_context {}".format( @@ -320,6 +334,9 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona if result is not None and "url" in result and result["url"].startswith("gs://"): result["serve"] = "https://storage.googleapis.com/{}".format(result["url"][5:]) + if isazure(result["url"]): + result["serve"] = result["url"] + return result diff --git a/lucid/misc/io/scoping.py b/lucid/misc/io/scoping.py index ea57effc..8681a26d 100644 --- a/lucid/misc/io/scoping.py +++ b/lucid/misc/io/scoping.py @@ -6,7 +6,6 @@ _thread_local_scopes = threading.local() - def current_io_scopes(): ret = getattr(_thread_local_scopes, "io_scopes", None) if ret is None: diff --git a/lucid/misc/io/writing.py b/lucid/misc/io/writing.py index fcb6d1ac..ea0fe7c2 100644 --- a/lucid/misc/io/writing.py +++ b/lucid/misc/io/writing.py @@ -26,8 +26,10 @@ from contextlib import contextmanager from urllib.parse import urlparse from tensorflow import gfile +import blobfile from lucid.misc.io.scoping import scope_url +from lucid.misc.io.util import isazure log = logging.getLogger(__name__) @@ -60,11 +62,19 @@ def write(data, url, mode="wb"): _write_to_path(data, url, mode=mode) - @contextmanager def write_handle(path, mode=None): path = scope_url(path) + if isazure(path): + if mode is None: + mode = "w" + handle = blobfile.BlobFile(path, mode) + handle.path = path + yield handle + handle.close() + return + if _supports_make_dirs(path): gfile.MakeDirs(os.path.dirname(path)) @@ -75,5 +85,6 @@ def write_handle(path, mode=None): mode = "w" handle = gfile.Open(path, mode) + handle.path = handle.name yield handle handle.close() diff --git a/lucid/optvis/overrides/__init__.py b/lucid/optvis/overrides/__init__.py index 40ecd456..8a81ab49 100644 --- a/lucid/optvis/overrides/__init__.py +++ b/lucid/optvis/overrides/__init__.py @@ -38,7 +38,7 @@ pooling_overrides_map = {"MaxPool": avg_smoothed_maxpool_grad} -relu_overrides_map = {"Relu": redirected_relu_grad, "Relu6": redirected_relu6_grad} +relu_overrides_map = {"Relu": redirected_relu_grad, "Relu6": redirected_relu6_grad, "EwZXa": redirected_relu_grad} default_overrides_map = {**pooling_overrides_map, **relu_overrides_map}