From 884b1c20b075a5d054d99aa3d67f73ee3c8ccbf0 Mon Sep 17 00:00:00 2001 From: Gabriel Goh Date: Thu, 19 Nov 2020 02:08:09 -0800 Subject: [PATCH 1/2] Azure support --- lucid/misc/io/__init__.py | 1 - lucid/misc/io/loading.py | 12 ++++++++++++ lucid/misc/io/reading.py | 8 ++++++++ lucid/misc/io/saving.py | 21 +++++++++++++++++++-- lucid/misc/io/writing.py | 11 ++++++++++- 5 files changed, 49 insertions(+), 4 deletions(-) diff --git a/lucid/misc/io/__init__.py b/lucid/misc/io/__init__.py index bd7c81d7..3217427a 100644 --- a/lucid/misc/io/__init__.py +++ b/lucid/misc/io/__init__.py @@ -1,4 +1,3 @@ -from lucid.misc.io.showing import show from lucid.misc.io.loading import load from lucid.misc.io.saving import save, CaptureSaveContext, batch_save from lucid.misc.io.scoping import io_scope, scope_url diff --git a/lucid/misc/io/loading.py b/lucid/misc/io/loading.py index a233de35..a14f4d6a 100644 --- a/lucid/misc/io/loading.py +++ b/lucid/misc/io/loading.py @@ -163,6 +163,18 @@ def _decompress_xz(handle, **kwargs): } +modes = { + ".png": "wb", + ".jpg": "wb", + ".jpeg": "wb", + ".webp": "wb", + ".npy": "wb", + ".npz": "wb", + ".json": "w", + ".txt": "w", + ".pb": "wb", +} + unsafe_loaders = { ".pickle": _load_pickle, ".pkl": _load_pickle, diff --git a/lucid/misc/io/reading.py b/lucid/misc/io/reading.py index 94279257..ffb43db6 100644 --- a/lucid/misc/io/reading.py +++ b/lucid/misc/io/reading.py @@ -33,7 +33,9 @@ from lucid.misc.io.writing import write_handle from lucid.misc.io.scoping import scope_url +from lucid.misc.io.util import isazure +import blobfile # create logger with module name, e.g. lucid.misc.io.reading log = logging.getLogger(__name__) @@ -93,6 +95,12 @@ def read_handle(url, cache=None, mode="rb"): """ url = scope_url(url) + if isazure(url): + handle = blobfile.BlobFile(url, mode) + yield handle + handle.close() + return + scheme = urlparse(url).scheme if cache == "purge": diff --git a/lucid/misc/io/saving.py b/lucid/misc/io/saving.py index ff4ccd5d..2e1b5524 100644 --- a/lucid/misc/io/saving.py +++ b/lucid/misc/io/saving.py @@ -42,7 +42,7 @@ from lucid.misc.io.writing import write_handle from lucid.misc.io.serialize_array import _normalize_array from lucid.misc.io.scoping import current_io_scopes, set_io_scopes - +from lucid.misc.io.util import isazure # create logger with module name, e.g. lucid.misc.io.saving log = logging.getLogger(__name__) @@ -227,6 +227,18 @@ def compress_xz(handle, **kwargs): ".pb": save_pb, } +modes = { + ".png": "wb", + ".jpg": "wb", + ".jpeg": "wb", + ".webp": "wb", + ".npy": "wb", + ".npz": "wb", + ".json": "w", + ".txt": "w", + ".pb": "wb", +} + unsafe_savers = { ".pickle": save_pickle, ".pkl": save_pickle, @@ -255,6 +267,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona # Determine context # Is this a handle? What is the extension? Are we saving to GCS? + is_handle = hasattr(url_or_handle, "write") and hasattr(url_or_handle, "name") if is_handle: path = url_or_handle.name @@ -292,7 +305,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona else: handle_provider = write_handle - with handle_provider(url_or_handle) as handle: + with handle_provider(url_or_handle, mode = modes[ext]) as handle: with compressor(handle) as compressed_handle: result = saver(thing, compressed_handle, **kwargs) @@ -309,6 +322,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona # capture save if a save context is available save_context = save_context if save_context is not None else CaptureSaveContext.current_save_context() + if save_context: log.debug( "capturing save: resulted in {} -> {} in save_context {}".format( @@ -320,6 +334,9 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona if result is not None and "url" in result and result["url"].startswith("gs://"): result["serve"] = "https://storage.googleapis.com/{}".format(result["url"][5:]) + if isazure(result["url"]): + result["serve"] = url + return result diff --git a/lucid/misc/io/writing.py b/lucid/misc/io/writing.py index fcb6d1ac..150d7dd6 100644 --- a/lucid/misc/io/writing.py +++ b/lucid/misc/io/writing.py @@ -26,8 +26,10 @@ from contextlib import contextmanager from urllib.parse import urlparse from tensorflow import gfile +import blobfile from lucid.misc.io.scoping import scope_url +from lucid.misc.io.util import isazure log = logging.getLogger(__name__) @@ -60,11 +62,18 @@ def write(data, url, mode="wb"): _write_to_path(data, url, mode=mode) - @contextmanager def write_handle(path, mode=None): path = scope_url(path) + if isazure(path): + if mode is None: + mode = "w" + handle = blobfile.BlobFile(path, mode) + yield handle + handle.close() + return + if _supports_make_dirs(path): gfile.MakeDirs(os.path.dirname(path)) From 7ca7e63ec369598b8de87170fcb92a13b0bed4e6 Mon Sep 17 00:00:00 2001 From: Gabriel Goh Date: Thu, 19 Nov 2020 12:04:35 -0800 Subject: [PATCH 2/2] Addressed comments + added missing files --- lucid/misc/io/__init__.py | 1 + lucid/misc/io/reading.py | 10 ++++------ lucid/misc/io/saving.py | 18 +++++++++--------- lucid/misc/io/scoping.py | 1 - lucid/misc/io/writing.py | 2 ++ lucid/optvis/overrides/__init__.py | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/lucid/misc/io/__init__.py b/lucid/misc/io/__init__.py index 3217427a..bd7c81d7 100644 --- a/lucid/misc/io/__init__.py +++ b/lucid/misc/io/__init__.py @@ -1,3 +1,4 @@ +from lucid.misc.io.showing import show from lucid.misc.io.loading import load from lucid.misc.io.saving import save, CaptureSaveContext, batch_save from lucid.misc.io.scoping import io_scope, scope_url diff --git a/lucid/misc/io/reading.py b/lucid/misc/io/reading.py index ffb43db6..03a2a0b5 100644 --- a/lucid/misc/io/reading.py +++ b/lucid/misc/io/reading.py @@ -95,12 +95,6 @@ def read_handle(url, cache=None, mode="rb"): """ url = scope_url(url) - if isazure(url): - handle = blobfile.BlobFile(url, mode) - yield handle - handle.close() - return - scheme = urlparse(url).scheme if cache == "purge": @@ -185,6 +179,8 @@ def _read_and_cache(url, mode="rb"): with lock: if os.path.exists(local_path): log.debug("Found cached file '%s'.", local_path) + if isazure(url): + return blobfile.BlobFile(url, mode) return _handle_gfile(local_path) log.debug("Caching URL '%s' locally at '%s'.", url, local_path) try: @@ -194,6 +190,8 @@ def _read_and_cache(url, mode="rb"): for chunk in _file_chunk_iterator(input_handle): output_handle.write(chunk) gc.collect() + if isazure(url): + return blobfile.BlobFile(url, mode) return _handle_gfile(local_path, mode=mode) except tf.errors.NotFoundError: raise diff --git a/lucid/misc/io/saving.py b/lucid/misc/io/saving.py index 2e1b5524..a8e30d18 100644 --- a/lucid/misc/io/saving.py +++ b/lucid/misc/io/saving.py @@ -110,21 +110,21 @@ def save_json(object, handle, indent=2): obj_json = json.dumps(object, indent=indent, cls=ClarityJSONEncoder) handle.write(obj_json) - return {"type": "json", "url": handle.name} + return {"type": "json", "url": handle.path} def save_npy(object, handle): """Save numpy array as npy file.""" np.save(handle, object) - return {"type": "npy", "shape": object.shape, "dtype": str(object.dtype), "url": handle.name} + return {"type": "npy", "shape": object.shape, "dtype": str(object.dtype), "url": handle.path} def save_npz(object, handle): """Save dict of numpy array as npz file.""" # there is a bug where savez doesn't actually accept a file handle. log.warning("Saving npz files currently only works locally. :/") - path = handle.name + path = handle.path handle.close() if type(object) is dict: np.savez(path, **object) @@ -151,7 +151,7 @@ def save_img(object, handle, domain=None, **kwargs): return { "type": "image", "shape": object.size + (len(object.getbands()),), - "url": handle.name, + "url": handle.path, } @@ -174,13 +174,13 @@ def save_txt(object, handle, **kwargs): line += b"\n" handle.write(line) - return {"type": "txt", "url": handle.name} + return {"type": "txt", "url": handle.path} def save_str(object, handle, **kwargs): assert isinstance(object, str) handle.write(object) - return {"type": "txt", "url": handle.name} + return {"type": "txt", "url": handle.path} def save_pb(object, handle, **kwargs): @@ -194,7 +194,7 @@ def save_pb(object, handle, **kwargs): ) raise finally: - return {"type": "pb", "url": handle.name} + return {"type": "pb", "url": handle.path} def save_pickle(object, handle, **kwargs): @@ -208,7 +208,7 @@ def save_pickle(object, handle, **kwargs): def compress_xz(handle, **kwargs): try: ret = lzma.LZMAFile(handle, format=lzma.FORMAT_XZ, mode="wb") - ret.name = handle.name + ret.name = handle.path return ret except AttributeError as e: warnings.warn("`compress_xz` failed for handle {}. Re-raising original exception.".format(handle)) @@ -335,7 +335,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona result["serve"] = "https://storage.googleapis.com/{}".format(result["url"][5:]) if isazure(result["url"]): - result["serve"] = url + result["serve"] = result["url"] return result diff --git a/lucid/misc/io/scoping.py b/lucid/misc/io/scoping.py index ea57effc..8681a26d 100644 --- a/lucid/misc/io/scoping.py +++ b/lucid/misc/io/scoping.py @@ -6,7 +6,6 @@ _thread_local_scopes = threading.local() - def current_io_scopes(): ret = getattr(_thread_local_scopes, "io_scopes", None) if ret is None: diff --git a/lucid/misc/io/writing.py b/lucid/misc/io/writing.py index 150d7dd6..ea0fe7c2 100644 --- a/lucid/misc/io/writing.py +++ b/lucid/misc/io/writing.py @@ -70,6 +70,7 @@ def write_handle(path, mode=None): if mode is None: mode = "w" handle = blobfile.BlobFile(path, mode) + handle.path = path yield handle handle.close() return @@ -84,5 +85,6 @@ def write_handle(path, mode=None): mode = "w" handle = gfile.Open(path, mode) + handle.path = handle.name yield handle handle.close() diff --git a/lucid/optvis/overrides/__init__.py b/lucid/optvis/overrides/__init__.py index 40ecd456..8a81ab49 100644 --- a/lucid/optvis/overrides/__init__.py +++ b/lucid/optvis/overrides/__init__.py @@ -38,7 +38,7 @@ pooling_overrides_map = {"MaxPool": avg_smoothed_maxpool_grad} -relu_overrides_map = {"Relu": redirected_relu_grad, "Relu6": redirected_relu6_grad} +relu_overrides_map = {"Relu": redirected_relu_grad, "Relu6": redirected_relu6_grad, "EwZXa": redirected_relu_grad} default_overrides_map = {**pooling_overrides_map, **relu_overrides_map}