Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Azure support #269

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions lucid/misc/io/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ def _decompress_xz(handle, **kwargs):
}


modes = {
".png": "wb",
".jpg": "wb",
".jpeg": "wb",
".webp": "wb",
".npy": "wb",
".npz": "wb",
".json": "w",
".txt": "w",
".pb": "wb",
}

unsafe_loaders = {
".pickle": _load_pickle,
".pkl": _load_pickle,
Expand Down
6 changes: 6 additions & 0 deletions lucid/misc/io/reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

from lucid.misc.io.writing import write_handle
from lucid.misc.io.scoping import scope_url
from lucid.misc.io.util import isazure

import blobfile

# create logger with module name, e.g. lucid.misc.io.reading
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -177,6 +179,8 @@ def _read_and_cache(url, mode="rb"):
with lock:
if os.path.exists(local_path):
log.debug("Found cached file '%s'.", local_path)
if isazure(url):
return blobfile.BlobFile(url, mode)
return _handle_gfile(local_path)
log.debug("Caching URL '%s' locally at '%s'.", url, local_path)
try:
Expand All @@ -186,6 +190,8 @@ def _read_and_cache(url, mode="rb"):
for chunk in _file_chunk_iterator(input_handle):
output_handle.write(chunk)
gc.collect()
if isazure(url):
return blobfile.BlobFile(url, mode)
return _handle_gfile(local_path, mode=mode)
except tf.errors.NotFoundError:
raise
Expand Down
37 changes: 27 additions & 10 deletions lucid/misc/io/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from lucid.misc.io.writing import write_handle
from lucid.misc.io.serialize_array import _normalize_array
from lucid.misc.io.scoping import current_io_scopes, set_io_scopes

from lucid.misc.io.util import isazure

# create logger with module name, e.g. lucid.misc.io.saving
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -110,21 +110,21 @@ def save_json(object, handle, indent=2):
obj_json = json.dumps(object, indent=indent, cls=ClarityJSONEncoder)
handle.write(obj_json)

return {"type": "json", "url": handle.name}
return {"type": "json", "url": handle.path}


def save_npy(object, handle):
"""Save numpy array as npy file."""
np.save(handle, object)

return {"type": "npy", "shape": object.shape, "dtype": str(object.dtype), "url": handle.name}
return {"type": "npy", "shape": object.shape, "dtype": str(object.dtype), "url": handle.path}


def save_npz(object, handle):
"""Save dict of numpy array as npz file."""
# there is a bug where savez doesn't actually accept a file handle.
log.warning("Saving npz files currently only works locally. :/")
path = handle.name
path = handle.path
handle.close()
if type(object) is dict:
np.savez(path, **object)
Expand All @@ -151,7 +151,7 @@ def save_img(object, handle, domain=None, **kwargs):
return {
"type": "image",
"shape": object.size + (len(object.getbands()),),
"url": handle.name,
"url": handle.path,
}


Expand All @@ -174,13 +174,13 @@ def save_txt(object, handle, **kwargs):
line += b"\n"
handle.write(line)

return {"type": "txt", "url": handle.name}
return {"type": "txt", "url": handle.path}


def save_str(object, handle, **kwargs):
assert isinstance(object, str)
handle.write(object)
return {"type": "txt", "url": handle.name}
return {"type": "txt", "url": handle.path}


def save_pb(object, handle, **kwargs):
Expand All @@ -194,7 +194,7 @@ def save_pb(object, handle, **kwargs):
)
raise
finally:
return {"type": "pb", "url": handle.name}
return {"type": "pb", "url": handle.path}


def save_pickle(object, handle, **kwargs):
Expand All @@ -208,7 +208,7 @@ def save_pickle(object, handle, **kwargs):
def compress_xz(handle, **kwargs):
try:
ret = lzma.LZMAFile(handle, format=lzma.FORMAT_XZ, mode="wb")
ret.name = handle.name
ret.name = handle.path
return ret
except AttributeError as e:
warnings.warn("`compress_xz` failed for handle {}. Re-raising original exception.".format(handle))
Expand All @@ -227,6 +227,18 @@ def compress_xz(handle, **kwargs):
".pb": save_pb,
}

modes = {
".png": "wb",
".jpg": "wb",
".jpeg": "wb",
".webp": "wb",
".npy": "wb",
".npz": "wb",
".json": "w",
".txt": "w",
".pb": "wb",
}

unsafe_savers = {
".pickle": save_pickle,
".pkl": save_pickle,
Expand Down Expand Up @@ -255,6 +267,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona

# Determine context
# Is this a handle? What is the extension? Are we saving to GCS?

is_handle = hasattr(url_or_handle, "write") and hasattr(url_or_handle, "name")
if is_handle:
path = url_or_handle.name
Expand Down Expand Up @@ -292,7 +305,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona
else:
handle_provider = write_handle

with handle_provider(url_or_handle) as handle:
with handle_provider(url_or_handle, mode = modes[ext]) as handle:
with compressor(handle) as compressed_handle:
result = saver(thing, compressed_handle, **kwargs)

Expand All @@ -309,6 +322,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona

# capture save if a save context is available
save_context = save_context if save_context is not None else CaptureSaveContext.current_save_context()

if save_context:
log.debug(
"capturing save: resulted in {} -> {} in save_context {}".format(
Expand All @@ -320,6 +334,9 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona
if result is not None and "url" in result and result["url"].startswith("gs://"):
result["serve"] = "https://storage.googleapis.com/{}".format(result["url"][5:])

if isazure(result["url"]):
result["serve"] = result["url"]

return result


Expand Down
1 change: 0 additions & 1 deletion lucid/misc/io/scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

_thread_local_scopes = threading.local()


def current_io_scopes():
ret = getattr(_thread_local_scopes, "io_scopes", None)
if ret is None:
Expand Down
13 changes: 12 additions & 1 deletion lucid/misc/io/writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from contextlib import contextmanager
from urllib.parse import urlparse
from tensorflow import gfile
import blobfile

from lucid.misc.io.scoping import scope_url
from lucid.misc.io.util import isazure

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,11 +62,19 @@ def write(data, url, mode="wb"):

_write_to_path(data, url, mode=mode)


@contextmanager
def write_handle(path, mode=None):
path = scope_url(path)

if isazure(path):
if mode is None:
mode = "w"
handle = blobfile.BlobFile(path, mode)
handle.path = path
yield handle
handle.close()
return

if _supports_make_dirs(path):
gfile.MakeDirs(os.path.dirname(path))

Expand All @@ -75,5 +85,6 @@ def write_handle(path, mode=None):
mode = "w"

handle = gfile.Open(path, mode)
handle.path = handle.name
yield handle
handle.close()
2 changes: 1 addition & 1 deletion lucid/optvis/overrides/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

pooling_overrides_map = {"MaxPool": avg_smoothed_maxpool_grad}

relu_overrides_map = {"Relu": redirected_relu_grad, "Relu6": redirected_relu6_grad}
relu_overrides_map = {"Relu": redirected_relu_grad, "Relu6": redirected_relu6_grad, "EwZXa": redirected_relu_grad}

default_overrides_map = {**pooling_overrides_map, **relu_overrides_map}

Expand Down