diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 0f514164..3e4fd02e 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -16,10 +16,10 @@ jobs:
steps:
- uses: actions/checkout@v2
- - name: Set up Python 3.8
+ - name: Set up Python 3.6
uses: actions/setup-python@v2
with:
- python-version: 3.8
+ python-version: 3.6
- name: Install DELTA
run: |
./scripts/setup.sh
diff --git a/MANIFEST.in b/MANIFEST.in
index 25b7fa53..1abc830e 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,2 +1,4 @@
include delta/config/delta.yaml
include delta/config/networks/*.yaml
+include delta/extensions/sources/snap_process_sentinel1.sh
+include delta/extensions/sources/sentinel1_default_snap_preprocess_graph.xml
diff --git a/README.md b/README.md
index 28479067..efcae9a6 100644
--- a/README.md
+++ b/README.md
@@ -1,83 +1,88 @@
**DELTA** (Deep Earth Learning, Tools, and Analysis) is a framework for deep learning on satellite imagery,
-based on Tensorflow. Use DELTA to train and run neural networks to classify large satellite images. DELTA
-provides pre-trained autoencoders for a variety of satellites to reduce required training data
-and time.
+based on Tensorflow. DELTA classifies large satellite images with neural networks, automatically handling
+tiling large imagery.
DELTA is currently under active development by the
-[NASA Ames Intelligent Robotics Group](https://ti.arc.nasa.gov/tech/asr/groups/intelligent-robotics/). Expect
-frequent changes. It is initially being used to map floods for disaster response, in collaboration with the
+[NASA Ames Intelligent Robotics Group](https://ti.arc.nasa.gov/tech/asr/groups/intelligent-robotics/).
+Initially, it is mapping floods for disaster response, in collaboration with the
[U.S. Geological Survey](http://www.usgs.gov), [National Geospatial Intelligence Agency](https://www.nga.mil/),
[National Center for Supercomputing Applications](http://www.ncsa.illinois.edu/), and
-[University of Alabama](https://www.ua.edu/). DELTA is a component of the
-[Crisis Mapping Toolkit](https://github.com/nasa/CrisisMappingToolkit), in addition
-to our previous software for mapping floods with Google Earth Engine.
+[University of Alabama](https://www.ua.edu/).
Installation
============
-1. Install [python3](https://www.python.org/downloads/), [GDAL](https://gdal.org/download.html), and the [GDAL python bindings](https://pypi.org/project/GDAL/).
- For Ubuntu Linux, you can run `scripts/setup.sh` from the DELTA repository to install these dependencies.
+1. Install [python3](https://www.python.org/downloads/), [GDAL](https://gdal.org/download.html),
+ and the [GDAL python bindings](https://pypi.org/project/GDAL/). For Ubuntu Linux, you can run
+ `scripts/setup.sh` from the DELTA repository to install these dependencies.
-2. Install Tensorflow with pip following the [instructions](https://www.tensorflow.org/install). For
+2. Install Tensorflow following the [instructions](https://www.tensorflow.org/install). For
GPU support in DELTA (highly recommended) follow the directions in the
[GPU guide](https://www.tensorflow.org/install/gpu).
3. Checkout the delta repository and install with pip:
- ```
- git clone http://github.com/nasa/delta
- python3 -m pip install delta
- ```
+```bash
+git clone http://github.com/nasa/delta
+python3 -m pip install delta
+```
+
+DELTA is now installed and ready to use!
+
+Documentation
+=============
+DELTA can be used either as a command line tool or as a python library.
+See the python documentation for the master branch [here](https://nasa.github.io/delta/),
+or generate the documentation with `scripts/docs.sh`.
+
+Example
+=======
+
+As a simple example, consider training a neural network to map clouds with Landsat-8 images.
+The script `scripts/example/l8_cloud.sh` trains such a network using DELTA from the
+[USGS SPARCS dataset](https://www.usgs.gov/core-science-systems/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs),
+and shows how DELTA can be used. The steps involved in this, and other, classification processes are:
+
+1. **Collect** training data. The SPARCS dataset contains Landsat-8 imagery with and without clouds.
- This installs DELTA and all dependencies (except for GDAL which must be installed manually in step 1).
+2. **Label** training data. The SPARCS labels classify each pixel according to cloud, land, water and other classes.
-Usage
-=====
+3. **Train** the neural network. The script `scripts/example/l8_cloud.sh` invokes the command
-As a simple example, consider training a neural network to map water in Worldview imagery.
-You would:
+ ```
+ delta train --config l8_cloud.yaml l8_clouds.h5
+ ```
-1. **Collect** training data. Find and save Worldview images with and without water. For a robust
- classifier, the training data should be as representative as possible of the evaluation data.
+ where `scripts/example/l8_cloud.yaml` is a configuration file specifying the labeled training data and
+ training parameters (learn more about configuration files below). A neural network file
+ `l8_clouds.h5` is output.
-2. **Label** training data. Create images matching the training images pixel for pixel, where each pixel
- in the label is 0 if it is not water and 1 if it is.
+4. **Classify** with the trained network. The script runs
-3. **Train** the neural network. Run
- ```
- delta train --config wv_water.yaml wv_water.h5
- ```
- where `wv_water.yaml` is a configuration file specifying the labeled training data and any
- training parameters (learn more about configuration files below). The command will output a
- neural network file `wv_water.h5` which can be
- used for classification. The neural network operates on the level of *chunks*, inputting
- and output smaller blocks of the image at a time.
+ ```
+ delta classify --config l8_cloud.yaml --image-dir ./validate --overlap 32 l8_clouds.h5
+ ```
-4. **Classify** with the trained network. Run
- ```
- delta classify --image image.tiff wv_water.h5
- ```
- to classify `image.tiff` using the network `wv_water.h5` learned previously.
- The file `image_predicted.tiff` will be written to the current directory showing the resulting labels.
+ to classify the images in the `validate` folder using the network `l8_clouds.h5` learned previously.
+ The overlap tiles to ignore border regions when possible to make a more aesthetically pleasing classified
+ image. The command outputs a predicted image and confusion matrix.
-Configuration Files
--------------------
+The results could be improved--- with more training, more data, an improved network, or more--- but this
+example shows the basic usage of DETLA.
-DELTA is configured with YAML files. Some options can be overwritten with command line options (use
-`delta --help` to see which). [Learn more about DELTA configuration files](./delta/config/README.md).
+Configuration and Extensions
+============================
-All available configuration options and their default values are shown [here](./delta/config/delta.yaml).
-We suggest that users create one reusable configuration file to describe the parameters specific
-to each dataset, and separate configuration files to train on or classify that dataset.
+DELTA provides many options for customizing data inputs and training. All options are configured via
+YAML files. Some options can be overwritten with command line options (use
+`delta --help` to see which). See the `delta.config` README to learn about available configuration
+options.
-Supported Image Formats
------------------------
-DELTA supports tiff files and a few other formats, listed [here](./delta/imagery/sources/README.md).
-Users can extend DELTA with their own custom formats. We are looking to expand DELTA to support other
-useful file formats.
+DELTA can be extended to support custom neural network layers, image types, preprocessing operations, metrics, losses,
+and training callbacks. Learn about DELTA extensions in the `delta.config.extensions` documentation.
-MLFlow
-------
+Data Management
+=============
DELTA integrates with [MLFlow](http://mlflow.org) to track training. MLFlow options can
be specified in the corresponding area of the configuration file. By default, training and
@@ -93,18 +98,6 @@ View all the logged training information through mlflow by running::
and navigating to the printed URL in a browser. This makes it easier to keep track when running
experiments and adjusting parameters.
-Using DELTA from Code
-=====================
-You can also call DELTA as a python library and customize it with your own extensions, for example,
-custom image types. The python API documentation can be generated as HTML. To do so:
-
-```
- pip install pdoc3
- ./scripts/docs.sh
-```
-
-Then open `html/delta/index.html` in a web browser.
-
Contributors
============
We welcome pull requests to contribute to DELTA. However, due to NASA legal restrictions, we must require
diff --git a/delta/config/README.md b/delta/config/README.md
index e7ff0f42..db2b7e08 100644
--- a/delta/config/README.md
+++ b/delta/config/README.md
@@ -5,9 +5,33 @@ all options, showing all parameters DELTA and their default values, see [delta.y
`delta` accepts multiple config files on the command line. For example, run
- delta train --config dataset.yaml --config train.yaml
+```bash
+delta train --config dataset.yaml --config train.yaml
+```
+
+to train on a dataset specified by `dataset.yaml`:
+
+```yaml
+dataset:
+ images:
+ type: tiff
+ directory: train/
+ labels:
+ type: tiff
+ directory: labels/
+ classes: 2
+```
+
+with training parameters given in `train.yaml`:
+
+```yaml
+train:
+ network:
+ model:
+ yaml_file: networks/convpool.yaml
+ epochs: 10
+```
-to train on a dataset specified by `dataset.yaml` with training parameters given in `train.yaml`.
Parameters can be overriden globally for all runs of `delta` as well, by placing options in
`$HOME/.config/delta/delta.yaml` on Linux. This is only recommended for global parameters
such as the cache directory.
@@ -17,8 +41,7 @@ only setting the necessary options.
Note that some configuration options can be overwritten on the command line: run
`delta --help` to see which.
-The remainder of this document details the available configuration parameters. Note that
-DELTA is still under active development and parts are likely to change in the future.
+The remainder of this document details the available configuration parameters.
Dataset
-----------------
@@ -26,71 +49,129 @@ Images and labels are specified with the `images` and `labels` fields respective
within `dataset`. Both share the
same underlying options.
- * `type`: Indicates which loader to use, e.g., `tiff` for geotiff.
- The available loaders are listed [here](../imagery/sources/README.md).
+ * `type`: Indicates which `delta.imagery.delta_image.DeltaImage` image reader to use, e.g., `tiff` for geotiff.
+ The reader should previously be registered with `delta.config.extensions.register_image_reader`.
* Files to load must be specified in one of three ways:
- * `directory` and `extension`: Use all images in the directory ending with the given extension.
- * `file_list`: Provide a text file with one image file name per line.
- * `files`: Provide a list of file names in yaml.
- * `preprocess`: Supports limited image preprocessing. Currently only scaling is supported. We recommend
+ * `directory` and `extension`: Use all images in the directory ending with the given extension.
+ * `file_list`: Provide a text file with one image file name per line.
+ * `files`: Provide a list of file names in yaml.
+ * `preprocess`: Specify a preprocessing chain. We recommend
scaling input imagery in the range 0.0 to 1.0 for best results with most of our networks.
- * `enabled`: Turn preprocessing on or off.
- * `scale_factor`: Factor to scale all readings by.
+ DELTA also supports custom preprocessing commands. Default actions include:
+ * `scale` with `factor` argument: Divide all values by amount.
+ * `offset` with `factor` argument: Add `factor` to pixel values.
+ * `clip` with `bounds` argument: clip all pixels to bounds.
+ Preprocessing commands are registered with `delta.config.extensions.register_preprocess`.
+ A full list of defaults (and examples of how to create new ones) can be found in `delta.extensions.preprocess`.
* `nodata_value`: A pixel value to ignore in the images.
+ * `classes`: Either an integer number of classes or a list of individual classes. If individual classes are specified,
+ each list item should be the pixel value of the class in the label images, and a dictionary with the
+ following potential attributes (see example below):
+ * `name`: Name of the class.
+ * `color`: Integer to use as the RGB representation for some classification options.
+ * `weight`: How much to weight the class during training (useful for underrepresented classes).
As an example:
- ```
- dataset:
- images:
- type: worldview
- directory: images/
- labels:
- type: tiff
- directory: labels/
- extension: _label.tiff
- ```
-
-This configuration will load worldview files ending in `.zip` from the `images/` directory.
+```yaml
+dataset:
+ images:
+ type: tiff
+ directory: images/
+ preprocess:
+ - scale:
+ factor: 256.0
+ nodata_value: 0
+ labels:
+ type: tiff
+ directory: labels/
+ extension: _label.tiff
+ nodata_value: 0
+ classes:
+ - 1:
+ name: Cloud
+ color: 0x0000FF
+ weight: 2.0
+ - 2:
+ name: Not Cloud
+ color: 0xFFFFFF
+ weight: 1.0
+```
+
+This configuration will load tiff files ending in `.tiff` from the `images/` directory.
It will then find matching tiff files ending in `_label.tiff` from the `labels` directory
-to use as labels.
+to use as labels. The image values will be divied by a factor of 256 before they are used.
+(It is often helpful to scale images to a range of 0-1 before training.) The labels represent two classes:
+clouds and non-clouds. Since there are fewer clouds, these are weighted more havily. The label
+images should contain 0 for nodata, 1 for cloud pixels, and 2 for non-cloud pixels.
Train
-----
These options are used in the `delta train` command.
- * `network`: The nueral network to train. See the next section for details.
- * `chunk_stride`: When collecting training samples, skip every `n` pixels between adjacent blocks. Keep the
- default of 1 to use all available training data.
- * `batch_size`: The number of chunks to train on in a group. May affect convergence speed. Larger
- batches allow higher training data throughput, but may encounter memory limitations.
+ * `network`: The nueral network to train. One of `yaml_file` or `layers` must be specified.
+ * `yaml_file`: A path to a yaml file with only the params and layers fields. See `delta/config/networks`
+ for examples.
+ * `params`: A dictionary of parameters to substitute in the `layers` field.
+ * `layers`: A list of layers which compose the network. See the following section for details.
+ * `stride`: When collecting training samples, skip every `n` pixels between adjacent blocks. Keep the
+ default of ~ or 1 to use all available training data. Not used for fully convolutional networks.
+ * `batch_size`: The number of patches to train on at a time. If running out of memory, reducing
+ batch size may be helpful.
* `steps`: If specified, stop training for each epoch after the given number of batches.
* `epochs`: the number of times to iterate through all training data during training.
- * `loss_function`: [Keras loss function](https://keras.io/losses/). For integer classes, use
- `sparse_categorical_cross_entropy`.
- * `metrics`: A list of [Keras metrics](https://keras.io/metrics/) to evaluate.
- * `optimizer`: The [Keras optimizer](https://keras.io/optimizers/) to use.
+ * `loss`: [Keras loss function](https://keras.io/losses/). For integer classes, use
+ `sparse_categorical_cross_entropy`. May be specified either as a string, or as a dictionary
+ with arguments to pass to the loss function constructor. Custom losses registered with
+ `delta.config.extensions.register_loss` may be used.
+ * `metrics`: A list of [Keras metrics](https://keras.io/metrics/) to evaluate. Either the string
+ name or a dictionary with the constructor arguments may be used. Custom metrics registered with
+ `delta.config.extensions.register_metric` or loss functions may also be used.
+ * `optimizer`: The [Keras optimizer](https://keras.io/optimizers/) to use. May be specified as a string or
+ as a dictionary with constructor parameters.
+ * `callbacks`: A list of [Keras callbacks)(https://keras.io/api/callbacks/) to use during training, specified as
+ either a string or as a dictionary with constructor parameters. Custom callbacks registered with
+ `delta.config.extensions.register_metric` may also be used.
* `validation`: Specify validation data. The validation data is tested after each epoch to evaluate the
classifier performance. Always use separate training and validation data!
* `from_training` and `steps`: If `from_training` is true, take the `steps` training batches
and do not use it for training but for validation instead.
* `images` and `labels`: Specified using the same format as the input data. Use this imagery as testing data
if `from_training` is false.
+ * `log_folder` and `resume_cutoff`: If log_folder is specified, store read records of how much of each image
+ has been trained on in this folder. If the number of reads exceeds resume_cutoff, skip the tile when resuming
+ training. This allows resuming training skipping part of an epoch. You should generally not bother using this except
+ on very large training sets (thousands of large images).
### Network
-These options configure the neural network to train with the `delta train` command.
+For the `layers` attribute, any [Keras Layer](https://keras.io/api/layers/) can
+be used, including custom layers registered with `delta.config.extensions.register_layer`.
- * `chunk_size`: The width and height of each chunks to input to the neural network
- * `output_size`: The width and height of the output from the neural network for each chunk
- * `classes`: The number of classes in the input data. The classes must currently have values
- 0 - n in the label images.
- * `model`: The network structure specification.
- folder. You can either point to another `yaml_file`, such as the ones in the delta/config/networks
- directory, or specify one under the `model` field in the same format as these files. The network
- layers are specified using the [Keras functional layers API](https://keras.io/layers/core/)
- converted to YAML files.
+Sub-fields of the layer are argument names and values which are passed to the layer's constructor.
+A special sub-field, `inputs`, is a list of the names of layers to pass as inputs to this layer.
+If `inputs` is not specified, the previous layer is used by default. Layer names can be specified `name`.
+
+```yaml
+layers:
+ Input:
+ shape: [~, ~, num_bands]
+ name: input
+ Add:
+ inputs: [input, input]
+```
+
+This simple example takes an input and adds it to itself.
+
+Since this network takes inputs of variable size ((~, ~, `num_bands`) is the input shape) it is a **fully
+convolutional network**. This means that during training and classification, it will be evaluated on entire
+tiles rather than smaller chunks.
+
+A few special parameters are available by default:
+
+ * `num_bands`: The number of bands / channels in an image.
+ * `num_classes`: The number of classes provided in dataset.classes.
MLFlow
------
@@ -120,9 +201,18 @@ General
-------
* `gpus`: The number of GPUs to use, or `-1` for all.
+ * `verbose`: Trigger verbose printing.
+ * `extensions`: List of extensions to load. Add custom modules here and they will be loaded when
+ delta starts.
+
+I/O
+-------
* `threads`: The number of threads to use for loading images into tensorflow.
- * `block_size_mb`: The size of blocks in images to load at a time. If too small may be data starved.
- * `tile_ratio` The ratio of block width and height when loading images. Can affect disk use efficiency.
- * `cache`: Configure cacheing options. The subfield `dir` specifies a directory on disk to store cached files,
- and `limit` is the number of files to retain in the cache. Used mainly for image types
- which much be extracted from archive files.
+ * `tile_size`: The size of a tile to load into memory at a time. For fully convolutional networks, the
+ entire tile will be processed at a time, for others it will be chunked.
+ * `interleave_images`: The number of images to interleave between. If this value is three, three images will
+ be opened at a time. Chunks / tiles will be interleaved from the first three tiles until one is completed, then
+ a new image will be opened. Larger interleaves can aid training (but comes at a cost in memory).
+ * `cache`: Options for a cache, which is used by a few image types (currently worldview and landsat).
+ * `dir`: Directory to store the cache. `default` gives a reasonable OS-specific default.
+ * `limit`: Maximum number of items to store in the cache before deleting old entries.
diff --git a/delta/config/__init__.py b/delta/config/__init__.py
index 87a2c7d5..36ae09e3 100644
--- a/delta/config/__init__.py
+++ b/delta/config/__init__.py
@@ -18,6 +18,8 @@
"""
Configuration via YAML files and command line options.
+.. include:: README.md
+
Access the singleton `delta.config.config` to get configuration
values, specified either in YAML files or on the command line,
and to load additional YAML files.
@@ -26,4 +28,4 @@
`delta/config/delta.yaml`.
"""
-from .config import config, DeltaConfigComponent, validate_path, validate_positive
+from .config import config, DeltaConfigComponent, validate_path, validate_positive, validate_non_negative
diff --git a/delta/config/config.py b/delta/config/config.py
index bbb2127f..e3d977d6 100644
--- a/delta/config/config.py
+++ b/delta/config/config.py
@@ -14,14 +14,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Loading configuration from command line arguments and yaml files.
+
+Most users will want to use the global object `delta.config.config.config`
+to access configuration parameters.
+"""
import os.path
+from typing import Any, Callable, List, Optional, Tuple, Union
import yaml
import pkg_resources
import appdirs
-def validate_path(path, base_dir):
+def validate_path(path: str, base_dir: str) -> str:
+ """
+ Normalizes a path.
+
+ Parameters
+ ----------
+ path: str
+ Input path
+ base_dir: str
+ The base directory for relative paths.
+
+ Returns
+ -------
+ str
+ The normalized path.
+ """
if path == 'default':
return path
path = os.path.expanduser(path)
@@ -30,11 +52,46 @@ def validate_path(path, base_dir):
path = os.path.normpath(os.path.join(base_dir, path))
return path
-def validate_positive(num, _):
+def validate_positive(num: Union[int, float], _: str) -> Union[int, float]:
+ """
+ Checks that a number is positive.
+
+ Parameters
+ ----------
+ num: Union[int, float]
+ Input number
+ _: str
+ Unused base path.
+
+ Raises
+ ------
+ ValueError
+ If number is not positive.
+ """
if num <= 0:
raise ValueError('%d is not positive' % (num))
return num
+def validate_non_negative(num: Union[int, float], _: str) -> Union[int, float]:
+ """
+ Checks that a number is not negative.
+
+ Parameters
+ ----------
+ num: Union[int, float]
+ Input number
+ _: str
+ Unused base path.
+
+ Raises
+ ------
+ ValueError
+ If number is negative.
+ """
+ if num < 0:
+ raise ValueError('%d is negative' % (num))
+ return num
+
class _NotSpecified: #pylint:disable=too-few-public-methods
pass
@@ -44,15 +101,15 @@ class DeltaConfigComponent:
Handles one subsection of a config file. Generally subclasses
will want to register fields and components in the constructor,
- and possibly override setup_arg_parser and parse_args to handle
+ and possibly override `setup_arg_parser` and `parse_args` to handle
command line options.
-
- section_header is the title of the section for command line
- arguments in the help.
"""
- def __init__(self, section_header = None):
+ def __init__(self, section_header: Optional[str] = None):
"""
- Constructs the component.
+ Parameters
+ ----------
+ section_header: Optional[str]
+ The title of the section for command line arguments in the help.
"""
self._config_dict = {}
self._components = {}
@@ -71,9 +128,18 @@ def reset(self):
for c in self._components.values():
c.reset()
- def register_component(self, component, name : str, attr_name = None):
+ def register_component(self, component: 'DeltaConfigComponent', name : str, attr_name: Optional[str] = None):
"""
- Register a subcomponent with a name and attribute name (access as self.attr_name)
+ Register a subcomponent.
+
+ Parameters
+ ----------
+ component: DeltaConfigComponent
+ The subcomponent to add.
+ name: str
+ Name of the subcomponent. Must be unique.
+ attr_name: Optional[str]
+ If specified, can access the component as self.attr_name.
"""
assert name not in self._components
self._components[name] = component
@@ -81,16 +147,25 @@ def register_component(self, component, name : str, attr_name = None):
attr_name = name
setattr(self, attr_name, component)
- def register_field(self, name : str, types, accessor = None, validate_fn = None, desc = None):
+ def register_field(self, name: str, types: Union[type, Tuple[type, ...]], accessor: Optional[str] = None,
+ validate_fn: Optional[Callable[[Any, str], Any]] = None, desc = None):
"""
Register a field in this component of the configuration.
- types is a single type or a tuple of valid types
-
- validate_fn (optional) should take two strings as input, the field's value and
- the base directory, and return what to save to the config dictionary.
- It should raise an exception if the field is invalid.
- accessor is an optional name to create an accessor function with
+ Parameters
+ ----------
+ name: str
+ Name of the field (must be unique).
+ types: type or tuple of types
+ Valid type or types for the field.
+ accessor: Optional[str]
+ If set, defines a function self.accessor() which retrieves the field.
+ validate_fn: Optional[Callable[[Any, str], Any]]
+ If specified, sets input = validate_fn(input, base_path) before using it, where
+ base_path is the current directory. The validate function should raise an error
+ if the input is invalid.
+ desc: Optional[str]
+ A description to use in help messages.
"""
self._fields.append(name)
self._validate[name] = validate_fn
@@ -103,16 +178,25 @@ def access(self) -> types:
access.__doc__ = desc
setattr(self.__class__, accessor, access)
- def register_arg(self, field, argname, **kwargs):
+ def register_arg(self, field: str, argname: str, options_name: Optional[str] =None, **kwargs):
"""
- Registers a command line argument in this component.
-
- field is the (registered) field this argument modifies.
- argname is the name of the flag on the command line (i.e., '--flag')
- **kwargs are arguments to ArgumentParser.add_argument.
-
- If help and type are not specified, will use the ones for the field.
- If default is not specified, will use the value from the config files.
+ Registers a command line argument in this component. Command line arguments override the
+ values in the config files when specified.
+
+ Parameters
+ ----------
+ field: str
+ The previously registered field this argument modifies.
+ argname: str
+ The name of the flag on the command line (i.e., '--flag')
+ options_name: Optional[str]
+ Name stored in the options object. It defaults to the
+ field if not specified. Only needed for duplicates, such as for multiple image
+ specifications.
+ **kwargs:
+ Further arguments are passed to ArgumentParser.add_argument.
+ If `help` and `type` are not specified, will use the values from field registration.
+ If `default` is not specified, will use the value from the config files.
"""
assert field in self._fields, 'Field %s not registered.' % (field)
if 'help' not in kwargs:
@@ -123,7 +207,7 @@ def register_arg(self, field, argname, **kwargs):
del kwargs['type']
if 'default' not in kwargs:
kwargs['default'] = _NotSpecified
- self._cmd_args[argname] = (field, kwargs)
+ self._cmd_args[argname] = (field, field if options_name is None else options_name, kwargs)
def to_dict(self) -> dict:
"""
@@ -150,48 +234,62 @@ def _set_field(self, name : str, value : str, base_dir : str):
if self._validate[name] and value is not None:
try:
value = self._validate[name](value, base_dir)
- except:
- print('Value %s for %s is invalid.' % (value, name))
- raise
+ except Exception as e:
+ raise AssertionError('Value %s for %s is invalid.' % (value, name)) from e
self._config_dict[name] = value
def _load_dict(self, d : dict, base_dir):
"""
Loads the dictionary d, assuming it came from the given base_dir (for relative paths).
"""
+ if not d:
+ return
for (k, v) in d.items():
if k in self._components:
self._components[k]._load_dict(v, base_dir) #pylint:disable=protected-access
else:
self._set_field(k, v, base_dir)
- def setup_arg_parser(self, parser, components = None) -> None:
+ def setup_arg_parser(self, parser : 'argparse.ArgumentParser', components: Optional[List[str]] = None) -> None:
"""
- Adds arguments to the parser. Must overridden by child classes.
+ Adds arguments to the parser. May be overridden by child classes.
+
+ Parameters
+ ----------
+ parser: argparse.ArgumentParser
+ The praser to set up arguments with and later pass the command line flags to.
+ components: Optional[List[str]]
+ If specified, only parse arguments from the given components, specified by name.
"""
if self._section_header is not None:
parser = parser.add_argument_group(self._section_header)
for (arg, value) in self._cmd_args.items():
- (field, kwargs) = value
- parser.add_argument(arg, dest=field, **kwargs)
+ (_, options_name, kwargs) = value
+ parser.add_argument(arg, dest=options_name, **kwargs)
for (name, c) in self._components.items():
if components is None or name in components:
c.setup_arg_parser(parser)
- def parse_args(self, options):
+ def parse_args(self, options: 'argparse.Namespace'):
"""
- Parse options extracted from an ArgParser configured with
+ Parse options extracted from an `argparse.ArgumentParser` configured with
`setup_arg_parser` and override the appropriate
configuration values.
+
+ Parameters
+ ----------
+ options: argparse.Namespace
+ Options returned from a call to parse_args on a parser initialized with
+ setup_arg_parser.
"""
d = {}
- for (field, _) in self._cmd_args.values():
- if not hasattr(options, field) or getattr(options, field) is None:
+ for (field, options_name, _) in self._cmd_args.values():
+ if not hasattr(options, options_name) or getattr(options, options_name) is None:
continue
- if getattr(options, field) is _NotSpecified:
+ if getattr(options, options_name) is _NotSpecified:
continue
- d[field] = getattr(options, field)
+ d[field] = getattr(options, options_name)
self._load_dict(d, None)
for c in self._components.values():
@@ -199,20 +297,25 @@ def parse_args(self, options):
class DeltaConfig(DeltaConfigComponent):
"""
- DELTA configuration manager.
-
- Access and control all configuration parameters.
+ DELTA configuration manager. Access and control all configuration parameters.
"""
- def load(self, yaml_file: str = None, yaml_str: str = None):
+ def load(self, yaml_file: Optional[str] = None, yaml_str: Optional[str] = None):
"""
Loads a config file, then updates the default configuration
with the loaded values.
+
+ Parameters
+ ----------
+ yaml_file: Optional[str]
+ Filename of a yaml file to load.
+ yaml_str: Optional[str]
+ Load yaml directly from a str. Exactly one of `yaml_file` and `yaml_str`
+ must be specified.
"""
base_path = None
if yaml_file:
- #print("Loading config file: " + yaml_file)
if not os.path.exists(yaml_file):
- raise Exception('Config file does not exist: ' + yaml_file)
+ raise FileNotFoundError('Config file does not exist: ' + yaml_file)
with open(yaml_file, 'r') as f:
config_data = yaml.safe_load(f)
base_path = os.path.normpath(os.path.dirname(yaml_file))
@@ -234,10 +337,16 @@ def reset(self):
super().reset()
self.load(pkg_resources.resource_filename('delta', 'config/delta.yaml'))
- def initialize(self, options, config_files = None):
+ def initialize(self, options: 'argparse.Namespace', config_files: Optional[List[str]] = None):
"""
- Loads the default files unless config_files is specified, in which case it
- loads them. Then loads options (from argparse).
+ Loads all config files, then parses all command line arguments.
+ Parameters
+ ----------
+ options: argparse.Namespace
+ Command line options from `setup_arg_parser` to parse.
+ config_files: Optional[List[str]]
+ If specified, loads only the listed files. Otherwise, loads the default config
+ files.
"""
self.reset()
@@ -254,3 +363,4 @@ def initialize(self, options, config_files = None):
config.parse_args(options)
config = DeltaConfig()
+"""Global config object. Use this to access all configuration."""
diff --git a/delta/config/delta.yaml b/delta/config/delta.yaml
index 82dbf2f8..ec55319c 100644
--- a/delta/config/delta.yaml
+++ b/delta/config/delta.yaml
@@ -1,32 +1,28 @@
general:
# negative is all
gpus: -1
- stop_on_input_error: true # If false skip past bad input files without halting training
+ # print more debug information
+ verbose: false
+ extensions:
+ - delta.extensions
io:
threads: 1
- # size to load in memory at a time (width x height x bands x bit-width x chunk-size^2)
- block_size_mb: 1
+ # size of tile to load into memory at a time as [rows, columns]
+ tile_size: [256, 1024]
# number of different images to interleave at a time when loading
interleave_images: 5
- # ratio of tile width and height when loading images
- tile_ratio: 5.0
- # When resuming training with a log_folder, skip input image where we have
- # already loaded this many tiles.
- resume_cutoff: 5000
cache:
# default is OS-specific, in Linux, ~/.cache/delta
dir: default
limit: 8
dataset:
- log_folder: ~ # Storage location for any record keeping files about the input dataset
images:
type: tiff
# preprocess the images when loading (i.e., scaling)
preprocess:
- enabled: true
- scale_factor: default
+ - scale
nodata_value: ~
directory: ~
extension: default
@@ -35,9 +31,7 @@ dataset:
labels:
type: tiff
- preprocess:
- enabled: false
- scale_factor: default
+ preprocess: ~
nodata_value: ~
directory: ~
extension: default
@@ -69,20 +63,20 @@ dataset:
train:
network:
- chunk_size: 16
- output_size: 8
- model:
- yaml_file: networks/convpool.yaml
- params: ~
- layers: ~
- chunk_stride: 1
+ yaml_file: networks/convpool.yaml
+ params: ~
+ layers: ~
+ stride: ~
batch_size: 500
steps: ~ # number of batches to train on (or ~ for all)
epochs: 5
- loss_function: sparse_categorical_crossentropy
+ loss: sparse_categorical_crossentropy
metrics:
- sparse_categorical_accuracy
- optimizer: adam
+ optimizer:
+ Adam:
+ learning_rate: 0.001
+ callbacks: ~
validation:
steps: 1000
# if true, skips the first steps from the training set to use for validation instead
@@ -91,8 +85,7 @@ train:
images:
type: tiff
preprocess:
- enabled: true
- scale_factor: default
+ - scale
nodata_value: ~
directory: ~
extension: default
@@ -100,14 +93,16 @@ train:
files: ~
labels:
type: tiff
- preprocess:
- enabled: false
- scale_factor: default
+ preprocess: ~
nodata_value: ~
directory: ~
extension: default
file_list: ~
files: ~
+ log_folder: ~ # Storage location for any record keeping files about the input dataset
+ # When resuming training with a log_folder, skip input image where we have
+ # already loaded this many tiles.
+ resume_cutoff: 10
mlflow:
# default to ~/.local/share/delta/mlflow
diff --git a/delta/config/extensions.py b/delta/config/extensions.py
new file mode 100644
index 00000000..ed009530
--- /dev/null
+++ b/delta/config/extensions.py
@@ -0,0 +1,307 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Manage extensions to DELTA.
+
+To extend delta, add the name for your extension to the `extensions` field
+in a DELTA config file. It will then be imported when DELTA loads.
+The named python module should then call the appropriate registration
+function (e.g., `register_layer` to register a custom Keras layer) and
+the extensions can be used like existing DELTA options.
+
+All extensions can take keyword arguments that can be specified in the config file.
+"""
+
+#pylint:disable=global-statement
+
+import importlib
+
+__extensions_to_load = set()
+__layers = {}
+__readers = {}
+__writers = {}
+__losses = {}
+__metrics = {}
+__callbacks = {}
+__prep_funcs = {}
+
+def __initialize():
+ """
+ This function is called before each use of extensions to import
+ the needed modules. This is only done at first use to not delay loading.
+ """
+ global __extensions_to_load
+ while __extensions_to_load:
+ ext = __extensions_to_load.pop()
+ importlib.import_module(ext)
+
+def register_extension(name : str):
+ """
+ Register an extension python module.
+ For internal use --- users should use the config files.
+
+ Parameters
+ ----------
+ name: str
+ Name of the extension to load.
+ """
+ global __extensions_to_load
+ __extensions_to_load.add(name)
+
+def register_layer(layer_type : str, layer_constructor):
+ """
+ Register a custom layer for use by DELTA.
+
+ Parameters
+ ----------
+ layer_type: str
+ Name of the layer.
+ layer_constructor
+ Either a class extending
+ [tensorflow.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerFunction),
+ or a function that returns a function that inputs and outputs tensors.
+
+ See Also
+ --------
+ delta.ml.train.DeltaLayer : Layer wrapper with Delta extensions
+ """
+ global __layers
+ __layers[layer_type] = layer_constructor
+
+def register_image_reader(image_type : str, image_class):
+ """
+ Register a custom image type for reading by DELTA.
+
+ Parameters
+ ----------
+ image_type: str
+ Name of the image type.
+ image_class: Type[`delta.imagery.delta_image.DeltaImage`]
+ A class that extends `delta.imagery.delta_image.DeltaImage`.
+ """
+ global __readers
+ __readers[image_type] = image_class
+
+def register_image_writer(image_type : str, writer_class):
+ """
+ Register a custom image type for writing by DELTA.
+
+ Parameters
+ ----------
+ image_type: str
+ Name of the image type.
+ writer_class: Type[`delta.imagery.delta_image.DeltaImageWriter`]
+ A class that extends `delta.imagery.delta_image.DeltaImageWriter`.
+ """
+ global __writers
+ __writers[image_type] = writer_class
+
+def register_loss(loss_type : str, custom_loss):
+ """
+ Register a custom loss function for use by DELTA.
+
+ Note that loss functions can also be used as metrics.
+
+ Parameters
+ ----------
+ loss_type: str
+ Name of the loss function.
+ custom_loss:
+ Either a loss extending [Loss](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss) or a
+ function of the form loss(y_true, y_pred) which returns a tensor of the loss.
+ """
+ global __losses
+ __losses[loss_type] = custom_loss
+
+def register_metric(metric_type : str, custom_metric):
+ """
+ Register a custom metric for use by DELTA.
+
+ Parameters
+ ----------
+ metric_type: str
+ Name of the metric.
+ custom_metric: Type[`tensorflow.keras.metrics.Metric`]
+ A class extending [Metric](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric).
+ """
+ global __metrics
+ __metrics[metric_type] = custom_metric
+
+def register_callback(cb_type : str, cb):
+ """
+ Register a custom training callback for use by DELTA.
+
+ Parameters
+ ----------
+ cb_type: str
+ Name of the callback.
+ cb: Type[`tensorflow.keras.callbacks.Callback`]
+ A class extending [Callback](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback)
+ or a function that returns one.
+ """
+ global __callbacks
+ __callbacks[cb_type] = cb
+
+def register_preprocess(function_name : str, prep_function):
+ """
+ Register a preprocessing function for use in delta.
+
+ Parameters
+ ----------
+ function_name: str
+ Name of the preprocessing function.
+ prep_function:
+ A function of the form prep_function(data, rectangle, bands_list), where data is an input
+ numpy array, rectangle a `delta.imagery.rectangle.Rectangle` specifying the region covered by data,
+ and bands_list is an integer list of bands loaded. The function must return a numpy array.
+ """
+ global __prep_funcs
+ __prep_funcs[function_name] = prep_function
+
+def layer(layer_type : str):
+ """
+ Retrieve a custom layer by name.
+
+ Parameters
+ ----------
+ layer_type: str
+ Name of the layer.
+
+ Returns
+ -------
+ Layer
+ The previously registered layer.
+ """
+ __initialize()
+ return __layers.get(layer_type)
+
+def loss(loss_type : str):
+ """
+ Retrieve a custom loss by name.
+
+ Parameters
+ ----------
+ loss_type: str
+ Name of the loss function.
+
+ Returns
+ -------
+ Loss
+ The previously registered loss function.
+ """
+ __initialize()
+ return __losses.get(loss_type)
+
+def metric(metric_type : str):
+ """
+ Retrieve a custom metric by name.
+
+ Parameters
+ ----------
+ metric_type: str
+ Name of the metric.
+
+ Returns
+ -------
+ Metric
+ The previously registered metric.
+ """
+ __initialize()
+ return __metrics.get(metric_type)
+
+def callback(cb_type : str):
+ """
+ Retrieve a custom callback by name.
+
+ Parameters
+ ----------
+ cb_type: str
+ Name of the callback function.
+
+ Returns
+ -------
+ Callback
+ The previously registered callback.
+ """
+ __initialize()
+ return __callbacks.get(cb_type)
+
+def preprocess_function(prep_type : str):
+ """
+ Retrieve a custom preprocessing function by name.
+
+ Parameters
+ ----------
+ prep_type: str
+ Name of the preprocessing function.
+
+ Returns
+ -------
+ Preprocessing Function
+ The previously registered preprocessing function.
+ """
+ __initialize()
+ return __prep_funcs.get(prep_type)
+
+def image_reader(reader_type : str):
+ """
+ Get the reader of the given type.
+
+ Parameters
+ ----------
+ reader_type: str
+ Name of the image reader.
+
+ Returns
+ -------
+ Type[`delta.imagery.delta_image.DeltaImage`]
+ The previously registered image reader.
+ """
+ __initialize()
+ return __readers.get(reader_type)
+
+def image_writer(writer_type : str):
+ """
+ Get the writer of the given type.
+
+ Parameters
+ ----------
+ writer_type: str
+ Name of the image writer.
+
+ Returns
+ -------
+ Type[`delta.imagery.delta_image.DeltaImageWriter`]
+ The previously registered image writer.
+ """
+ __initialize()
+ return __writers.get(writer_type)
+
+def custom_objects():
+ """
+ Returns a dictionary of all supported custom objects for use
+ by tensorflow. Passed as an argument to load_model.
+
+ Returns
+ -------
+ dict
+ A dictionary of registered custom tensorflow objects.
+ """
+ __initialize()
+ d = __layers.copy()
+ d.update(__losses.copy())
+ return d
diff --git a/delta/config/modules.py b/delta/config/modules.py
index 2b53419d..d50316e7 100644
--- a/delta/config/modules.py
+++ b/delta/config/modules.py
@@ -21,13 +21,45 @@
import delta.imagery.imagery_config
import delta.ml.ml_config
+from .config import config, DeltaConfigComponent
+from .extensions import register_extension
+
+class ExtensionsConfig(DeltaConfigComponent):
+ """
+ Configuration component for extensions.
+ """
+ def __init__(self):
+ super().__init__()
+
+ # register immediately, don't override
+ def _load_dict(self, d : dict, base_dir):
+ if not d:
+ return
+ if isinstance(d, list):
+ for ext in d:
+ register_extension(ext)
+ elif isinstance(d, str):
+ register_extension(d)
+ else:
+ raise ValueError('extensions should be a list or string.')
_config_initialized = False
def register_all():
+ """
+ Register all default config modules.
+ """
global _config_initialized #pylint: disable=global-statement
# needed to call twice when testing subcommands and when not
if _config_initialized:
return
+ config.register_component(DeltaConfigComponent('General'), 'general')
+ config.general.register_component(ExtensionsConfig(), 'extensions')
+ config.general.register_field('extensions', list, 'extensions', None,
+ 'Python modules to import as extensions.')
+ config.general.register_field('verbose', bool, 'verbose', None,
+ 'Print debugging information.')
+ config.general.register_arg('verbose', '--verbose', action='store_const',
+ const=True, type=None)
delta.imagery.imagery_config.register()
delta.ml.ml_config.register()
_config_initialized = True
diff --git a/delta/config/networks/autoencoder_conv.yaml b/delta/config/networks/autoencoder_conv.yaml
index 751c0c1c..d1686df7 100644
--- a/delta/config/networks/autoencoder_conv.yaml
+++ b/delta/config/networks/autoencoder_conv.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 50
kernel_size: [3, 3]
diff --git a/delta/config/networks/autoencoder_conv_med_filters.yaml b/delta/config/networks/autoencoder_conv_med_filters.yaml
index 505deaf6..a80dd69a 100644
--- a/delta/config/networks/autoencoder_conv_med_filters.yaml
+++ b/delta/config/networks/autoencoder_conv_med_filters.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 50
kernel_size: [5, 5]
diff --git a/delta/config/networks/autoencoder_conv_wide_filters.yaml b/delta/config/networks/autoencoder_conv_wide_filters.yaml
index 7548cd72..3f09f8b3 100644
--- a/delta/config/networks/autoencoder_conv_wide_filters.yaml
+++ b/delta/config/networks/autoencoder_conv_wide_filters.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 50
kernel_size: [7, 7]
diff --git a/delta/config/networks/autoencoder_dense.yaml b/delta/config/networks/autoencoder_dense.yaml
index 5524a3d3..d793f4c8 100644
--- a/delta/config/networks/autoencoder_dense.yaml
+++ b/delta/config/networks/autoencoder_dense.yaml
@@ -1,13 +1,12 @@
layers:
- Input:
- shape: in_shape
+ shape: [8, 8, 3]
- Flatten:
- input_shape: in_shape
- Dense:
units: 300
activation: relu
- Dense:
- units: in_dims
+ units: 192
activation: relu
- Reshape:
- target_shape: in_shape
+ target_shape: [8, 8, 3]
diff --git a/delta/config/networks/conv_autoencoder_128_chunk.yaml b/delta/config/networks/conv_autoencoder_128_chunk.yaml
index 1b8ac40f..cf79e5f5 100644
--- a/delta/config/networks/conv_autoencoder_128_chunk.yaml
+++ b/delta/config/networks/conv_autoencoder_128_chunk.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 100
kernel_size: [7, 7]
diff --git a/delta/config/networks/convpool.yaml b/delta/config/networks/convpool.yaml
index c1f8d0c9..f01bef4c 100644
--- a/delta/config/networks/convpool.yaml
+++ b/delta/config/networks/convpool.yaml
@@ -1,9 +1,10 @@
params:
dropout_rate: 0.3
num_filters: 64
+ out_dims: 27 # 3 * 3 * 3
layers:
- Input:
- shape: in_shape
+ shape: [5, 5, num_bands]
- Conv2D:
filters: num_filters
kernel_size: [5, 5]
@@ -27,6 +28,6 @@ layers:
units: out_dims
activation: linear
- Reshape:
- target_shape: out_shape
+ target_shape: [3, 3, num_classes]
- Softmax:
axis: 3
diff --git a/delta/config/networks/efficientnet_autoencoder.yaml b/delta/config/networks/efficientnet_autoencoder.yaml
new file mode 100644
index 00000000..da347c69
--- /dev/null
+++ b/delta/config/networks/efficientnet_autoencoder.yaml
@@ -0,0 +1,128 @@
+layers:
+ - Input:
+ shape: [~, ~, num_bands]
+ - EfficientNet:
+ input_shape: [~, ~, num_bands]
+ name: efficientnet
+ width_coefficient: 1.1
+ depth_coefficient: 1.2
+ - UpSampling2D:
+ size: 2
+ - Conv2D:
+ filters: 240
+ kernel_size: [2, 2]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 240
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 240
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - UpSampling2D:
+ size: 2
+ - Conv2D:
+ filters: 240
+ kernel_size: [2, 2]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 96
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 96
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - UpSampling2D:
+ size: 2
+ - Conv2D:
+ filters: 96
+ kernel_size: [2, 2]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 48
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 48
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - UpSampling2D:
+ size: 2
+ - Conv2D:
+ filters: 48
+ kernel_size: [2, 2]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 32
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 32
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - UpSampling2D:
+ size: 2
+ - Conv2D:
+ filters: 32
+ kernel_size: [2, 2]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 16
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: 16
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ - Conv2D:
+ filters: num_bands
+ kernel_size: [3, 3]
+ activation: relu
+ padding: same
diff --git a/delta/config/networks/fcn8s-voc.yaml b/delta/config/networks/fcn8s-voc.yaml
new file mode 100644
index 00000000..bf92a48f
--- /dev/null
+++ b/delta/config/networks/fcn8s-voc.yaml
@@ -0,0 +1,203 @@
+# fully convolutional neural network
+# based on https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn8s/net.py
+params:
+ num_classes: 3
+layers:
+ - Input:
+ shape: [~, ~, num_bands]
+
+ - Conv2D:
+ filters: 64
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Conv2D:
+ filters: 64
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - MaxPooling2D:
+ name: pool1
+ pool_size: [2, 2]
+ strides: 2
+
+ - Conv2D:
+ filters: 128
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Conv2D:
+ filters: 128
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - MaxPooling2D:
+ name: pool2
+ pool_size: [2, 2]
+ strides: 2
+
+ - Conv2D:
+ filters: 256
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Conv2D:
+ filters: 256
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Conv2D:
+ filters: 256
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - MaxPooling2D:
+ name: pool3
+ pool_size: [2, 2]
+ strides: 2
+
+ - Conv2D:
+ filters: 512
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Conv2D:
+ filters: 512
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Conv2D:
+ filters: 512
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - MaxPooling2D:
+ name: pool4
+ pool_size: [2, 2]
+ strides: 2
+
+ - Conv2D:
+ filters: 512
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Conv2D:
+ filters: 512
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Conv2D:
+ filters: 512
+ kernel_size: 3
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - MaxPooling2D:
+ name: pool5
+ pool_size: [2, 2]
+ strides: 2
+
+ - Conv2D:
+ filters: 4096
+ kernel_size: 7
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Dropout:
+ rate: 0.5
+ - Conv2D:
+ filters: 4096
+ kernel_size: 1
+ padding: same
+ strides: 1
+ - BatchNormalization:
+ - Activation:
+ activation: 'relu'
+ - Dropout:
+ rate: 0.5
+
+ - Conv2D:
+ filters: num_classes
+ kernel_size: 1
+ strides: 1
+ - Conv2DTranspose:
+ name: upscore2
+ filters: num_classes
+ padding: same
+ kernel_size: 4
+ strides: 2
+ use_bias: false
+
+ - Conv2D:
+ inputs: pool4
+ name: score_pool4
+ filters: num_classes
+ padding: same
+ kernel_size: 1
+ - Add:
+ inputs: [upscore2, score_pool4]
+ - Conv2DTranspose:
+ name: upscore4
+ filters: num_classes
+ padding: same
+ kernel_size: 4
+ strides: 2
+ use_bias: false
+
+ - Conv2D:
+ inputs: pool3
+ name: score_pool3
+ filters: num_classes
+ padding: same
+ kernel_size: 1
+ - Add:
+ inputs: [upscore4, score_pool3]
+ - Conv2DTranspose:
+ name: upscore8
+ filters: num_classes
+ padding: same
+ kernel_size: 16
+ strides: 8
+ use_bias: false
+ - Softmax:
+ axis: 3
diff --git a/delta/config/networks/segnet-medium.yaml b/delta/config/networks/segnet-medium.yaml
index 749b66e3..ea38c311 100644
--- a/delta/config/networks/segnet-medium.yaml
+++ b/delta/config/networks/segnet-medium.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 64
kernel_size: [7, 7]
diff --git a/delta/config/networks/segnet-short-fewer-filters.yaml b/delta/config/networks/segnet-short-fewer-filters.yaml
index c044bcac..601e91ea 100644
--- a/delta/config/networks/segnet-short-fewer-filters.yaml
+++ b/delta/config/networks/segnet-short-fewer-filters.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 64
kernel_size: [7, 7]
diff --git a/delta/config/networks/segnet-short-small-filters.yaml b/delta/config/networks/segnet-short-small-filters.yaml
index 4b0b9498..fa2f9a60 100644
--- a/delta/config/networks/segnet-short-small-filters.yaml
+++ b/delta/config/networks/segnet-short-small-filters.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 100
kernel_size: [5, 5]
diff --git a/delta/config/networks/segnet-short.yaml b/delta/config/networks/segnet-short.yaml
index c044bcac..601e91ea 100644
--- a/delta/config/networks/segnet-short.yaml
+++ b/delta/config/networks/segnet-short.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 64
kernel_size: [7, 7]
diff --git a/delta/config/networks/segnet.yaml b/delta/config/networks/segnet.yaml
index a693b238..24588149 100644
--- a/delta/config/networks/segnet.yaml
+++ b/delta/config/networks/segnet.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [~, ~, num_bands]
- Conv2D:
filters: 64
kernel_size: [7, 7]
diff --git a/delta/config/networks/variational_autoencoder.yaml b/delta/config/networks/variational_autoencoder.yaml
index 62fdbee9..7acfe7ac 100644
--- a/delta/config/networks/variational_autoencoder.yaml
+++ b/delta/config/networks/variational_autoencoder.yaml
@@ -1,6 +1,6 @@
layers:
- Input:
- shape: in_shape
+ shape: [16, 16, num_bands]
- Conv2D:
filters: 300
kernel_size: [3, 3]
diff --git a/tests/test_worldview.py b/delta/extensions/__init__.py
similarity index 60%
rename from tests/test_worldview.py
rename to delta/extensions/__init__.py
index 2f84b6a9..6b63af4d 100644
--- a/tests/test_worldview.py
+++ b/delta/extensions/__init__.py
@@ -15,22 +15,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-#pylint:disable=redefined-outer-name
"""
-Test for worldview class.
-"""
-import pytest
+Module for extensions to DELTA.
-from delta.imagery.sources import worldview
+This is a collection of default extensions that come with DELTA. If you
+are interested in making your own extensions, see `delta.config.extensions`.
+"""
-@pytest.fixture(scope="function")
-def wv_image(worldview_filenames):
- return worldview.WorldviewImage(worldview_filenames[0])
+from .defaults import initialize
-# very basic, doesn't actually look at content
-def test_wv_image(wv_image):
- assert wv_image.meta_path() is not None
- buf = wv_image.read()
- assert buf.shape == (64, 32, 1)
- assert len(wv_image.scale()) == 1
- assert len(wv_image.bandwidth()) == 1
+initialize()
diff --git a/delta/extensions/callbacks.py b/delta/extensions/callbacks.py
new file mode 100644
index 00000000..7bd8cc71
--- /dev/null
+++ b/delta/extensions/callbacks.py
@@ -0,0 +1,82 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Custom callbacks that come with DELTA.
+"""
+
+import tensorflow
+import tensorflow.keras.callbacks
+
+from delta.config.extensions import register_callback
+from delta.ml.train import ContinueTrainingException
+
+class SetTrainable(tensorflow.keras.callbacks.Callback):
+ """
+ Changes whether a given layer is trainable during training.
+
+ This is useful for transfer learning, to do an initial training and then allow fine-tuning.
+ """
+ def __init__(self, layer_name: str, epoch: int, trainable: bool=True, learning_rate: float=None):
+ """
+ Parameters
+ ----------
+ layer_name: str
+ The layer to modify.
+ epoch: int
+ The change will take place at the start of this epoch (the first epoch is 1).
+ trainable: bool
+ Whether the layer will be made trainable or not trainable.
+ learning_rate: float
+ Optionally change the learning rate as well.
+ """
+ super().__init__()
+ self._layer_name = layer_name
+ self._epoch = epoch - 1
+ self._make_trainable = trainable
+ self._lr = learning_rate
+ self._triggered = False
+
+ def on_epoch_begin(self, epoch, logs=None):
+ if epoch == self._epoch:
+ if self._triggered:
+ return
+ self._triggered = True # don't repeat twice
+ l = self.model.get_layer(self._layer_name)
+ l.trainable = True
+ # have to abort, recompile changed model, and continue training
+ raise ContinueTrainingException(completed_epochs=epoch, recompile_model=True, learning_rate=self._lr)
+
+def ExponentialLRScheduler(start_epoch: int=10, multiplier: float=0.95):
+ """
+ Schedule the learning rate exponentially.
+
+ Parameters
+ ----------
+ start_epoch: int
+ The epoch to begin.
+ multiplier: float
+ After `start_epoch`, multiply the learning rate by this amount each epoch.
+ """
+ def schedule(epoch, lr):
+ if epoch < start_epoch:
+ return lr
+ return multiplier * lr
+ return tensorflow.keras.callbacks.LearningRateScheduler(schedule)
+
+register_callback('SetTrainable', SetTrainable)
+register_callback('ExponentialLRScheduler', ExponentialLRScheduler)
diff --git a/delta/extensions/defaults.py b/delta/extensions/defaults.py
new file mode 100644
index 00000000..d67a8691
--- /dev/null
+++ b/delta/extensions/defaults.py
@@ -0,0 +1,51 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Module to install all extensions that come with DELTA by default.
+"""
+
+from delta.config.extensions import register_extension, register_image_reader, register_image_writer
+
+from .sources import tiff
+from .sources import landsat
+from .sources import npy
+from .sources import worldview
+from .sources import sentinel1
+
+def initialize():
+ """
+ Register all default extensions.
+ """
+ register_extension('delta.extensions.callbacks')
+
+ register_extension('delta.extensions.layers.pretrained')
+ register_extension('delta.extensions.layers.gaussian_sample')
+ register_extension('delta.extensions.layers.efficientnet')
+ register_extension('delta.extensions.layers.simple')
+
+ register_extension('delta.extensions.losses')
+ register_extension('delta.extensions.metrics')
+ register_extension('delta.extensions.preprocess')
+
+ register_image_reader('tiff', tiff.TiffImage)
+ register_image_reader('npy', npy.NumpyImage)
+ register_image_reader('landsat', landsat.LandsatImage)
+ register_image_reader('worldview', worldview.WorldviewImage)
+ register_image_reader('sentinel1', sentinel1.Sentinel1Image)
+
+ register_image_writer('tiff', tiff.TiffWriter)
+ register_image_writer('npy', npy.NumpyWriter)
diff --git a/delta/imagery/sources/__init__.py b/delta/extensions/layers/__init__.py
similarity index 95%
rename from delta/imagery/sources/__init__.py
rename to delta/extensions/layers/__init__.py
index d0baa8d2..d51fe988 100644
--- a/delta/imagery/sources/__init__.py
+++ b/delta/extensions/layers/__init__.py
@@ -16,5 +16,5 @@
# limitations under the License.
"""
-Module for reading imagery.
+Custom layers provided by DELTA.
"""
diff --git a/delta/extensions/layers/efficientnet.py b/delta/extensions/layers/efficientnet.py
new file mode 100644
index 00000000..85182935
--- /dev/null
+++ b/delta/extensions/layers/efficientnet.py
@@ -0,0 +1,480 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#pylint:disable=dangerous-default-value, too-many-arguments
+"""
+An implementation of EfficientNet. This is
+taken from tensorflow but modified to remove initial layers.
+"""
+# https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from copy import deepcopy
+import math
+
+import tensorflow
+#from tensorflow.keras.applications import imagenet_utils
+
+from delta.config.extensions import register_layer
+
+backend = tensorflow.keras.backend
+layers = tensorflow.keras.layers
+models = tensorflow.keras.models
+keras_utils = tensorflow.keras.utils
+
+#BASE_WEIGHTS_PATH = (
+# 'https://github.com/Callidior/keras-applications/'
+# 'releases/download/efficientnet/')
+#WEIGHTS_HASHES = {
+# 'b0': ('e9e877068bd0af75e0a36691e03c072c',
+# '345255ed8048c2f22c793070a9c1a130'),
+# 'b1': ('8f83b9aecab222a9a2480219843049a1',
+# 'b20160ab7b79b7a92897fcb33d52cc61'),
+# 'b2': ('b6185fdcd190285d516936c09dceeaa4',
+# 'c6e46333e8cddfa702f4d8b8b6340d70'),
+# 'b3': ('b2db0f8aac7c553657abb2cb46dcbfbb',
+# 'e0cf8654fad9d3625190e30d70d0c17d'),
+# 'b4': ('ab314d28135fe552e2f9312b31da6926',
+# 'b46702e4754d2022d62897e0618edc7b'),
+# 'b5': ('8d60b903aff50b09c6acf8eaba098e09',
+# '0a839ac36e46552a881f2975aaab442f'),
+# 'b6': ('a967457886eac4f5ab44139bdd827920',
+# '375a35c17ef70d46f9c664b03b4437f2'),
+# 'b7': ('e964fd6e26e9a4c144bcb811f2a10f20',
+# 'd55674cc46b805f4382d18bc08ed43c1')
+#}
+
+DEFAULT_BLOCKS_ARGS = [
+ {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
+ 'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
+ {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
+ 'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
+ {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
+ 'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
+ {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
+ 'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
+ {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
+ 'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
+ {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
+ 'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
+ {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
+ 'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
+]
+
+CONV_KERNEL_INITIALIZER = {
+ 'class_name': 'VarianceScaling',
+ 'config': {
+ 'scale': 2.0,
+ 'mode': 'fan_out',
+ # EfficientNet actually uses an untruncated normal distribution for
+ # initializing conv layers, but keras.initializers.VarianceScaling use
+ # a truncated distribution.
+ # We decided against a custom initializer for better serializability.
+ 'distribution': 'normal'
+ }
+}
+
+DENSE_KERNEL_INITIALIZER = {
+ 'class_name': 'VarianceScaling',
+ 'config': {
+ 'scale': 1. / 3.,
+ 'mode': 'fan_out',
+ 'distribution': 'uniform'
+ }
+}
+
+def correct_pad(inputs, kernel_size):
+ """Returns a tuple for zero-padding for 2D convolution with downsampling.
+ # Arguments
+ input_size: An integer or tuple/list of 2 integers.
+ kernel_size: An integer or tuple/list of 2 integers.
+ # Returns
+ A tuple.
+ """
+ img_dim = 2 if backend.image_data_format() == 'channels_first' else 1
+ input_size = backend.int_shape(inputs)[img_dim:(img_dim + 2)]
+
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+
+ if input_size[0] is None:
+ adjust = (1, 1)
+ else:
+ adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
+
+ correct = (kernel_size[0] // 2, kernel_size[1] // 2)
+
+ return ((correct[0] - adjust[0], correct[0]),
+ (correct[1] - adjust[1], correct[1]))
+
+def swish(x):
+ """Swish activation function.
+
+ # Arguments
+ x: Input tensor.
+
+ # Returns
+ The Swish activation: `x * sigmoid(x)`.
+
+ # References
+ [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
+ """
+ if backend.backend() == 'tensorflow':
+ try:
+ # The native TF implementation has a more
+ # memory-efficient gradient implementation
+ return backend.tf.nn.swish(x)
+ except AttributeError:
+ pass
+
+ return x * backend.sigmoid(x)
+
+
+def block(inputs, activation_fn=swish, drop_rate=0., name='',
+ filters_in=32, filters_out=16, kernel_size=3, strides=1,
+ expand_ratio=1, se_ratio=0., id_skip=True):
+ """A mobile inverted residual block.
+
+ # Arguments
+ inputs: input tensor.
+ activation_fn: activation function.
+ drop_rate: float between 0 and 1, fraction of the input units to drop.
+ name: string, block label.
+ filters_in: integer, the number of input filters.
+ filters_out: integer, the number of output filters.
+ kernel_size: integer, the dimension of the convolution window.
+ strides: integer, the stride of the convolution.
+ expand_ratio: integer, scaling coefficient for the input filters.
+ se_ratio: float between 0 and 1, fraction to squeeze the input filters.
+ id_skip: boolean.
+
+ # Returns
+ output tensor for the block.
+ """
+ bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
+
+ # Expansion phase
+ filters = filters_in * expand_ratio
+ if expand_ratio != 1:
+ x = layers.Conv2D(filters, 1,
+ padding='same',
+ use_bias=False,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'expand_conv')(inputs)
+ x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
+ x = layers.Activation(activation_fn, name=name + 'expand_activation')(x)
+ else:
+ x = inputs
+
+ # Depthwise Convolution
+ if strides == 2:
+ x = layers.ZeroPadding2D(padding=correct_pad(x, kernel_size),
+ name=name + 'dwconv_pad')(x)
+ conv_pad = 'valid'
+ else:
+ conv_pad = 'same'
+ x = layers.DepthwiseConv2D(kernel_size,
+ strides=strides,
+ padding=conv_pad,
+ use_bias=False,
+ depthwise_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'dwconv')(x)
+ x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
+ x = layers.Activation(activation_fn, name=name + 'activation')(x)
+
+ # Squeeze and Excitation phase
+ if 0 < se_ratio <= 1:
+ filters_se = max(1, int(filters_in * se_ratio))
+ se = layers.GlobalAveragePooling2D(name=name + 'se_squeeze')(x)
+ if bn_axis == 1:
+ se = layers.Reshape((filters, 1, 1), name=name + 'se_reshape')(se)
+ else:
+ se = layers.Reshape((1, 1, filters), name=name + 'se_reshape')(se)
+ se = layers.Conv2D(filters_se, 1,
+ padding='same',
+ activation=activation_fn,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'se_reduce')(se)
+ se = layers.Conv2D(filters, 1,
+ padding='same',
+ activation='sigmoid',
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'se_expand')(se)
+ if backend.backend() == 'theano':
+ # For the Theano backend, we have to explicitly make
+ # the excitation weights broadcastable.
+ se = layers.Lambda(
+ lambda x: backend.pattern_broadcast(x, [True, True, True, False]),
+ output_shape=lambda input_shape: input_shape,
+ name=name + 'se_broadcast')(se)
+ x = layers.multiply([x, se], name=name + 'se_excite')
+
+ # Output phase
+ x = layers.Conv2D(filters_out, 1,
+ padding='same',
+ use_bias=False,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'project_conv')(x)
+ x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
+ if (id_skip is True and strides == 1 and filters_in == filters_out):
+ if drop_rate > 0:
+ x = layers.Dropout(drop_rate,
+ noise_shape=(None, 1, 1, 1),
+ name=name + 'drop')(x)
+ x = layers.add([x, inputs], name=name + 'add')
+
+ return x
+
+
+def EfficientNet(width_coefficient,
+ depth_coefficient,
+ drop_connect_rate=0.2,
+ depth_divisor=8,
+ activation_fn=swish,
+ blocks_args=DEFAULT_BLOCKS_ARGS,
+ model_name='efficientnet',
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ name=None):
+ #pylint: disable=too-many-locals
+ if input_tensor is None:
+ img_input = layers.Input(shape=input_shape)
+ else:
+ if not backend.is_keras_tensor(input_tensor):
+ img_input = layers.Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
+
+ bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
+
+ def round_filters(filters, divisor=depth_divisor):
+ """Round number of filters based on depth multiplier."""
+ filters *= width_coefficient
+ new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ return int(new_filters)
+
+ def round_repeats(repeats):
+ """Round number of repeats based on depth multiplier."""
+ return int(math.ceil(depth_coefficient * repeats))
+
+ # Build stem
+ x = img_input
+ x = layers.ZeroPadding2D(padding=correct_pad(x, 3),
+ name='stem_conv_pad')(x)
+ x = layers.Conv2D(round_filters(32), 3,
+ strides=2,
+ padding='valid',
+ use_bias=False,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name='stem_conv')(x)
+ x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
+ x = layers.Activation(activation_fn, name='stem_activation')(x)
+
+ # Build blocks
+ blocks_args = deepcopy(blocks_args)
+
+ b = 0
+ blocks = float(sum(args['repeats'] for args in blocks_args))
+ for (i, args) in enumerate(blocks_args):
+ assert args['repeats'] > 0
+ # Update block input and output filters based on depth multiplier.
+ args['filters_in'] = round_filters(args['filters_in'])
+ args['filters_out'] = round_filters(args['filters_out'])
+
+ for j in range(round_repeats(args.pop('repeats'))):
+ # The first block needs to take care of stride and filter size increase.
+ if j > 0:
+ args['strides'] = 1
+ args['filters_in'] = args['filters_out']
+ x = block(x, activation_fn, drop_connect_rate * b / blocks,
+ name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
+ b += 1
+
+ # Build top
+ #x = layers.Conv2D(round_filters(1280), 1,
+ # padding='same',
+ # use_bias=False,
+ # kernel_initializer=CONV_KERNEL_INITIALIZER,
+ # name='top_conv')(x)
+ #x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
+ #x = layers.Activation(activation_fn, name='top_activation')(x)
+ #if include_top:
+ # x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
+ # if dropout_rate > 0:
+ # x = layers.Dropout(dropout_rate, name='top_dropout')(x)
+ # x = layers.Dense(classes,
+ # activation='softmax',
+ # kernel_initializer=DENSE_KERNEL_INITIALIZER,
+ # name='probs')(x)
+ #else:
+ # if pooling == 'avg':
+ # x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
+ # elif pooling == 'max':
+ # x = layers.GlobalMaxPooling2D(name='max_pool')(x)
+
+ # Ensure that the model takes into account
+ # any potential predecessors of `input_tensor`.
+ if input_tensor is not None:
+ inputs = keras_utils.get_source_inputs(input_tensor)
+ else:
+ inputs = img_input
+
+ # Create model.
+ model = models.Model(inputs, x, name=name if name is not None else model_name)
+
+ # Load weights.
+ #if weights == 'imagenet':
+ # if include_top:
+ # file_suff = '_weights_tf_dim_ordering_tf_kernels_autoaugment.h5'
+ # file_hash = WEIGHTS_HASHES[model_name[-2:]][0]
+ # else:
+ # file_suff = '_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5'
+ # file_hash = WEIGHTS_HASHES[model_name[-2:]][1]
+ # file_name = model_name + file_suff
+ # weights_path = keras_utils.get_file(file_name,
+ # BASE_WEIGHTS_PATH + file_name,
+ # cache_subdir='models',
+ # file_hash=file_hash)
+ # model.load_weights(weights_path, by_name=True, skip_mismatch=True)
+ if weights is not None:
+ model.load_weights(weights)
+
+ return model
+
+
+#def EfficientNetB0(include_top=True,
+# input_tensor=None,
+# input_shape=None,
+# **kwargs):
+# return EfficientNet(1.0, 1.0,
+# model_name='efficientnet-b0',
+# include_top=include_top,
+# input_tensor=input_tensor, input_shape=input_shape,
+# **kwargs)
+#
+#
+#def EfficientNetB1(include_top=True,
+# input_tensor=None,
+# input_shape=None,
+# **kwargs):
+# return EfficientNet(1.0, 1.1,
+# model_name='efficientnet-b1',
+# include_top=include_top,
+# input_tensor=input_tensor, input_shape=input_shape,
+# **kwargs)
+#
+#
+#def EfficientNetB2(include_top=True,
+# input_tensor=None,
+# input_shape=None,
+# **kwargs):
+# return EfficientNet(1.1, 1.2,
+# model_name='efficientnet-b2',
+# include_top=include_top,
+# input_tensor=input_tensor, input_shape=input_shape,
+# **kwargs)
+#
+#
+#def EfficientNetB3(include_top=True,
+# input_tensor=None,
+# input_shape=None,
+# **kwargs):
+# return EfficientNet(1.2, 1.4,
+# model_name='efficientnet-b3',
+# include_top=include_top,
+# input_tensor=input_tensor, input_shape=input_shape,
+# **kwargs)
+#
+#
+#def EfficientNetB4(include_top=True,
+# input_tensor=None,
+# input_shape=None,
+# **kwargs):
+# return EfficientNet(1.4, 1.8,
+# model_name='efficientnet-b4',
+# include_top=include_top,
+# input_tensor=input_tensor, input_shape=input_shape,
+# **kwargs)
+#
+#
+#def EfficientNetB5(include_top=True,
+# input_tensor=None,
+# input_shape=None,
+# **kwargs):
+# return EfficientNet(1.6, 2.2,
+# model_name='efficientnet-b5',
+# include_top=include_top,
+# input_tensor=input_tensor, input_shape=input_shape,
+# **kwargs)
+#
+#
+#def EfficientNetB6(include_top=True,
+# input_tensor=None,
+# input_shape=None,
+# **kwargs):
+# return EfficientNet(1.8, 2.6,
+# model_name='efficientnet-b6',
+# include_top=include_top,
+# input_tensor=input_tensor, input_shape=input_shape,
+# **kwargs)
+#
+#
+#def EfficientNetB7(include_top=True,
+# input_tensor=None,
+# input_shape=None,
+# **kwargs):
+# return EfficientNet(2.0, 3.1,
+# model_name='efficientnet-b7',
+# include_top=include_top,
+# input_tensor=input_tensor, input_shape=input_shape,
+# **kwargs)
+
+def DeltaEfficientNet(input_shape, width_coefficient=1.1, depth_coefficient=1.2, name=None):
+ return EfficientNet(width_coefficient, depth_coefficient,
+ input_shape=input_shape, weights=None, name=name)
+
+#def preprocess_input(x, data_format=None, **kwargs):
+# """Preprocesses a numpy array encoding a batch of images.
+#
+# # Arguments
+# x: a 3D or 4D numpy array consists of RGB values within [0, 255].
+# data_format: data format of the image tensor.
+#
+# # Returns
+# Preprocessed array.
+# """
+# return imagenet_utils.preprocess_input(x, data_format,
+# mode='torch', **kwargs)
+
+
+#setattr(EfficientNetB0, '__doc__', EfficientNet.__doc__)
+#setattr(EfficientNetB1, '__doc__', EfficientNet.__doc__)
+#setattr(EfficientNetB2, '__doc__', EfficientNet.__doc__)
+#setattr(EfficientNetB3, '__doc__', EfficientNet.__doc__)
+#setattr(EfficientNetB4, '__doc__', EfficientNet.__doc__)
+#setattr(EfficientNetB5, '__doc__', EfficientNet.__doc__)
+#setattr(EfficientNetB6, '__doc__', EfficientNet.__doc__)
+#setattr(EfficientNetB7, '__doc__', EfficientNet.__doc__)
+
+register_layer('EfficientNet', DeltaEfficientNet)
diff --git a/delta/ml/layers.py b/delta/extensions/layers/gaussian_sample.py
similarity index 65%
rename from delta/ml/layers.py
rename to delta/extensions/layers/gaussian_sample.py
index 28ad3974..5f7d99c6 100644
--- a/delta/ml/layers.py
+++ b/delta/extensions/layers/gaussian_sample.py
@@ -16,29 +16,35 @@
# limitations under the License.
"""
-DELTA specific network layers.
+Gaussian sampling layer, used in variational autoencoders.
"""
-import tensorflow.keras.models
import tensorflow.keras.backend as K
-from tensorflow.keras.layers import Layer
from tensorflow.keras.callbacks import Callback
-class DeltaLayer(Layer):
- # optionally return a Keras callback
- def callback(self): # pylint:disable=no-self-use
- return None
+from delta.config.extensions import register_layer
+from delta.ml.train import DeltaLayer
# If layers inherit from callback as well we add them automatically on fit
class GaussianSample(DeltaLayer):
def __init__(self, kl_loss=True, **kwargs):
- super(GaussianSample, self).__init__(**kwargs)
+ """
+ A layer that takes two inputs, a mean and a log variance, both of the same
+ dimensions. This layer returns a tensor of the same dimensions, sample
+ according to the provided mean and variance.
+
+ Parameters
+ ----------
+ kl_loss: bool
+ Add a kl loss term for the layer if true, to encourage a Normal(0, 1) distribution.
+ """
+ super().__init__(**kwargs)
self._use_kl_loss = kl_loss
self._kl_enabled = K.variable(0.0, name=self.name + ':kl_enabled')
self.trainable = False
def get_config(self):
- config = super(GaussianSample, self).get_config()
+ config = super().get_config()
config.update({'kl_loss': self._use_kl_loss})
return config
@@ -73,29 +79,4 @@ def call(self, inputs, **_):
return result
-def pretrained_model(filename, encoding_layer, trainable=False, **kwargs):
- '''
- Loads a pretrained model and extracts the enocoding layers.
- '''
- assert filename is not None, 'Did not specify pre-trained model.'
- assert encoding_layer is not None, 'Did not specify encoding layer point.'
-
- temp_model = tensorflow.keras.models.load_model(filename, compile=False)
-
- output_layers = []
- if isinstance(encoding_layer, int):
- break_point = lambda x, y: x == encoding_layer
- elif isinstance(encoding_layer, str):
- break_point = lambda x, y: y.name == encoding_layer
-
- for idx, l in enumerate(temp_model.layers):
- output_layers.append(l)
- output_layers[-1].trainable = trainable
- if break_point(idx, l):
- break
- return tensorflow.keras.models.Sequential(output_layers, **kwargs)
-
-ALL_LAYERS = {
- 'GaussianSample' : GaussianSample,
- 'Pretrained' : pretrained_model
-}
+register_layer('GaussianSample', GaussianSample)
diff --git a/delta/extensions/layers/pretrained.py b/delta/extensions/layers/pretrained.py
new file mode 100644
index 00000000..49233cca
--- /dev/null
+++ b/delta/extensions/layers/pretrained.py
@@ -0,0 +1,120 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Use a pretrained model inside another network.
+"""
+from typing import List, Optional
+import tensorflow
+import tensorflow.keras.models
+
+from delta.config.extensions import register_layer
+
+class InputSelectLayer(tensorflow.keras.layers.Layer):
+ """
+ A layer that takes any number of inputs, and returns a given one.
+ """
+ def __init__(self, arg_number, **kwargs):
+ """
+ Parameters
+ ----------
+ arg_number: int
+ The index of the input to select.
+ """
+ super().__init__(**kwargs)
+ self._arg = arg_number
+ def call(self, inputs, **kwargs):
+ return inputs[self._arg]
+ def get_config(self):
+ return {'arg_number' : self._arg}
+
+def _model_to_output_layers(model, break_point, trainable):
+ output_layers = []
+ for idx, l in enumerate(model.layers):
+ if not isinstance(l, tensorflow.keras.layers.BatchNormalization):
+ l.trainable = trainable
+ if isinstance(l, tensorflow.keras.models.Model): # assumes sequential
+ output_layers.extend(_model_to_output_layers(l, break_point, trainable))
+ else:
+ output_layers.append(l)
+ if break_point(idx, l):
+ break
+ return output_layers
+
+def pretrained(filename, encoding_layer, outputs: Optional[List[str]]=None, trainable: bool=True,
+ training: bool=True, **kwargs):
+ """
+ Creates pre-trained layer from an existing model file.
+ Only works with sequential models. This was quite tricky to get right with tensorflow.
+
+ Parameters
+ ----------
+ filename: str
+ Model file to load.
+ encoding_layer: str
+ Name of the layer to stop at.
+ outputs: Optional[List[str]]
+ List of names of output layers that may be used later in the model.
+ Only layers listed here will be accessible as inputs to other layers, in the form
+ this_layer_name/internal_name. (internal_name must be included in outputs to do so)
+ trainable: bool
+ Whether to update weights during training for this layer.
+ training: bool
+ Standard tensorflow option, used for batch norm layers.
+ """
+ model = tensorflow.keras.models.load_model(filename, compile=False)
+
+ if isinstance(encoding_layer, int):
+ break_point = lambda x, y: x == encoding_layer
+ elif isinstance(encoding_layer, str):
+ break_point = lambda x, y: y.name == encoding_layer
+
+ output_layers = _model_to_output_layers(model, break_point, trainable)
+
+ output_tensors = []
+ cur = model.inputs[0]
+ old_to_new = {}
+ old_to_new[cur.ref()] = cur
+ for l in output_layers:
+ if isinstance(l, tensorflow.keras.layers.InputLayer):
+ old_to_new[l.output.ref()] = cur
+ output_tensors.append(cur)
+ continue
+ if isinstance(l.input, list):
+ inputs = [old_to_new[t.ref()] for t in l.input]
+ else:
+ inputs = old_to_new[l.input.ref()]
+ cur = l(inputs)
+ old_to_new[l.output.ref()] = cur
+ output_tensors.append(cur)
+ new_model = tensorflow.keras.models.Model(model.inputs, output_tensors, **kwargs)
+
+ layers_dict = {}
+ if outputs:
+ for (i, l) in enumerate(output_layers):
+ if l.name not in outputs:
+ continue
+ layers_dict[l.name] = InputSelectLayer(i)
+
+ def call(*inputs):
+ result = new_model(inputs, training=training)
+ output = (InputSelectLayer(len(output_layers)-1)(result), {k : v(result) for k, v in layers_dict.items()})
+ return output
+ return call
+
+register_layer('InputSelectLayer', InputSelectLayer)
+register_layer('Pretrained', pretrained)
diff --git a/delta/extensions/layers/simple.py b/delta/extensions/layers/simple.py
new file mode 100644
index 00000000..b3b9aa70
--- /dev/null
+++ b/delta/extensions/layers/simple.py
@@ -0,0 +1,64 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Simple helpful layers.
+"""
+
+import tensorflow as tf
+import tensorflow.keras.layers
+import tensorflow.keras.backend as K
+
+from delta.config.extensions import register_layer
+
+class RepeatedGlobalAveragePooling2D(tensorflow.keras.layers.Layer):
+ """
+ Global average pooling in 2D for fully convolutional networks.
+
+ Takes the global average over the entire input, and repeats
+ it to return a tensor the same size as the input.
+ """
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+ def call(self, inputs, **_):
+ ones = tf.fill(tf.shape(inputs)[:-1], 1.0)
+ ones = tf.expand_dims(ones, -1)
+ mean = K.mean(inputs, axis=[1, 2])
+ mean = tf.expand_dims(mean, 1)
+ mean = tf.expand_dims(mean, 1)
+ return mean * ones
+
+class ReflectionPadding2D(tensorflow.keras.layers.Layer):
+ """
+ Add reflected padding of the given size surrounding the input.
+ """
+ def __init__(self, padding=(1, 1), **kwargs):
+ super().__init__(**kwargs)
+ self.padding = tuple(padding)
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({'padding': self.padding})
+ return config
+
+ def call(self, inputs, **_):
+ w_pad,h_pad = self.padding
+ return tf.pad(inputs, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
+
+register_layer('RepeatedGlobalAveragePooling2D', RepeatedGlobalAveragePooling2D)
+register_layer('ReflectionPadding2D', ReflectionPadding2D)
diff --git a/delta/extensions/losses.py b/delta/extensions/losses.py
new file mode 100644
index 00000000..b176953e
--- /dev/null
+++ b/delta/extensions/losses.py
@@ -0,0 +1,154 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Various helpful loss functions.
+"""
+
+import numpy as np
+
+import tensorflow as tf
+import tensorflow.keras.losses
+import tensorflow.keras.backend as K
+
+from delta.config import config
+from delta.config.extensions import register_loss
+
+def ms_ssim(y_true, y_pred):
+ """
+ `tf.image.ssim_multiscale` as a loss function.
+ """
+ return 1.0 - tf.image.ssim_multiscale(y_true, y_pred, 4.0)
+
+def ms_ssim_mse(y_true, y_pred):
+ """
+ Sum of MS-SSIM and Mean Squared Error.
+ """
+ return ms_ssim(y_true, y_pred) + K.mean(K.mean(tensorflow.keras.losses.MSE(y_true, y_pred), -1), -1)
+
+# from https://gist.github.com/wassname/7793e2058c5c9dacb5212c0ac0b18a8a
+def dice_coef(y_true, y_pred, smooth=1):
+ """
+ Dice = (2*|X & Y|)/ (|X|+ |Y|)
+ = 2*sum(|A*B|)/(sum(A^2)+sum(B^2))
+ ref: https://arxiv.org/pdf/1606.04797v1.pdf
+ """
+ intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
+ return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)
+
+def dice_loss(y_true, y_pred):
+ """
+ Dice coefficient as a loss function.
+ """
+ return 1 - dice_coef(y_true, y_pred)
+
+class MappedLoss(tf.keras.losses.Loss): #pylint: disable=abstract-method
+ def __init__(self, mapping, name=None):
+ """
+ This is a base class for losses when the labels of the input images do not match the labels
+ output by the network. For example, if one class in the labels should be ignored, or two
+ classes in the label should map to the same output, or one label should be treated as a probability
+ between two classes. It applies a transform to the output labels and then applies the loss function.
+
+ Note that the transform is applied after preprocessing (labels in the config will be transformed to 0-n
+ in order, and nodata will be n+1).
+
+ Parameters
+ ----------
+ mapping
+ One of:
+ * A list with transforms, where the first entry is what to transform the first label, to etc., i.e.,
+ [1, 0] will swap the order of two labels.
+ * A dictionary with classes mapped to transformed values. Classes can be referenced by name or by
+ number (see `delta.imagery.imagery_config.ClassesConfig.class_id` for class formats).
+ name: Optional[str]
+ Optional name for the loss function.
+ """
+ super().__init__(name=name)
+ if isinstance(mapping, list):
+ map_list = mapping
+ else:
+ # automatically set nodata to 0 (even if there is none it's fine)
+ entry = mapping[next(iter(mapping))]
+ if np.isscalar(entry):
+ map_list = np.zeros((len(config.dataset.classes) + 1,))
+ else:
+ map_list = np.zeros((len(config.dataset.classes) + 1, len(entry)))
+ assert len(mapping) == len(config.dataset.classes), 'Must specify all classes in loss mapping.'
+ for k in mapping:
+ i = config.dataset.classes.class_id(k)
+ if isinstance(mapping[k], (int, float)):
+ map_list[i] = mapping[k]
+ else:
+ assert len(mapping[k]) == map_list.shape[1], 'Mapping entry wrong length.'
+ map_list[i, :] = np.asarray(mapping[k])
+ self._lookup = tf.constant(map_list, dtype=tf.float32)
+
+class MappedCategoricalCrossentropy(MappedLoss):
+ """
+ `MappedLoss` for categorical_crossentropy.
+ """
+ def call(self, y_true, y_pred):
+ y_true = tf.squeeze(y_true)
+ true_convert = tf.gather(self._lookup, tf.cast(y_true, tf.int32), axis=None)
+ return tensorflow.keras.losses.categorical_crossentropy(true_convert, y_pred)
+
+class MappedBinaryCrossentropy(MappedLoss):
+ """
+ `MappedLoss` for binary_crossentropy.
+ """
+ def call(self, y_true, y_pred):
+ true_convert = tf.gather(self._lookup, tf.cast(y_true, tf.int32), axis=None)
+ return tensorflow.keras.losses.binary_crossentropy(true_convert, y_pred)
+
+class MappedDiceLoss(MappedLoss):
+ """
+ `MappedLoss` for `dice_loss`.
+ """
+ def call(self, y_true, y_pred):
+ true_convert = tf.gather(self._lookup, tf.cast(y_true, tf.int32), axis=None)
+ return dice_loss(true_convert, y_pred)
+
+class MappedMsssim(MappedLoss):
+ """
+ `MappedLoss` for `ms_ssim`.
+ """
+ def call(self, y_true, y_pred):
+ true_convert = tf.gather(self._lookup, tf.cast(y_true, tf.int32), axis=None)
+ return ms_ssim(true_convert, y_pred)
+
+class MappedDiceBceMsssim(MappedLoss):
+ """
+ `MappedLoss` for sum of `ms_ssim`, `dice_loss`, and `binary_crossentropy`.
+ """
+ def call(self, y_true, y_pred):
+ true_convert = tf.gather(self._lookup, tf.cast(y_true, tf.int32), axis=None)
+ dice = dice_loss(true_convert, y_pred)
+ bce = tensorflow.keras.losses.binary_crossentropy(true_convert, y_pred)
+ bce = K.mean(bce)
+ msssim = K.mean(ms_ssim(true_convert, y_pred))
+ return dice + bce + msssim
+
+
+register_loss('ms_ssim', ms_ssim)
+register_loss('ms_ssim_mse', ms_ssim_mse)
+register_loss('dice', dice_loss)
+register_loss('MappedCategoricalCrossentropy', MappedCategoricalCrossentropy)
+register_loss('MappedBinaryCrossentropy', MappedBinaryCrossentropy)
+register_loss('MappedDice', MappedDiceLoss)
+register_loss('MappedMsssim', MappedMsssim)
+register_loss('MappedDiceBceMsssim', MappedDiceBceMsssim)
diff --git a/delta/extensions/metrics.py b/delta/extensions/metrics.py
new file mode 100644
index 00000000..30a237e5
--- /dev/null
+++ b/delta/extensions/metrics.py
@@ -0,0 +1,147 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# pylint: disable=too-many-ancestors
+"""
+Various helpful loss functions.
+"""
+
+import tensorflow as tf
+import tensorflow.keras.metrics
+
+from delta.config import config
+from delta.config.extensions import register_metric
+
+class SparseMetric(tensorflow.keras.metrics.Metric): # pylint:disable=abstract-method # pragma: no cover
+ """
+ An abstract class for metrics applied to integer class labels,
+ with networks that output one-hot encoding.
+ """
+ def __init__(self, label, class_id: int=None, name: str=None, binary: int=False):
+ """
+ Parameters
+ ----------
+ label
+ A class identifier accepted by `delta.imagery.imagery_config.ClassesConfig.class_id`.
+ Compared to valuse in the label image.
+ class_id: Optional[int]
+ For multi-class one-hot outputs, used if the output class ID is different than the
+ one in the label image.
+ name: str
+ Metric name.
+ binary: bool
+ Use binary threshold (0.5) or argmax on one-hot encoding.
+ """
+ super().__init__(name=name)
+ self._binary = binary
+ self._label_id = config.dataset.classes.class_id(label)
+ self._class_id = class_id if class_id is not None else self._label_id
+
+ def reset_state(self):
+ for s in self.variables:
+ s.assign(tf.zeros(shape=s.shape))
+
+class SparseRecall(SparseMetric): # pragma: no cover
+ """
+ Recall.
+ """
+ def __init__(self, label, class_id: int=None, name: str=None, binary: int=False):
+ super().__init__(label, class_id, name, binary)
+ self._total_class = self.add_weight('total_class', initializer='zeros')
+ self._true_positives = self.add_weight('true_positives', initializer='zeros')
+
+ def update_state(self, y_true, y_pred, sample_weight=None): #pylint: disable=unused-argument, arguments-differ
+ y_true = tf.squeeze(y_true)
+ right_class = tf.math.equal(y_true, self._label_id)
+ if self._binary:
+ y_pred = y_pred >= 0.5
+ right_class_pred = tf.squeeze(y_pred)
+ else:
+ y_pred = tf.math.argmax(y_pred, axis=-1)
+ right_class_pred = tf.math.equal(y_pred, self._class_id)
+ total_class = tf.math.reduce_sum(tf.cast(right_class, tf.float32))
+ self._total_class.assign_add(total_class)
+ true_positives = tf.math.logical_and(right_class, right_class_pred)
+ true_positives = tf.math.reduce_sum(tf.cast(true_positives, tf.float32))
+ self._true_positives.assign_add(true_positives)
+
+ def result(self):
+ return tf.math.divide_no_nan(self._true_positives, self._total_class)
+
+class SparsePrecision(SparseMetric): # pragma: no cover
+ """
+ Precision.
+ """
+ def __init__(self, label, class_id: int=None, name: str=None, binary: int=False):
+ super().__init__(label, class_id, name, binary)
+ self._total_class = self.add_weight('total_class', initializer='zeros')
+ self._true_positives = self.add_weight('true_positives', initializer='zeros')
+
+ def update_state(self, y_true, y_pred, sample_weight=None): #pylint: disable=unused-argument, arguments-differ
+ y_true = tf.squeeze(y_true)
+ right_class = tf.math.equal(y_true, self._label_id)
+ if self._binary:
+ y_pred = y_pred >= 0.5
+ right_class_pred = tf.squeeze(y_pred)
+ else:
+ y_pred = tf.math.argmax(y_pred, axis=-1)
+ right_class_pred = tf.math.equal(y_pred, self._class_id)
+
+ total_class = tf.math.reduce_sum(tf.cast(right_class_pred, tf.float32))
+ self._total_class.assign_add(total_class)
+ true_positives = tf.math.logical_and(right_class, right_class_pred)
+ true_positives = tf.math.reduce_sum(tf.cast(true_positives, tf.float32))
+ self._true_positives.assign_add(true_positives)
+
+ def result(self):
+ return tf.math.divide_no_nan(self._true_positives, self._total_class)
+
+class SparseBinaryAccuracy(SparseMetric): # pragma: no cover
+ """
+ Accuracy.
+ """
+ def __init__(self, label, name: str=None):
+ super().__init__(label, label, name, False)
+ self._nodata_id = config.dataset.classes.class_id('nodata')
+ self._total = self.add_weight('total', initializer='zeros')
+ self._correct = self.add_weight('correct', initializer='zeros')
+
+ def update_state(self, y_true, y_pred, sample_weight=None): #pylint: disable=unused-argument, arguments-differ
+ y_true = tf.squeeze(y_true)
+ y_pred = tf.squeeze(y_pred)
+
+ right_class = tf.math.equal(y_true, self._label_id)
+ right_class_pred = y_pred >= 0.5
+ true_positives = tf.math.logical_and(right_class, right_class_pred)
+ false_negatives = tf.math.logical_and(tf.math.logical_not(right_class), tf.math.logical_not(right_class_pred))
+ if self._nodata_id:
+ valid = tf.math.not_equal(y_true, self._nodata_id)
+ false_negatives = tf.math.logical_and(false_negatives, valid)
+ total = tf.math.reduce_sum(tf.cast(valid, tf.float32))
+ else:
+ total = tf.size(y_true)
+
+ true_positives = tf.math.reduce_sum(tf.cast(true_positives, tf.float32))
+ false_negatives = tf.math.reduce_sum(tf.cast(false_negatives, tf.float32))
+ self._correct.assign_add(true_positives + false_negatives)
+ self._total.assign_add(total)
+
+ def result(self):
+ return tf.math.divide(self._correct, self._total)
+
+register_metric('SparseRecall', SparseRecall)
+register_metric('SparsePrecision', SparsePrecision)
+register_metric('SparseBinaryAccuracy', SparseBinaryAccuracy)
diff --git a/delta/extensions/preprocess.py b/delta/extensions/preprocess.py
new file mode 100644
index 00000000..1ea57635
--- /dev/null
+++ b/delta/extensions/preprocess.py
@@ -0,0 +1,120 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#pylint:disable=unused-argument
+"""
+Various helpful preprocessing functions.
+
+These are intended to be included in image: preprocess in a yaml file.
+See the `delta.config` documentation for details. Note that for all
+functions, the image_type will be specified automatically: other
+parameters must be specified in the config file.
+"""
+import numpy as np
+
+from delta.config.extensions import register_preprocess
+
+__DEFAULT_SCALE_FACTORS = {'tiff' : 1024.0,
+ 'worldview' : 1024.0,
+ 'landsat' : 120.0,
+ 'npy' : None,
+ 'sentinel1' : None}
+
+def scale(image_type, factor='default'):
+ """
+ Divides by a given scale factor.
+
+ Parameters
+ ----------
+ factor: Union[str, float]
+ Scale factor to divide by. 'default' will scale by an image type specific
+ default amount.
+ """
+ if factor == 'default':
+ factor = __DEFAULT_SCALE_FACTORS[image_type]
+ factor = np.float32(factor)
+ return (lambda data, _, dummy: data / factor)
+
+def offset(image_type, factor):
+ """
+ Add an amount to all pixels.
+
+ Parameters
+ ----------
+ factor: float
+ Number to add.
+ """
+ factor = np.float32(factor)
+ return lambda data, _, dummy: data + factor
+
+def clip(image_type, bounds):
+ """
+ Clips all pixel values within a range.
+
+ Parameters
+ ----------
+ bounds: List[float]
+ List of two floats to clip all values between.
+ """
+ if isinstance(bounds, list):
+ assert len(bounds) == 2, 'Bounds must have two items.'
+ else:
+ bounds = (bounds, bounds)
+ bounds = (np.float32(bounds[0]), np.float32(bounds[1]))
+ return lambda data, _, dummy: np.clip(data, bounds[0], bounds[1])
+
+def cbrt(image_type):
+ """
+ Cubic root.
+ """
+ return lambda data, _, dummy: np.cbrt(data)
+def sqrt(image_type):
+ """
+ Square root.
+ """
+ return lambda data, _, dummy: np.sqrt(data)
+
+def gauss_mult_noise(image_type, stddev):
+ """
+ Multiplies each pixel by p ~ Normal(1, stddev)
+
+ Parameters
+ ----------
+ stddev: float
+ Standard deviation of distribution to sample from.
+ """
+ return lambda data, _, dummy: data * np.random.normal(1.0, stddev, data.shape)
+
+def substitute(image_type, mapping):
+ """
+ Replaces pixels in image with the listed values.
+
+ Parameters
+ ----------
+ mapping: List[Any]
+ For example, to change a binary image to a one-hot representation,
+ use [[1, 0], [0, 1]]. This replaces all 0 pixels with [1, 0] and all
+ 1 pixels with [0, 1].
+ """
+ return lambda data, _, dummy: np.take(mapping, data)
+
+register_preprocess('scale', scale)
+register_preprocess('offset', offset)
+register_preprocess('clip', clip)
+register_preprocess('sqrt', sqrt)
+register_preprocess('cbrt', cbrt)
+register_preprocess('gauss_mult_noise', gauss_mult_noise)
+register_preprocess('substitute', substitute)
diff --git a/delta/extensions/sources/__init__.py b/delta/extensions/sources/__init__.py
new file mode 100644
index 00000000..9b3fff39
--- /dev/null
+++ b/delta/extensions/sources/__init__.py
@@ -0,0 +1,25 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Imagery types for DELTA.
+
+These are specified in the "type" field in the configuration yaml file.
+
+Note that while DELTA supports compressed images for some satellites, we
+recommend extracting these images to tiffs beforehand as it will speed up training.
+"""
diff --git a/delta/imagery/sources/landsat.py b/delta/extensions/sources/landsat.py
similarity index 84%
rename from delta/imagery/sources/landsat.py
rename to delta/extensions/sources/landsat.py
index 49a4be14..f9550360 100644
--- a/delta/imagery/sources/landsat.py
+++ b/delta/extensions/sources/landsat.py
@@ -32,48 +32,22 @@
# Use this for all the output Landsat data we write.
OUTPUT_NODATA = 0.0
-def _allocate_bands_for_spacecraft(landsat_number):
- """Set up value storage for _parse_mtl_file()"""
-
- BAND_COUNTS = {'5':7, '7':9, '8':11}
-
- num_bands = BAND_COUNTS[landsat_number]
- data = dict()
-
- # There are fewer K constants but we store in the the
- # appropriate band indices.
- data['FILE_NAME' ] = [''] * num_bands
- data['RADIANCE_MULT' ] = [None] * num_bands
- data['RADIANCE_ADD' ] = [None] * num_bands
- data['REFLECTANCE_MULT'] = [None] * num_bands
- data['REFLECTANCE_ADD' ] = [None] * num_bands
- data['K1_CONSTANT' ] = [None] * num_bands
- data['K2_CONSTANT' ] = [None] * num_bands
-
- return data
-
def _parse_mtl_file(mtl_path):
"""Parse out the needed values from the MTL file"""
if not os.path.exists(mtl_path):
- raise Exception('MTL file not found: ' + mtl_path)
+ raise FileNotFoundError('MTL file not found: ' + mtl_path)
# These are all the values we want to read in
DESIRED_TAGS = ['FILE_NAME', 'RADIANCE_MULT', 'RADIANCE_ADD',
'REFLECTANCE_MULT', 'REFLECTANCE_ADD',
'K1_CONSTANT', 'K2_CONSTANT']
- data = None
+ data = dict()
with open(mtl_path, 'r') as f:
for line in f:
-
line = line.replace('"','') # Clean up
- # Get the spacecraft ID and allocate storage
- if 'SPACECRAFT_ID = LANDSAT_' in line:
- spacecraft_id = line.split('_')[-1].strip()
- data = _allocate_bands_for_spacecraft(spacecraft_id)
-
if 'SUN_ELEVATION = ' in line:
value = line.split('=')[-1].strip()
data['SUN_ELEVATION'] = float(value)
@@ -97,6 +71,8 @@ def _parse_mtl_file(mtl_path):
except ValueError: # Means this is not a proper match
break
+ if tag not in data:
+ data[tag] = dict()
if tag == 'FILE_NAME':
data[tag][band] = value # String
else:
@@ -116,27 +92,22 @@ def get_scene_info(path):
output['date' ] = parts[3]
return output
+__LANDSAT_BANDS_DICT = {
+ '5': [1, 2, 3, 4, 5, 6, 7],
+ '7': [1, 2, 3, 4, 5, 6, 7], # Don't forget the extra thermal band!
+ '8': [1, 2, 3, 4, 5, 6, 7, 9]
+}
+
def _get_landsat_bands_to_use(sensor_name):
"""Return the list of one-based band indices that we are currently
using to process the given landsat sensor.
"""
- # For now just the 30 meter bands, in original order.
- LS5_DESIRED_BANDS = [1, 2, 3, 4, 5, 6, 7]
- LS7_DESIRED_BANDS = [1, 2, 3, 4, 5, 6, 7] # Don't forget the extra thermal band!
- LS8_DESIRED_BANDS = [1, 2, 3, 4, 5, 6, 7, 9]
-
- if '5' in sensor_name:
- bands = LS5_DESIRED_BANDS
- else:
- if '7' in sensor_name:
- bands = LS7_DESIRED_BANDS
- else:
- if '8' in sensor_name:
- bands = LS8_DESIRED_BANDS
- else:
- raise Exception('Unknown landsat type: ' + sensor_name)
- return bands
+ for (k, v) in __LANDSAT_BANDS_DICT.items():
+ if k in sensor_name:
+ return v
+ print('Unknown landsat type: ' + sensor_name)
+ return None
def _get_band_paths(mtl_data, folder, bands_to_use=None):
"""Return full paths to all band files that should be in the folder.
@@ -175,8 +146,11 @@ def _find_mtl_file(folder):
class LandsatImage(tiff.TiffImage):
- """Compressed Landsat image tensorflow dataset wrapper (see imagery_dataset.py)"""
+ """Compressed Landsat image. Loads a compressed zip or tar file with a .mtl file."""
+ def __init__(self, paths, nodata_value=None, bands=None):
+ self._bands = bands
+ super().__init__(paths, nodata_value)
def _prep(self, paths):
"""Prepares a Landsat file from the archive for processing.
@@ -193,7 +167,7 @@ def _prep(self, paths):
# Get the folder where this will be stored from the cache manager
name = '_'.join([self._sensor, self._lpath, self._lrow, self._date])
- untar_folder = config.cache_manager().register_item(name)
+ untar_folder = config.io.cache.manager().register_item(name)
# Check if we already unpacked this data
all_files_present = False
@@ -209,7 +183,7 @@ def _prep(self, paths):
print('Unpacking tar file ' + paths + ' to folder ' + untar_folder)
utilities.unpack_to_folder(paths, untar_folder)
- bands_to_use = _get_landsat_bands_to_use(self._sensor)
+ bands_to_use = _get_landsat_bands_to_use(self._sensor) if self._bands is None else self._bands
# Generate all the band file names (the MTL file is not returned)
self._mtl_path = _find_mtl_file(untar_folder)
diff --git a/delta/imagery/sources/npy.py b/delta/extensions/sources/npy.py
similarity index 71%
rename from delta/imagery/sources/npy.py
rename to delta/extensions/sources/npy.py
index 8272f021..bb3a7b49 100644
--- a/delta/imagery/sources/npy.py
+++ b/delta/extensions/sources/npy.py
@@ -20,17 +20,28 @@
"""
import os
+from typing import Optional
+
import numpy as np
-from . import delta_image
+from delta.imagery import delta_image
class NumpyImage(delta_image.DeltaImage):
"""
- Numpy image data tensorflow dataset wrapper (see imagery_dataset.py).
- Can set either path to load a file, or data to load a numpy array directly.
+ Load a numpy array as an image.
"""
- def __init__(self, data=None, path=None):
- super(NumpyImage, self).__init__()
+ def __init__(self, data: Optional[np.ndarray]=None, path: Optional[str]=None, nodata_value=None):
+ """
+ Parameters
+ ----------
+ data: Optional[numpy.ndarray]
+ Loads a numpy array directly.
+ path: Optional[str]
+ Load a numpy array from a file with `numpy.load`. Only one of data or path should be given.
+ nodata_value
+ The pixel value representing no data.
+ """
+ super().__init__(nodata_value)
if path:
assert not data
@@ -43,30 +54,26 @@ def __init__(self, data=None, path=None):
self._data = data
def _read(self, roi, bands, buf=None):
- """
- Read the image of the given data type. An optional roi specifies the boundaries.
-
- This function is intended to be overwritten by subclasses.
- """
if buf is None:
buf = np.zeros(shape=(roi.width(), roi.height(), self.num_bands() ), dtype=self._data.dtype)
- (min_x, max_x, min_y, max_y) = roi.get_bounds()
+ (min_x, max_x, min_y, max_y) = roi.bounds()
buf = self._data[min_y:max_y,min_x:max_x,:]
return buf
def size(self):
- """Return the size of this image in pixels, as (width, height)."""
return (self._data.shape[1], self._data.shape[0])
def num_bands(self):
- """Return the number of bands in the image."""
return self._data.shape[2]
-class NumpyImageWriter(delta_image.DeltaImageWriter):
+ def dtype(self):
+ return self._data.dtype
+
+class NumpyWriter(delta_image.DeltaImageWriter):
def __init__(self):
self._buffer = None
- def initialize(self, size, numpy_dtype, metadata=None):
+ def initialize(self, size, numpy_dtype, metadata=None, nodata_value=None):
self._buffer = np.zeros(shape=size, dtype=numpy_dtype)
def write(self, data, x, y):
diff --git a/delta/extensions/sources/sentinel1.py b/delta/extensions/sources/sentinel1.py
new file mode 100644
index 00000000..dfe2d07f
--- /dev/null
+++ b/delta/extensions/sources/sentinel1.py
@@ -0,0 +1,184 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Functions to support Sentinel1 satellites.
+"""
+
+import os
+import portalocker
+
+from delta.config import config
+from delta.imagery import utilities
+from . import tiff
+
+
+# Unpack procedure:
+# - Start with .zip files
+# - Unpack to .SAFE folders containing .tif files
+# - Use SNAP to
+# - (or) Use gdalbuildvrt to creat a merged.vrt file
+
+
+this_folder = os.path.dirname(os.path.abspath(__file__))
+SNAP_GRAPH_PATH = os.path.join(this_folder,
+ 'sentinel1_default_snap_preprocess_graph.xml')
+SNAP_SCRIPT_PATH = os.path.join(this_folder, 'snap_process_sentinel1.sh')
+
+
+# Using the .vrt does not make much sense with SNAP but it is consistent
+# with the gdalbuildvrt option and makes it easier to search for unpacked
+# Sentinel1 images
+def get_merged_path(unpack_folder):
+ return os.path.join(unpack_folder, 'merged.vrt')
+
+def get_files_from_unpack_folder(folder):
+ """Return the source image file paths from the given unpack folder.
+ Returns [] if the files were not found.
+ """
+
+ # All of the image files are located in the measurement folier
+ measurement_folder = os.path.join(folder, 'measurement')
+ if not os.path.exists(folder) or not os.path.exists(measurement_folder):
+ return []
+
+ tiff_files = []
+ measure_files = os.listdir(measurement_folder)
+ for f in measure_files:
+ ext = os.path.splitext(f)[1]
+ if (ext.lower() == '.tiff') or (ext.lower() == '.tif'):
+ tiff_files.append(os.path.join(measurement_folder, f))
+
+ return tiff_files
+
+
+def unpack_s1_to_folder(zip_path, unpack_folder):
+ '''Returns the merged image path from the unpack folder.
+ Unpacks the zip file and merges the source images as needed.'''
+
+ with portalocker.Lock(zip_path, 'r', timeout=300) as unused: #pylint: disable=W0612
+
+ merged_path = get_merged_path(unpack_folder)
+ try:
+ test_image = tiff.TiffImage(merged_path) #pylint: disable=W0612
+ except Exception: #pylint: disable=W0703
+ test_image = None
+
+ if test_image: # Merged image is ready to use
+ print('Already have unpacked files in ' + unpack_folder)
+ return merged_path
+ # Otherwise go through the entire unpack process
+
+ NUM_SOURCE_CHANNELS = 2
+ need_to_unpack = True
+ if os.path.exists(unpack_folder):
+ source_image_paths = get_files_from_unpack_folder(unpack_folder)
+ if len(source_image_paths) == NUM_SOURCE_CHANNELS:
+ need_to_unpack = False
+ print('Already have files')
+ else:
+ print('Clearing unpack folder missing image files.')
+ os.system('rm -rf ' + unpack_folder)
+
+ if need_to_unpack:
+ print('Unpacking file ' + zip_path + ' to folder ' + unpack_folder)
+ utilities.unpack_to_folder(zip_path, unpack_folder)
+ subdirs = os.listdir(unpack_folder)
+ if len(subdirs) != 1:
+ raise Exception('Unexpected Sentinel1 subdirectories: ' + str(subdirs))
+ cmd = 'mv ' + os.path.join(unpack_folder, subdirs[0]) +'/* ' + unpack_folder
+ print(cmd)
+ os.system(cmd)
+ source_image_paths = get_files_from_unpack_folder(unpack_folder)
+
+ if len(source_image_paths) != NUM_SOURCE_CHANNELS:
+ raise Exception('Did not find two image files in ' + zip_path)
+
+ USE_SNAP = True # To get real results we need to use SNAP
+
+ if USE_SNAP: # Requires the Sentinel1 processing software to be installed
+ # Run the preconfigured SNAP preprocessing graph
+ # - The SNAP tool *must* write to a .tif extension, so we have to
+ # rename the file if we want something else.
+ temp_out_path = merged_path.replace('.vrt', '.tif')
+ cmd = (SNAP_SCRIPT_PATH + ' ' + SNAP_GRAPH_PATH + ' '
+ + unpack_folder + ' ' + temp_out_path)
+ print(cmd)
+ os.system(cmd)
+ MIN_IMAGE_SIZE = 1024*1024*500 # 500 MB, expected size is much larger
+ if not os.path.exists(temp_out_path):
+ raise Exception('Failed to run ESA SNAP preprocessing.\n'
+ +'Do you have SNAP installed in the default location?')
+ if os.path.getsize(temp_out_path) < MIN_IMAGE_SIZE:
+ raise Exception('SNAP encountered a problem processing the file!')
+ os.system('mv ' + temp_out_path + ' ' + merged_path)
+ else:
+ # Generate a merged file containing all input images as an N channel image
+ cmd = 'gdalbuildvrt -separate ' + merged_path
+ for f in source_image_paths:
+ cmd += ' ' + f
+ print(cmd)
+ os.system(cmd)
+
+ # Verify that we generated a valid image file
+ try:
+ test_image = tiff.TiffImage(merged_path) #pylint: disable=W0612
+ except Exception as e: #pylint: disable=W0703
+ raise Exception('Failed to generate merged Sentinel1 file: ' + merged_path) from e
+
+ return merged_path
+
+
+class Sentinel1Image(tiff.TiffImage):
+ """Sentinel1 image tensorflow dataset wrapper (see imagery_dataset.py)"""
+ def __init__(self, paths, nodata_value=None):
+ self._meta_path = None
+ self._meta = None
+ self._sensor = None
+ self._date = None
+ self._name = None
+ super().__init__(paths, nodata_value)
+
+ def _unpack(self, zip_path):
+ # Get the folder where this will be stored from the cache manager
+ unpack_folder = config.io.cache.manager().register_item(self._name)
+
+ return unpack_s1_to_folder(zip_path, unpack_folder)
+
+ # This function is currently set up for the HDDS archived WV data, files from other
+ # locations will need to be handled differently.
+ def _prep(self, paths):
+ """Prepares a Sentinel1 file from the archive for processing.
+ Returns the path to the file ready to use.
+ --> This version does not do any preprocessing!!!
+ """
+ assert isinstance(paths, str)
+ ext = os.path.splitext(paths)[1]
+
+ tif_path = None
+ if ext == '.zip': # Need to unpack
+
+ tif_path = self._unpack(paths)
+
+ if ext == '.vrt': # Already unpacked
+
+ unpack_folder = os.path.dirname(paths)
+ tif_path = get_merged_path(unpack_folder)
+
+ assert tif_path is not None, f'Error: Unsupported extension {ext}'
+
+ return [tif_path]
diff --git a/delta/extensions/sources/sentinel1_default_snap_preprocess_graph.xml b/delta/extensions/sources/sentinel1_default_snap_preprocess_graph.xml
new file mode 100644
index 00000000..107dd79b
--- /dev/null
+++ b/delta/extensions/sources/sentinel1_default_snap_preprocess_graph.xml
@@ -0,0 +1,75 @@
+
+ 1.0
+
+ Calibration
+
+ ${sourceProduct}
+
+
+
+ Product Auxiliary File
+
+ false
+ false
+ false
+ false
+
+ true
+ false
+ false
+
+
+
+ Terrain-Correction
+
+
+
+
+
+ SRTM 1Sec HGT
+
+ 0.0
+ true
+ BILINEAR_INTERPOLATION
+ BILINEAR_INTERPOLATION
+ 10.0
+ 8.983152841195215E-5
+ GEOGCS["WGS84(DD)",
+ DATUM["WGS84",
+ SPHEROID["WGS84", 6378137.0, 298.257223563]],
+ PRIMEM["Greenwich", 0.0],
+ UNIT["degree", 0.017453292519943295],
+ AXIS["Geodetic longitude", EAST],
+ AXIS["Geodetic latitude", NORTH]]
+ false
+ 0.0
+ 0.0
+ true
+ false
+ false
+ false
+ false
+ false
+ true
+ false
+ false
+ false
+ false
+ false
+ Use projected local incidence angle from DEM
+ Use projected local incidence angle from DEM
+ Latest Auxiliary File
+
+
+
+
+ Write
+
+
+
+
+ ${targetProduct}
+ GeoTIFF-BigTIFF
+
+
+
diff --git a/delta/extensions/sources/snap_process_sentinel1.sh b/delta/extensions/sources/snap_process_sentinel1.sh
new file mode 100644
index 00000000..85553de4
--- /dev/null
+++ b/delta/extensions/sources/snap_process_sentinel1.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This is the default SNAP install location
+export PATH=~/snap/bin:$PATH
+gptPath="gpt"
+
+# Get input parameters
+graphXmlPath="$1"
+sourceFile="$2"
+targetFile="$3"
+
+# Execute the graph
+${gptPath} ${graphXmlPath} -e -PtargetProduct=${targetFile} ${sourceFile}
diff --git a/delta/imagery/sources/tiff.py b/delta/extensions/sources/tiff.py
similarity index 67%
rename from delta/imagery/sources/tiff.py
rename to delta/extensions/sources/tiff.py
index 3b868370..da36b5df 100644
--- a/delta/imagery/sources/tiff.py
+++ b/delta/extensions/sources/tiff.py
@@ -21,25 +21,44 @@
import os
import math
-from osgeo import gdal
import numpy as np
+from osgeo import gdal
+
+from delta.imagery import delta_image, rectangle
+
-from delta.config import config
-from delta.imagery import rectangle
+# Suppress GDAL warnings, errors become exceptions so we get them
+gdal.SetConfigOption('CPL_LOG', '/dev/null')
+gdal.UseExceptions()
-from . import delta_image
+_GDAL_TO_NUMPY_TYPES = {
+ gdal.GDT_Byte: np.dtype(np.uint8),
+ gdal.GDT_UInt16: np.dtype(np.uint16),
+ gdal.GDT_UInt32: np.dtype(np.uint32),
+ gdal.GDT_Float32: np.dtype(np.float32),
+ gdal.GDT_Float64: np.dtype(np.float64)
+}
+_NUMPY_TO_GDAL_TYPES = {v: k for k, v in _GDAL_TO_NUMPY_TYPES.items()}
class TiffImage(delta_image.DeltaImage):
- """For geotiffs."""
+ """Images supported by GDAL."""
- def __init__(self, path):
- '''
- Opens a geotiff for reading. paths can be either a single filename or a list.
- For a list, the images are opened in order as a multi-band image, assumed to overlap.
- '''
- super(TiffImage, self).__init__()
+ def __init__(self, path, nodata_value=None):
+ """
+ Opens a geotiff for reading.
+
+ Parameters
+ ----------
+ paths: str or List[str]
+ Either a single filename or a list.
+ For a list, the images are opened in order as a multi-band image, assumed to overlap.
+ nodata_value: dtype of image
+ Value representing no data.
+ """
+ super().__init__(nodata_value)
paths = self._prep(path)
+ self._path = path
self._paths = paths
self._handles = []
for p in paths:
@@ -63,8 +82,17 @@ def _prep(self, paths): #pylint:disable=no-self-use
"""
Prepare the file to be opened by other tools (unpack, etc).
- Returns a list of underlying files to load instead of the original path.
- This is intended to be overwritten by subclasses.
+ This can be overwritten by subclasses to, for example,
+ unpack a zip file to a cache directory.
+
+ Parameters
+ ----------
+ paths: str or List[str]
+ Paths passed to constructor
+
+ Returns
+ -------
+ Returns a list of underlying files to load instead of the original paths.
"""
if isinstance(paths, str):
return [paths]
@@ -75,10 +103,19 @@ def __asert_open(self):
raise IOError('Operating on an image that has been closed.')
def close(self):
+ """
+ Close the image.
+ """
self._handles = None # gdal doesn't have a close function for some reason
self._band_map = None
self._paths = None
+ def path(self):
+ """
+ Returns the paths returned by `_prep`.
+ """
+ return self._path
+
def num_bands(self):
self.__asert_open()
return len(self._band_map)
@@ -89,9 +126,10 @@ def size(self):
def _read(self, roi, bands, buf=None):
self.__asert_open()
+ num_bands = len(bands) if bands else self.num_bands()
if buf is None:
- buf = np.zeros(shape=(self.num_bands(), roi.width(), roi.height()), dtype=self.numpy_type())
+ buf = np.zeros(shape=(num_bands, roi.width(), roi.height()), dtype=self.dtype())
for i, b in enumerate(bands):
band_handle = self._gdal_band(b)
s = buf[i, :, :].shape
@@ -107,67 +145,43 @@ def _gdal_band(self, band):
assert ret
return ret
- def nodata_value(self, band=0):
- '''
- Returns the value that indicates no data is present in a pixel for the specified band.
- '''
- self.__asert_open()
- return self._gdal_band(band).GetNoDataValue()
-
- def data_type(self, band=0):
- '''
+ def _gdal_type(self, band=0):
+ """
Returns the GDAL data type of the image.
- '''
+ """
self.__asert_open()
return self._gdal_band(band).DataType
- def numpy_type(self, band=0):
+ def dtype(self):
self.__asert_open()
- dtype = self.data_type(band)
- if dtype == gdal.GDT_Byte:
- return np.uint8
- if dtype == gdal.GDT_UInt16:
- return np.uint16
- if dtype == gdal.GDT_UInt32:
- return np.uint32
- if dtype == gdal.GDT_Float32:
- return np.float32
- if dtype == gdal.GDT_Float64:
- return np.float64
+ dtype = self._gdal_type(0)
+ if dtype in _GDAL_TO_NUMPY_TYPES:
+ return _GDAL_TO_NUMPY_TYPES[dtype]
raise Exception('Unrecognized gdal data type: ' + str(dtype))
def bytes_per_pixel(self, band=0):
- '''
- Returns the number of bytes per pixel
- '''
+ """
+ Returns
+ -------
+ int:
+ the number of bytes per pixel
+ """
self.__asert_open()
- results = {
- gdal.GDT_Byte: 1,
- gdal.GDT_UInt16: 2,
- gdal.GDT_UInt32: 4,
- gdal.GDT_Float32: 4,
- gdal.GDT_Float64: 8
- }
- return results.get(self.data_type(band))
-
- def block_info(self, band=0):
- """Returns ((block height, block width), (num blocks x, num blocks y))"""
+ return gdal.GetDataTypeSize(self._gdal_type(band)) // 8
+
+ def block_size(self):
+ """
+ Returns
+ -------
+ (int, int):
+ block height, block width
+ """
self.__asert_open()
- band_handle = self._gdal_band(band)
+ band_handle = self._gdal_band(0)
block_size = band_handle.GetBlockSize()
-
- num_blocks_x = int(math.ceil(self.height() / block_size[1]))
- num_blocks_y = int(math.ceil(self.width() / block_size[0]))
-
- # we are backwards from gdal I think
- return ((block_size[1], block_size[0]), (num_blocks_x, num_blocks_y))
+ return (block_size[1], block_size[0])
def metadata(self):
- '''
- Returns all useful image metadata.
-
- If multiple images were specified, returns the information from the first.
- '''
self.__asert_open()
data = dict()
h = self._handles[0]
@@ -179,17 +193,13 @@ def metadata(self):
return data
def block_aligned_roi(self, desired_roi):
- '''
- Returns the block aligned pixel region to read in a Rectangle format
- to get the requested data region while respecting block boundaries.
- '''
self.__asert_open()
bounds = rectangle.Rectangle(0, 0, width=self.width(), height=self.height())
if not bounds.contains_rect(desired_roi):
raise Exception('desired_roi ' + str(desired_roi)
+ ' is outside the bounds of image with size' + str(self.size()))
- (block_size, unused_num_blocks) = self.block_info(0)
+ block_size = self.block_size()
start_block_x = int(math.floor(desired_roi.min_x / block_size[0]))
start_block_y = int(math.floor(desired_roi.min_y / block_size[1]))
# Rect max is exclusive
@@ -208,28 +218,36 @@ def block_aligned_roi(self, desired_roi):
bounds = rectangle.Rectangle(0, 0, width=self.width(), height=self.height())
return ans.get_intersection(bounds)
- def save(self, path, tile_size=(0,0), nodata_value=None, show_progress=False):
+ def save(self, path, tile_size=None, nodata_value=None, show_progress=False):
"""
- Save a TiffImage to the file output_path, optionally overwriting the tile_size.
+ Save to file, with preprocessing applied.
+
+ Parameters
+ ----------
+ path: str
+ Filename to save to.
+ tile_size: (int, int)
+ If specified, overwrite block size
+ nodata_value: image dtype
+ If specified, overwrite nodata value
+ show_progress: bool
+ Write progress bar to stdout
"""
if nodata_value is None:
nodata_value = self.nodata_value()
# Use the input tile size for the block size unless the user specified one.
- (bs, _) = self.block_info()
- block_size_x = bs[0]
- block_size_y = bs[1]
- if tile_size[0] > 0:
+ block_size_y, block_size_x = self.block_size()
+ if tile_size is not None:
block_size_x = tile_size[0]
- if tile_size[1] > 0:
block_size_y = tile_size[1]
# Set up the output image
- with TiffWriter(path, self.width(), self.height(), self.num_bands(),
- self.data_type(), block_size_x, block_size_y,
- nodata_value, self.metadata()) as writer:
+ with _TiffWriter(path, self.width(), self.height(), self.num_bands(),
+ self._gdal_type(), block_size_x, block_size_y,
+ nodata_value, self.metadata()) as writer:
input_bounds = rectangle.Rectangle(0, 0, width=self.width(), height=self.height())
- output_rois = input_bounds.make_tile_rois(block_size_x, block_size_y, include_partials=True)
+ output_rois = input_bounds.make_tile_rois((block_size_x, block_size_y), include_partials=True)
def callback_function(output_roi, data):
"""Callback function to write the first channel to the output file."""
@@ -244,53 +262,36 @@ def callback_function(output_roi, data):
self.process_rois(output_rois, callback_function, show_progress=show_progress)
-class RGBAImage(TiffImage):
- """Basic RGBA images where the alpha channel needs to be stripped"""
-
- def _prep(self, paths):
- """Converts RGBA images to RGB images"""
-
- # Get the path to the cached image
- fname = os.path.basename(paths)
- output_path = config.cache_manager().register_item(fname)
-
- if not os.path.exists(output_path):
- # Just remove the alpha band from the original image
- cmd = 'gdal_translate -b 1 -b 2 -b 3 ' + paths + ' ' + output_path
- os.system(cmd)
- return [output_path]
-
-def numpy_dtype_to_gdal_type(dtype): #pylint: disable=R0911
- if dtype == np.uint8:
- return gdal.GDT_Byte
- if dtype == np.uint16:
- return gdal.GDT_UInt16
- if dtype == np.uint32:
- return gdal.GDT_UInt32
- if dtype == np.int16:
- return gdal.GDT_Int16
- if dtype == np.int32:
- return gdal.GDT_Int32
- if dtype == np.float32:
- return gdal.GDT_Float32
- if dtype == np.float64:
- return gdal.GDT_Float64
+def _numpy_dtype_to_gdal_type(dtype): #pylint: disable=R0911
+ if dtype in _NUMPY_TO_GDAL_TYPES:
+ return _NUMPY_TO_GDAL_TYPES[dtype]
raise Exception('Unrecognized numpy data type: ' + str(dtype))
-def write_tiff(output_path, data, metadata=None):
- """Try to write a tiff file"""
+def write_tiff(output_path: str, data: np.ndarray, metadata: dict=None):
+ """
+ Write a numpy array to a file as a tiff.
+
+ Parameters
+ ----------
+ output_path: str
+ Filename to save tiff file to
+ data: numpy.ndarray
+ Image data to save.
+ metadata: dict
+ Optional metadata to include.
+ """
if len(data.shape) < 3:
num_bands = 1
else:
num_bands = data.shape[2]
- data_type = numpy_dtype_to_gdal_type(data.dtype)
+ data_type = _numpy_dtype_to_gdal_type(data.dtype)
TILE_SIZE=256
- with TiffWriter(output_path, data.shape[0], data.shape[1], num_bands=num_bands,
- data_type=data_type, metadata=metadata, tile_width=min(TILE_SIZE, data.shape[0]),
- tile_height=min(TILE_SIZE, data.shape[1])) as writer:
+ with _TiffWriter(output_path, data.shape[0], data.shape[1], num_bands=num_bands,
+ data_type=data_type, metadata=metadata, tile_width=min(TILE_SIZE, data.shape[0]),
+ tile_height=min(TILE_SIZE, data.shape[1])) as writer:
for x in range(0, data.shape[0], TILE_SIZE):
for y in range(0, data.shape[1], TILE_SIZE):
block = (x // TILE_SIZE, y // TILE_SIZE)
@@ -300,11 +301,12 @@ def write_tiff(output_path, data, metadata=None):
for b in range(num_bands):
writer.write_block(data[x:x+TILE_SIZE, y:y+TILE_SIZE, b], block[0], block[1], b)
-class TiffWriter:
- """Class to manage block writes to a Geotiff file.
+class _TiffWriter:
+ """
+ Class to manage block writes to a Geotiff file. Internal helper class.
"""
def __init__(self, path, width, height, num_bands=1, data_type=gdal.GDT_Byte, #pylint:disable=too-many-arguments
- tile_width=256, tile_height=256, no_data_value=None, metadata=None):
+ tile_width=256, tile_height=256, nodata_value=None, metadata=None):
self._width = width
self._height = height
self._tile_height = tile_height
@@ -313,8 +315,8 @@ def __init__(self, path, width, height, num_bands=1, data_type=gdal.GDT_Byte, #p
# Constants
options = ['COMPRESS=LZW', 'BigTIFF=IF_SAFER', 'INTERLEAVE=BAND']
- options += ['BLOCKXSIZE='+str(self._tile_height),
- 'BLOCKYSIZE='+str(self._tile_width)]
+ options += ['BLOCKXSIZE='+str(self._tile_width),
+ 'BLOCKYSIZE='+str(self._tile_height)]
MIN_SIZE_FOR_TILES=100
if width > MIN_SIZE_FOR_TILES or height > MIN_SIZE_FOR_TILES:
options += ['TILED=YES']
@@ -324,9 +326,9 @@ def __init__(self, path, width, height, num_bands=1, data_type=gdal.GDT_Byte, #p
if not self._handle:
raise Exception('Failed to create output file: ' + path)
- if no_data_value is not None:
+ if nodata_value is not None:
for i in range(1,num_bands+1):
- self._handle.GetRasterBand(i).SetNoDataValue(no_data_value)
+ self._handle.GetRasterBand(i).SetNoDataValue(nodata_value)
if metadata:
self._handle.SetProjection (metadata['projection' ])
@@ -349,12 +351,6 @@ def close(self):
self._handle.FlushCache()
self._handle = None
- def get_size(self):
- return (self._width, self._height)
-
- def get_tile_size(self):
- return (self._tile_width, self._tile_height)
-
def get_num_tiles(self):
num_x = int(math.ceil(self._width / self._tile_width))
num_y = int(math.ceil(self._height / self._tile_height))
@@ -408,31 +404,26 @@ def write_region(self, data, x, y):
assert gdal_band
gdal_band.WriteArray(data[:, :, band], y, x)
-class DeltaTiffWriter(delta_image.DeltaImageWriter):
+class TiffWriter(delta_image.DeltaImageWriter):
+ """
+ Write a geotiff to a file.
+ """
def __init__(self, filename):
self._filename = filename
self._tiff_w = None
- def initialize(self, size, numpy_dtype, metadata=None):
- """
- Prepare for writing with the given size and dtype.
- """
+ def initialize(self, size, numpy_dtype, metadata=None, nodata_value=None):
assert (len(size) == 3), ('Error: len(size) of '+str(size)+' != 3')
TILE_SIZE = 256
- self._tiff_w = TiffWriter(self._filename, size[0], size[1], num_bands=size[2],
- data_type=numpy_dtype_to_gdal_type(numpy_dtype), metadata=metadata,
- tile_width=min(TILE_SIZE, size[0]), tile_height=min(TILE_SIZE, size[1]))
+ self._tiff_w = _TiffWriter(self._filename, size[0], size[1], num_bands=size[2],
+ data_type=_numpy_dtype_to_gdal_type(numpy_dtype), metadata=metadata,
+ nodata_value=nodata_value,
+ tile_width=min(TILE_SIZE, size[0]), tile_height=min(TILE_SIZE, size[1]))
def write(self, data, x, y):
- """
- Writes the data as a rectangular block starting at the given coordinates.
- """
self._tiff_w.write_region(data, x, y)
def close(self):
- """
- Finish writing.
- """
if self._tiff_w is not None:
self._tiff_w.close()
diff --git a/delta/imagery/sources/worldview.py b/delta/extensions/sources/worldview.py
similarity index 57%
rename from delta/imagery/sources/worldview.py
rename to delta/extensions/sources/worldview.py
index fb1b0923..db827669 100644
--- a/delta/imagery/sources/worldview.py
+++ b/delta/extensions/sources/worldview.py
@@ -19,16 +19,12 @@
Functions to support the WorldView satellites.
"""
-import math
import zipfile
import functools
import os
-import sys
import numpy as np
import portalocker
-import tensorflow as tf
-
from delta.config import config
from delta.imagery import utilities
from . import tiff
@@ -36,7 +32,7 @@
# Use this value for all WorldView nodata values we write, though they usually don't have any nodata.
OUTPUT_NODATA = 0.0
-def _get_files_from_unpack_folder(folder):
+def get_files_from_unpack_folder(folder):
"""Return the image and header file paths from the given unpack folder.
Returns (None, None) if the files were not found.
"""
@@ -51,48 +47,62 @@ def _get_files_from_unpack_folder(folder):
main_files = os.listdir(folder)
vendor_files = os.listdir(vendor_folder)
for f in vendor_files:
- if os.path.splitext(f)[1] == '.IMD':
+ ext = os.path.splitext(f)[1]
+ if ext.lower() == '.imd':
imd_path = os.path.join(vendor_folder, f)
break
for f in main_files:
- if os.path.splitext(f)[1] == '.tif':
+ ext = os.path.splitext(f)[1]
+ if ext.lower() == '.tif':
tif_path = os.path.join(folder, f)
break
return (tif_path, imd_path)
+
+def unpack_wv_to_folder(zip_path, unpack_folder):
+
+ with portalocker.Lock(zip_path, 'r', timeout=300) as unused: #pylint: disable=W0612
+ # Check if we already unpacked this data
+ (tif_path, imd_path) = get_files_from_unpack_folder(unpack_folder)
+
+ if imd_path and tif_path:
+ pass
+ else:
+ print('Unpacking file ' + zip_path + ' to folder ' + unpack_folder)
+ utilities.unpack_to_folder(zip_path, unpack_folder)
+ # some worldview zip files have a subdirectory with the name of the image
+ if not os.path.exists(os.path.join(unpack_folder, 'vendor_metadata')):
+ subdir = os.path.join(unpack_folder, os.path.splitext(os.path.basename(zip_path))[0])
+ if not os.path.exists(os.path.join(subdir, 'vendor_metadata')):
+ raise Exception('vendor_metadata not found in %s.' % (zip_path))
+ for filename in os.listdir(subdir):
+ os.rename(os.path.join(subdir, filename), os.path.join(unpack_folder, filename))
+ os.rmdir(subdir)
+ (tif_path, imd_path) = get_files_from_unpack_folder(unpack_folder)
+ return (tif_path, imd_path)
+
+
class WorldviewImage(tiff.TiffImage):
- """Compressed WorldView image tensorflow dataset wrapper (see imagery_dataset.py)"""
- def __init__(self, paths):
+ """Compressed WorldView image. Loads an image from a zip file with a tiff and a .imd file."""
+ def __init__(self, paths, nodata_value=None):
self._meta_path = None
- self._meta = None
- super(WorldviewImage, self).__init__(paths)
+ self._meta = None
+ self._sensor = None
+ self._date = None
+ self._name = None
+ super().__init__(paths, nodata_value)
- def _unpack(self, paths):
+ def _unpack(self, zip_path):
# Get the folder where this will be stored from the cache manager
unpack_folder = config.io.cache.manager().register_item(self._name)
+ return unpack_wv_to_folder(zip_path, unpack_folder)
+
+ def _set_info_from_tif_name(self, tif_name):
+ parts = os.path.basename(tif_name).split('_')
+ self._sensor = parts[0][0:4]
+ self._date = parts[2][6:14]
+ self._name = os.path.splitext(os.path.basename(tif_name))[0]
- with portalocker.Lock(paths, 'r', timeout=300) as unused: #pylint: disable=W0612
- # Check if we already unpacked this data
- (tif_path, imd_path) = _get_files_from_unpack_folder(unpack_folder)
-
- if imd_path and tif_path:
- #tf.print('Already have unpacked files in ' + unpack_folder,
- # output_stream=sys.stdout)
- pass
- else:
- tf.print('Unpacking file ' + paths + ' to folder ' + unpack_folder,
- output_stream=sys.stdout)
- utilities.unpack_to_folder(paths, unpack_folder)
- # some worldview zip files have a subdirectory with the name of the image
- if not os.path.exists(os.path.join(unpack_folder, 'vendor_metadata')):
- subdir = os.path.join(unpack_folder, os.path.splitext(os.path.basename(paths))[0])
- if not os.path.exists(os.path.join(subdir, 'vendor_metadata')):
- raise Exception('vendor_metadata not found in %s.' % (paths))
- for filename in os.listdir(subdir):
- os.rename(os.path.join(subdir, filename), os.path.join(unpack_folder, filename))
- os.rmdir(subdir)
- (tif_path, imd_path) = _get_files_from_unpack_folder(unpack_folder)
- return (tif_path, imd_path)
# This function is currently set up for the HDDS archived WV data, files from other
# locations will need to be handled differently.
@@ -103,24 +113,36 @@ def _prep(self, paths):
"""
assert isinstance(paths, str)
(_, ext) = os.path.splitext(paths)
- assert '.zip' in ext, f'Error: Was assuming a zip file. Found {paths}'
+ tif_name = None
- zip_file = zipfile.ZipFile(paths, 'r')
- tif_names = list(filter(lambda x: '.tif' in x, zip_file.namelist()))
- assert len(tif_names) > 0, f'Error: no tif files in the file {paths}'
- assert len(tif_names) == 1, f'Error: too many tif files in {paths}: {tif_names}'
- tif_name = tif_names[0]
+ if ext == '.zip': # Need to unpack
+ zip_file = zipfile.ZipFile(paths, 'r')
+ tif_names = list(filter(lambda x: x.lower().endswith('.tif'), zip_file.namelist()))
+ assert len(tif_names) > 0, f'Error: no tif files in the file {paths}'
+ assert len(tif_names) == 1, f'Error: too many tif files in {paths}: {tif_names}'
+ tif_name = tif_names[0]
- parts = os.path.basename(tif_name).split('_')
- self._sensor = parts[0][0:4]
- self._date = parts[2][6:14]
- self._name = os.path.splitext(os.path.basename(tif_name))[0]
+ self._set_info_from_tif_name(tif_name)
+
+ (tif_path, imd_path) = self._unpack(paths)
+
+ if ext == '.tif': # Already unpacked
- (tif_path, imd_path) = self._unpack(paths)
+ # Both files should be present in the same folder
+ tif_name = paths
+ unpack_folder = os.path.dirname(paths)
+ (tif_path, imd_path) = get_files_from_unpack_folder(unpack_folder)
+
+ if not (imd_path and tif_path):
+ raise Exception('vendor_metadata not found in %s.' % (paths))
+ self._set_info_from_tif_name(tif_name)
+
+ assert tif_name is not None, f'Error: Unsupported extension {ext}'
self._meta_path = imd_path
self.__parse_meta_file(imd_path)
+
return [tif_path]
def meta_path(self):
@@ -171,23 +193,23 @@ def bandwidth(self):
return self._meta['EFFECTIVEBANDWIDTH']
# TOA correction
-def _get_esun_value(sat_id, band):
- """Get the ESUN value for the given satellite and band"""
-
- VALUES = {'WV02':[1580.814, 1758.2229, 1974.2416, 1856.4104,
- 1738.4791, 1559.4555, 1342.0695, 1069.7302, 861.2866],
- 'WV03':[1583.58, 1743.81, 1971.48, 1856.26,
- 1749.4, 1555.11, 1343.95, 1071.98, 863.296]}
- try:
- return VALUES[sat_id][band]
- except Exception:
- raise Exception('No ESUN value for ' + sat_id
- + ', band ' + str(band))
-
-def _get_earth_sun_distance():
- """Returns the distance between the Earth and the Sun in AU for the given date"""
- # TODO: Copy the calculation from the WV manuals.
- return 1.0
+#def _get_esun_value(sat_id, band):
+# """Get the ESUN value for the given satellite and band"""
+#
+# VALUES = {'WV02':[1580.814, 1758.2229, 1974.2416, 1856.4104,
+# 1738.4791, 1559.4555, 1342.0695, 1069.7302, 861.2866],
+# 'WV03':[1583.58, 1743.81, 1971.48, 1856.26,
+# 1749.4, 1555.11, 1343.95, 1071.98, 863.296]}
+# try:
+# return VALUES[sat_id][band]
+# except Exception as e:
+# raise Exception('No ESUN value for ' + sat_id
+# + ', band ' + str(band)) from e
+
+#def _get_earth_sun_distance():
+# """Returns the distance between the Earth and the Sun in AU for the given date"""
+# # TODO: Copy the calculation from the WV manuals.
+# return 1.0
# The np.where clause handles input nodata values.
@@ -200,17 +222,17 @@ def _apply_toa_radiance(data, _, bands, factors, widths):
buf[:, :, b] = np.where(data[:, :, b] > 0, (data[:, :, b] * f) / w, OUTPUT_NODATA)
return buf
-def _apply_toa_reflectance(data, band, factor, width, sun_elevation,
- satellite, earth_sun_distance):
- """Apply a top of atmosphere reflectance conversion to WorldView data"""
- f = factor[band]
- w = width [band]
-
- esun = _get_esun_value(satellite, band)
- des2 = earth_sun_distance*earth_sun_distance
- theta = np.pi/2.0 - sun_elevation
- scaling = (des2*np.pi) / (esun*math.cos(theta))
- return np.where(data>0, ((data*f)/w)*scaling, OUTPUT_NODATA)
+#def _apply_toa_reflectance(data, band, factor, width, sun_elevation,
+# satellite, earth_sun_distance):
+# """Apply a top of atmosphere reflectance conversion to WorldView data"""
+# f = factor[band]
+# w = width [band]
+#
+# esun = _get_esun_value(satellite, band)
+# des2 = earth_sun_distance*earth_sun_distance
+# theta = np.pi/2.0 - sun_elevation
+# scaling = (des2*np.pi) / (esun*math.cos(theta))
+# return np.where(data>0, ((data*f)/w)*scaling, OUTPUT_NODATA)
def toa_preprocess(image, calc_reflectance=False):
diff --git a/delta/imagery/__init__.py b/delta/imagery/__init__.py
index c569a57a..373fbf07 100644
--- a/delta/imagery/__init__.py
+++ b/delta/imagery/__init__.py
@@ -21,5 +21,5 @@
For loading training data into Tensorflow, see
`delta.imagery.imagery_dataset.ImageryDataset`.
For dealing with imagery directly, see
-`delta.imagery.sources`.
+`delta.extensions.sources`.
"""
diff --git a/delta/imagery/delta_image.py b/delta/imagery/delta_image.py
new file mode 100644
index 00000000..f206da98
--- /dev/null
+++ b/delta/imagery/delta_image.py
@@ -0,0 +1,393 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Base classes for reading and writing images.
+"""
+
+from abc import ABC, abstractmethod
+import concurrent.futures
+import copy
+import functools
+from typing import Callable, Iterator, List, Tuple
+
+import numpy as np
+
+from delta.imagery import rectangle, utilities
+
+class DeltaImage(ABC):
+ """
+ Base class used for wrapping input images in DELTA. Can be extended
+ to support new data types. A variety of image types are implemented in
+ `delta.extensions.sources`.
+ """
+ def __init__(self, nodata_value=None):
+ """
+ Parameters
+ ----------
+ nodata_value: Optional[Any]
+ Nodata value for the image, if any.
+ """
+ self.__preprocess_function = None
+ self.__nodata_value = nodata_value
+
+ def read(self, roi: rectangle.Rectangle=None, bands: List[int]=None, buf: np.ndarray=None) -> np.ndarray:
+ """
+ Reads the image in [row, col, band] indexing.
+
+ Subclasses should generally not overwrite this method--- they will likely want to implement
+ `_read`.
+
+ Parameters
+ ----------
+ roi: `rectangle.Rectangle`
+ The bounding box to read from the image. If None, read the entire image.
+ bands: List[int]
+ Bands to load (zero-indexed). If None, read all bands.
+ buf: np.ndarray
+ If specified, reads the image into this buffer. Must be sufficiently large.
+
+ Returns
+ -------
+ np.ndarray:
+ A buffer containing the requested part of the image.
+ """
+ if roi is None:
+ roi = rectangle.Rectangle(0, 0, width=self.width(), height=self.height())
+ else:
+ if roi.min_x < 0 or roi.min_y < 0 or roi.max_x > self.width() or roi.max_y > self.height():
+ raise IndexError('Rectangle (%d, %d, %d, %d) outside of bounds (%d, %d).' %
+ (roi.min_x, roi.min_y, roi.max_x, roi.max_y, self.width(), self.height()))
+ if bands is None:
+ bands = range(self.num_bands())
+ if isinstance(bands, int):
+ result = self._read(roi, [bands], buf)
+ result = result[:, :, 0] # reduce dimensions
+ else:
+ result = self._read(roi, bands, buf)
+ if self.__preprocess_function:
+ return self.__preprocess_function(result, roi, bands)
+ return result
+
+ def set_preprocess(self, callback: Callable[[np.ndarray, rectangle.Rectangle, List[int]], np.ndarray]):
+ """
+ Set a preproprocessing function callback to be applied to the results of
+ all reads on the image.
+
+ Parameters
+ ----------
+ callback: Callable[[np.ndarray, rectangle.Rectangle, List[in]], np.ndarray]
+ A function to be called on loading image data, callback(image, roi, bands),
+ where `image` is the numpy array containing the read data, `roi` is the region of interest read,
+ and `bands` is a list of the bands read. Must return a numpy array.
+ """
+ self.__preprocess_function = callback
+
+ def get_preprocess(self) -> Callable[[np.ndarray, rectangle.Rectangle, List[int]], np.ndarray]:
+ """
+ Returns
+ -------
+ Callable[[np.ndarray, rectangle.Rectangle, List[int]], np.ndarray]
+ The preprocess function currently set.
+ """
+ return self.__preprocess_function
+
+ def nodata_value(self):
+ """
+ Returns
+ -------
+ The value of pixels to treat as nodata.
+ """
+ return self.__nodata_value
+
+ @abstractmethod
+ def _read(self, roi: rectangle.Rectangle, bands: List[int], buf: np.ndarray=None) -> np.ndarray:
+ """
+ Read the image.
+
+ Abstract function to be implemented by subclasses. Users should call `read` instead.
+
+ Parameters
+ ----------
+ roi: rectangle.Rectangle
+ Segment of the image to read.
+ bands: List[int]
+ List of bands to read (zero-indexed).
+ buf: np.ndarray
+ Buffer to read into. If not specified, a new buffer should be allocated.
+
+ Returns
+ -------
+ np.ndarray:
+ The relevant part of the image as a numpy array.
+ """
+
+ def metadata(self): #pylint:disable=no-self-use
+ """
+ Returns
+ -------
+ A dictionary of metadata, if any is given for the image type.
+ """
+ return {}
+
+ @abstractmethod
+ def size(self) -> Tuple[int, int]:
+ """
+ Returns
+ -------
+ Tuple[int, int]:
+ The size of this image in pixels, as (width, height).
+ """
+
+ @abstractmethod
+ def num_bands(self) -> int:
+ """
+ Returns
+ -------
+ int:
+ The number of bands in this image.
+ """
+
+ @abstractmethod
+ def dtype(self) -> np.dtype:
+ """
+ Returns
+ -------
+ numpy.dtype:
+ The underlying data type of the image.
+ """
+
+ def block_aligned_roi(self, desired_roi: rectangle.Rectangle) -> rectangle.Rectangle:#pylint:disable=no-self-use
+ """
+ Parameters
+ ----------
+ desired_roi: rectangle.Rectangle
+ Original region of interest.
+
+ Returns
+ -------
+ rectangle.Rectangle:
+ The block-aligned roi containing the specified roi.
+ """
+ return desired_roi
+
+ def block_size(self): #pylint: disable=no-self-use
+ """
+ Returns
+ -------
+ (int, int):
+ The suggested block size for efficient reading.
+ """
+ return (256, 256)
+
+ def width(self) -> int:
+ """
+ Returns
+ -------
+ int:
+ The number of image columns
+ """
+ return self.size()[0]
+
+ def height(self) -> int:
+ """
+ Returns
+ -------
+ int:
+ The number of image rows
+ """
+ return self.size()[1]
+
+ def tiles(self, shape, overlap_shape=(0, 0), partials: bool=True, min_shape=(0, 0),
+ partials_overlap: bool=False, by_block=False):
+ """
+ Splits the image into tiles with the given properties.
+
+ Parameters
+ ----------
+ shape: (int, int)
+ Shape of each tile
+ overlap_shape: (int, int)
+ Amount to overlap tiles in x and y direction
+ partials: bool
+ If true, include partial tiles at the edge of the image.
+ min_shape: (int, int)
+ If true and `partials` is true, keep partial tiles of this minimum size.
+ partials_overlap: bool
+ If `partials` is false, and this is true, expand partial tiles
+ to the desired size. Tiles may overlap in some areas.
+ by_block: bool
+ If true, changes the returned generator to group tiles by block.
+ This is intended to optimize disk reads by reading the entire block at once.
+
+ Returns
+ -------
+ List[Rectangle] or List[(Rectangle, List[Rectangle])]
+ List of ROIs. If `by_block` is true, returns a list of (Rectangle, List[Rectangle])
+ instead, where the first rectangle is a larger block containing multiple tiles in a list.
+ """
+ input_bounds = rectangle.Rectangle(0, 0, max_x=self.width(), max_y=self.height())
+ return input_bounds.make_tile_rois(shape, overlap_shape=overlap_shape, include_partials=partials,
+ min_shape=min_shape, partials_overlap=partials_overlap,
+ by_block=by_block)
+
+ def roi_generator(self, requested_rois: Iterator[rectangle.Rectangle]) -> \
+ Iterator[Tuple[rectangle.Rectangle, np.ndarray, int, int]]:
+ """
+ Generator that yields image blocks of the requested rois.
+
+ Parameters
+ ----------
+ requested_rois: Iterator[Rectangle]
+ Regions of interest to read.
+
+ Returns
+ -------
+ Iterator[Tuple[Rectangle, numpy.ndarray, int, int]]
+ A generator with read image regions. In each tuple, the first item
+ is the region of interest, the second is a numpy array of the image contents,
+ the third is the index of the current region of interest, and the fourth is the total
+ number of rois.
+ """
+ block_rois = copy.copy(requested_rois)
+
+ whole_bounds = rectangle.Rectangle(0, 0, width=self.width(), height=self.height())
+ for roi in requested_rois:
+ if not whole_bounds.contains_rect(roi):
+ raise Exception('Roi outside image bounds: ' + str(roi) + str(whole_bounds))
+
+ # gdal doesn't work reading multithreading. But this let's a thread
+ # take care of IO input while we do computation.
+ jobs = []
+
+ total_rois = len(block_rois)
+ while block_rois:
+ # For the next (output) block, figure out the (input block) aligned
+ # data read that we need to perform to get it.
+ read_roi = self.block_aligned_roi(block_rois[0])
+
+ applicable_rois = []
+
+ # Loop through the remaining ROIs and apply the callback function to each
+ # ROI that is contained in the section we read in.
+ index = 0
+ while index < len(block_rois):
+
+ if not read_roi.contains_rect(block_rois[index]):
+ index += 1
+ continue
+ applicable_rois.append(block_rois.pop(index))
+
+ jobs.append((read_roi, applicable_rois))
+
+ # only do a few reads ahead since otherwise we will exhaust our memory
+ pending = []
+ exe = concurrent.futures.ThreadPoolExecutor(1)
+ NUM_AHEAD = 2
+ for i in range(min(NUM_AHEAD, len(jobs))):
+ pending.append(exe.submit(functools.partial(self.read, jobs[i][0])))
+ num_remaining = total_rois
+ for (i, (read_roi, rois)) in enumerate(jobs):
+ buf = pending.pop(0).result()
+ for roi in rois:
+ x0 = roi.min_x - read_roi.min_x
+ y0 = roi.min_y - read_roi.min_y
+ num_remaining -= 1
+ yield (roi, buf[x0:x0 + roi.width(), y0:y0 + roi.height(), :],
+ (total_rois - num_remaining, total_rois))
+ if i + NUM_AHEAD < len(jobs):
+ pending.append(exe.submit(functools.partial(self.read, jobs[i + NUM_AHEAD][0])))
+
+ def process_rois(self, requested_rois: Iterator[rectangle.Rectangle],
+ callback_function: Callable[[rectangle.Rectangle, np.ndarray], None],
+ show_progress: bool=False) -> None:
+ """
+ Apply a callback function to a list of ROIs.
+
+ Parameters
+ ----------
+ requested_rois: Iterator[Rectangle]
+ Regions of interest to evaluate
+ callback_function: Callable[[rectangle.Rectangle, np.ndarray], None]
+ A function to apply to each requested region. Pass the bounding box
+ of the current region and a numpy array of pixel values as inputs.
+ show_progress: bool
+ Print a progress bar on the command line if true.
+ """
+ for (roi, buf, (i, total)) in self.roi_generator(requested_rois):
+ callback_function(roi, buf)
+ if show_progress:
+ utilities.progress_bar('%d / %d' % (i, total), i / total, prefix='Blocks Processed:')
+ if show_progress:
+ print()
+
+class DeltaImageWriter(ABC):
+ """
+ Base class for writing images in DELTA.
+ """
+ @abstractmethod
+ def initialize(self, size, numpy_dtype, metadata=None, nodata_value=None):
+ """
+ Prepare for writing.
+
+ Parameters
+ ----------
+ size: tuple of ints
+ Dimensions of the image to write.
+ numpy_dtype: numpy.dtype
+ Type of the underling data.
+ metadata: dict
+ Dictionary of metadata to save with the image.
+ nodata_value: numpy_dtype
+ Value representing nodata in the image.
+ """
+
+ @abstractmethod
+ def write(self, data: np.ndarray, x: int, y: int):
+ """
+ Write a portion of the image.
+
+ Parameters
+ ----------
+ data: np.ndarray
+ A block of image data to write.
+ x: int
+ y: int
+ Top-left coordinates of the block of data to write.
+ """
+
+ @abstractmethod
+ def close(self):
+ """
+ Finish writing, perform cleanup.
+ """
+
+ @abstractmethod
+ def abort(self):
+ """
+ Cancel writing before finished, perform cleanup.
+ """
+
+ def __del__(self):
+ self.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *unused):
+ self.close()
+ return False
diff --git a/delta/imagery/disk_folder_cache.py b/delta/imagery/disk_folder_cache.py
index a900117e..f15fb343 100644
--- a/delta/imagery/disk_folder_cache.py
+++ b/delta/imagery/disk_folder_cache.py
@@ -26,10 +26,14 @@ class DiskCache:
It is safe to mix different datasets in the cache folder, though all items in
the folder will count towards the limit.
"""
- def __init__(self, top_folder, limit):
+ def __init__(self, top_folder: str, limit: int):
"""
- The top level folder to store cached items in and the number to store
- are specified.
+ Parameters
+ ----------
+ top_folder: str
+ Top level cache directory.
+ limit: int
+ Maximum number of items to keep in cache.
"""
if limit < 1:
raise Exception('Illegal limit passed to Disk Cache: ' + str(limit))
@@ -37,8 +41,8 @@ def __init__(self, top_folder, limit):
if not os.path.exists(top_folder):
try:
os.mkdir(top_folder)
- except:
- raise Exception('Could not create disk cache folder: ' + top_folder)
+ except Exception as e:
+ raise Exception('Could not create disk cache folder: ' + top_folder) from e
self._limit = limit
self._folder = top_folder
@@ -48,25 +52,44 @@ def __init__(self, top_folder, limit):
def limit(self):
"""
- The number of items to store in the cache.
+ Returns
+ -------
+ int:
+ The maximum number of items to cache.
"""
return self._limit
def folder(self):
"""
- The directory to store cached items in.
+ Returns
+ -------
+ str:
+ The cache directory.
"""
return self._folder
def num_cached(self):
"""
- The number of items currently cached.
+ Returns
+ -------
+ int:
+ The number of items currently cached.
"""
return len(self._item_list)
def register_item(self, name):
"""
- Register a new item with the cache manager and return the full path to it.
+ Register a new item with the cache manager.
+
+ Parameters
+ ----------
+ name: str
+ Filename of the item to add to the cache.
+
+ Returns
+ -------
+ str:
+ Full path to store the item in the cache.
"""
# If we already have the name just move it to the back of the list
diff --git a/delta/imagery/imagery_config.py b/delta/imagery/imagery_config.py
index 9720a37b..d511153c 100644
--- a/delta/imagery/imagery_config.py
+++ b/delta/imagery/imagery_config.py
@@ -25,6 +25,7 @@
import appdirs
from delta.config import config, DeltaConfigComponent, validate_path, validate_positive
+from delta.config.extensions import image_reader, preprocess_function
from . import disk_folder_cache
@@ -36,13 +37,18 @@ class ImageSet:
"""
def __init__(self, images, image_type, preprocess=None, nodata_value=None):
"""
- The parameters for the constructor are:
-
- * An iterable of image filenames `images`
- * The image type (i.e., tiff, worldview, landsat) `image_type`
- * An optional preprocessing function to apply to the image,
- following the signature in `delta.imagery.sources.delta_image.DeltaImage.set_process`.
- * A `nodata_value` for pixels to disregard
+ Parameters
+ ----------
+ images: Iterator[str]
+ Image filenames
+ image_type: str
+ The image type as a string (i.e., tiff, worldview, landsat). Must have
+ been previously registered with `delta.config.extensions.register_image_reader`.
+ preprocess: Callable
+ Optional preprocessing function to apply to the image
+ following the signature in `delta.imagery.delta_image.DeltaImage.set_preprocess`.
+ nodata_value: image dtype
+ A no data value for pixels to disregard
"""
self._images = images
self._image_type = image_type
@@ -51,19 +57,59 @@ def __init__(self, images, image_type, preprocess=None, nodata_value=None):
def type(self):
"""
- The type of the image (used by `delta.imagery.sources.loader`).
+ Returns
+ -------
+ str:
+ The type of the image
"""
return self._image_type
def preprocess(self):
"""
- Return the preprocessing function.
+ Returns
+ -------
+ Callable:
+ The preprocessing function
"""
return self._preprocess
def nodata_value(self):
"""
- Value of pixels to disregard.
+ Returns
+ -------
+ image dtype:
+ Value of pixels to disregard.
"""
return self._nodata_value
+
+ def set_nodata_value(self, nodata):
+ """
+ Set the pixel value to disregard.
+
+ Parameters
+ ----------
+ nodata: image dtype
+ The pixel value to set as nodata
+ """
+ self._nodata_value = nodata
+
+ def load(self, index):
+ """
+ Loads the image of the given index.
+
+ Parameters
+ ----------
+ index: int
+ Index of the image to load.
+
+ Returns
+ -------
+ `delta.imagery.delta_image.DeltaImage`:
+ The image
+ """
+ img = image_reader(self.type())(self[index], self.nodata_value())
+ if self._preprocess:
+ img.set_preprocess(self._preprocess)
+ return img
+
def __len__(self):
return len(self._images)
def __getitem__(self, index):
@@ -76,23 +122,13 @@ def __iter__(self):
__DEFAULT_EXTENSIONS = {'tiff' : '.tiff',
'worldview' : '.zip',
'landsat' : '.zip',
- 'npy' : '.npy'}
-__DEFAULT_SCALE_FACTORS = {'tiff' : 1024.0,
- 'worldview' : 2048.0,
- 'landsat' : 120.0,
- 'npy' : None}
+ 'npy' : '.npy',
+ 'sentinel1' : '.zip'}
+
def __extension(conf):
if conf['extension'] == 'default':
return __DEFAULT_EXTENSIONS.get(conf['type'])
return conf['extension']
-def __scale_factor(image_comp):
- f = image_comp.preprocess.scale_factor()
- if f == 'default':
- return __DEFAULT_SCALE_FACTORS.get(image_comp.type())
- try:
- return float(f)
- except ValueError:
- raise ValueError('Scale factor is %s, must be a float.' % (f))
def __find_images(conf, matching_images=None, matching_conf=None):
'''
@@ -100,25 +136,26 @@ def __find_images(conf, matching_images=None, matching_conf=None):
If matching_images and matching_conf are specified, we find the labels matching these images.
'''
images = []
- if (conf['files'] is None) != (conf['file_list'] is None) != (conf['directory'] is None):
- raise ValueError('''Too many image specification methods used.\n
- Choose one of "files", "file_list" and "directory" when indicating
- file locations.''')
if conf['type'] not in __DEFAULT_EXTENSIONS:
raise ValueError('Unexpected image type %s.' % (conf['type']))
if conf['files']:
+ assert conf['file_list'] is None and conf['directory'] is None, 'Only one image specification allowed.'
images = conf['files']
+ for (i, im) in enumerate(images):
+ images[i] = os.path.normpath(im)
elif conf['file_list']:
+ assert conf['directory'] is None, 'Only one image specification allowed.'
with open(conf['file_list'], 'r') as f:
for line in f:
- images.append(line)
+ images.append(os.path.normpath(line.strip()))
elif conf['directory']:
extension = __extension(conf)
if not os.path.exists(conf['directory']):
raise ValueError('Supplied images directory %s does not exist.' % (conf['directory']))
if matching_images is None:
- for root, _, filenames in os.walk(conf['directory']):
+ for root, _, filenames in os.walk(conf['directory'],
+ followlinks=True):
for filename in filenames:
if filename.endswith(extension):
images.append(os.path.join(root, filename))
@@ -127,21 +164,18 @@ def __find_images(conf, matching_images=None, matching_conf=None):
for m in matching_images:
rel_path = os.path.relpath(m, matching_conf['directory'])
label_path = os.path.join(conf['directory'], rel_path)
- images.append(os.path.splitext(label_path)[0] + extension)
+ if matching_conf['directory'] is None:
+ images.append(os.path.splitext(label_path)[0] + extension)
+ else:
+ # if custom extension, remove it
+ label_path = label_path[:-len(__extension(matching_conf))]
+ images.append(label_path + extension)
for img in images:
if not os.path.exists(img):
raise ValueError('Image file %s does not exist.' % (img))
return images
-def __preprocess_function(image_comp):
- if not image_comp.preprocess.enabled():
- return None
- f = __scale_factor(image_comp)
- if f is None:
- return None
- return lambda data, _, dummy: data / np.float32(f)
-
def load_images_labels(images_comp, labels_comp, classes_comp):
'''
Takes two configuration subsections and returns (image set, label set). Also takes classes
@@ -160,7 +194,7 @@ def load_images_labels(images_comp, labels_comp, classes_comp):
label_extension = __extension(labels_dict)
images = [img for img in images if not img.endswith(label_extension)]
- pre = __preprocess_function(images_comp)
+ pre = images_comp.preprocess_function()
imageset = ImageSet(images, images_dict['type'], pre, images_dict['nodata_value'])
if (labels_dict['files'] is None) and (labels_dict['file_list'] is None) and (labels_dict['directory'] is None):
@@ -171,18 +205,72 @@ def load_images_labels(images_comp, labels_comp, classes_comp):
if len(labels) != len(images):
raise ValueError('%d images found, but %d labels found.' % (len(images), len(labels)))
- pre = pre_orig = __preprocess_function(labels_comp)
- conv = classes_comp.classes_to_indices_func()
- if conv is not None:
- pre = lambda data, _, dummy: conv(pre_orig(data, _, dummy) if pre_orig is not None else data)
+ labels_nodata = labels_dict['nodata_value']
+ pre_orig = labels_comp.preprocess_function()
+ # we shift the label images to always be 0...n[+1], Class 1, Class 2, ... Class N, [nodata]
+ def class_shift(data, _, dummy):
+ if pre_orig is not None:
+ data = pre_orig(data, _, dummy)
+ # set any nodata values to be past the expected range
+ if labels_nodata is not None:
+ nodata_indices = (data == labels_nodata)
+ conv = classes_comp.classes_to_indices_func()
+ if conv is not None:
+ data = conv(data)
+ if labels_nodata is not None:
+ data[nodata_indices] = len(classes_comp)
+ return data
return (imageset, ImageSet(labels, labels_dict['type'],
- pre, labels_dict['nodata_value']))
+ class_shift, len(classes_comp) if labels_nodata is not None else None))
class ImagePreprocessConfig(DeltaConfigComponent):
+ """
+ Configuration for image preprocessing.
+
+ Expects a list of preprocessing functions registered
+ with `delta.config.extensions.register_preprocess`.
+ """
def __init__(self):
super().__init__()
- self.register_field('enabled', bool, 'enabled', None, 'Turn on preprocessing.')
- self.register_field('scale_factor', (float, str), 'scale_factor', None, 'Image scale factor.')
+ self._functions = []
+
+ def _load_dict(self, d, base_dir):
+ if d is None:
+ self._functions = []
+ return
+ if not d:
+ return
+ self._functions = []
+ assert isinstance(d, list), 'preprocess should be list of commands'
+ for func in d:
+ if isinstance(func, str):
+ self._functions.append((func, {}))
+ else:
+ assert isinstance(func, dict), 'preprocess items must be strings or dicts'
+ assert len(func) == 1, 'One preprocess item per list entry.'
+ name = list(func.keys())[0]
+ self._functions.append((name, func[name]))
+
+ def function(self, image_type):
+ """
+ Parameters
+ ----------
+ image_type: str
+ Type of the image
+ Returns
+ -------
+ Callable:
+ The specified preprocessing function to apply to the image.
+ """
+ prep = lambda data, _, dummy: data
+ for (name, args) in self._functions:
+ t = preprocess_function(name)
+ assert t is not None, 'Preprocess function %s not found.' % (name)
+ p = t(image_type=image_type, **args)
+ def helper(cur, prev):
+ return lambda data, roi, bands: cur(prev(data, roi, bands), roi, bands)
+ prep = helper(p, prep)
+ return prep
def _validate_paths(paths, base_dir):
out = []
@@ -191,23 +279,37 @@ def _validate_paths(paths, base_dir):
return out
class ImageSetConfig(DeltaConfigComponent):
+ """
+ Configuration for a set of images.
+
+ Used for images, labels, and validation images and labels.
+ """
def __init__(self, name=None):
super().__init__()
self.register_field('type', str, 'type', None, 'Image type.')
self.register_field('files', list, None, _validate_paths, 'List of image files.')
- self.register_field('file_list', list, None, validate_path, 'File listing image files.')
+ self.register_field('file_list', str, None, validate_path, 'File listing image files.')
self.register_field('directory', str, None, validate_path, 'Directory of image files.')
self.register_field('extension', str, None, None, 'Image file extension.')
self.register_field('nodata_value', (float, int), None, None, 'Value of pixels to ignore.')
if name:
- self.register_arg('type', '--' + name + '-type')
- self.register_arg('file_list', '--' + name + '-file-list')
- self.register_arg('directory', '--' + name + '-dir')
- self.register_arg('extension', '--' + name + '-extension')
+ self.register_arg('type', '--' + name + '-type', name + '_type')
+ self.register_arg('file_list', '--' + name + '-file-list', name + '_file_list')
+ self.register_arg('directory', '--' + name + '-dir', name + '_directory')
+ self.register_arg('extension', '--' + name + '-extension', name + '_extension')
self.register_component(ImagePreprocessConfig(), 'preprocess')
self._name = name
+ def preprocess_function(self):
+ """
+ Returns
+ -------
+ Callable:
+ Preprocessing function for the set of images.
+ """
+ return self._components['preprocess'].function(self._config_dict['type'])
+
def setup_arg_parser(self, parser, components = None) -> None:
if self._name is None:
return
@@ -221,9 +323,26 @@ def parse_args(self, options):
super().parse_args(options)
if hasattr(options, self._name) and getattr(options, self._name) is not None:
self._config_dict['files'] = [getattr(options, self._name)]
+ self._config_dict['directory'] = None
+ self._config_dict['file_list'] = None
class LabelClass:
+ """
+ Label configuration.
+ """
def __init__(self, value, name=None, color=None, weight=None):
+ """
+ Parameters
+ ----------
+ value: int
+ Pixel of the label
+ name: str
+ Name of the class to display
+ color: int
+ In visualizations, set the class to this RGB color.
+ weight: float
+ During training weight this class by this amount.
+ """
color_order = [0x1f77b4, 0xff7f0e, 0x2ca02c, 0xd62728, 0x9467bd, 0x8c564b, \
0xe377c2, 0x7f7f7f, 0xbcbd22, 0x17becf]
if name is None:
@@ -234,11 +353,17 @@ def __init__(self, value, name=None, color=None, weight=None):
self.name = name
self.color = color
self.weight = weight
+ self.end_value = None
def __repr__(self):
return 'Color: ' + self.name
class ClassesConfig(DeltaConfigComponent):
+ """
+ Configuration for classes.
+
+ Specify either a number of classes or list of classes with details.
+ """
def __init__(self):
super().__init__()
self._classes = []
@@ -247,6 +372,9 @@ def __init__(self):
def __iter__(self):
return self._classes.__iter__()
+ def __getitem__(self, key):
+ return self._classes[key]
+
def __len__(self):
return len(self._classes)
@@ -271,6 +399,11 @@ def _load_dict(self, d : dict, base_dir):
inner_dict = c[k]
self._classes.append(LabelClass(k, str(inner_dict.get('name')),
inner_dict.get('color'), inner_dict.get('weight')))
+ elif isinstance(d, dict):
+ for k in d:
+ assert isinstance(k, int), 'Class label value must be int.'
+ self._classes.append(LabelClass(k, str(d[k].get('name')),
+ d[k].get('color'), d[k].get('weight')))
else:
raise ValueError('Expected classes to be an int or list in config, was ' + str(d))
# make sure the order is consistent for same values, and create preprocessing function
@@ -279,8 +412,36 @@ def _load_dict(self, d : dict, base_dir):
for (i, v) in enumerate(self._classes):
if v.value != i:
self._conversions.append(v.value)
+ v.end_value = i
+
+ def class_id(self, class_name):
+ """
+ Parameters
+ ----------
+ class_name: int or str
+ Either the original pixel value in images (int) or the name (str) of a class.
+ The special value 'nodata' will give the nodata class, if any.
+
+ Returns
+ -------
+ int:
+ the ID of the class in the labels after default image preprocessing (labels are arranged
+ to a canonical order, with nodata always coming after them.)
+ """
+ if class_name == len(self._classes) or class_name == 'nodata':
+ return len(self._classes)
+ for (i, c) in enumerate(self._classes):
+ if class_name in (c.value, c.name):
+ return i
+ raise ValueError('Class ' + class_name + ' not found.')
def weights(self):
+ """
+ Returns
+ -------
+ List[float]
+ List of class weights for use in training, if specified.
+ """
weights = []
for c in self._classes:
if c.weight is not None:
@@ -291,6 +452,12 @@ def weights(self):
return weights
def classes_to_indices_func(self):
+ """
+ Returns
+ -------
+ Callable[[numpy.ndarray], numpy.ndarray]:
+ Function to convert label image to canonical form
+ """
if not self._conversions:
return None
def convert(data):
@@ -301,6 +468,12 @@ def convert(data):
return convert
def indices_to_classes_func(self):
+ """
+ Returns
+ -------
+ Callable[[numpy.ndarray], numpy.ndarray]:
+ Reverse of `classes_to_indices_func`.
+ """
if not self._conversions:
return None
def convert(data):
@@ -311,14 +484,15 @@ def convert(data):
return convert
class DatasetConfig(DeltaConfigComponent):
+ """
+ Configuration for a dataset.
+ """
def __init__(self):
super().__init__('Dataset')
self.register_component(ImageSetConfig('image'), 'images', '__image_comp')
self.register_component(ImageSetConfig('label'), 'labels', '__label_comp')
self.__images = None
self.__labels = None
- self.register_field('log_folder', str, 'log_folder', validate_path,
- 'Directory where dataset progress is recorded.')
self.register_component(ClassesConfig(), 'classes')
def reset(self):
@@ -328,7 +502,10 @@ def reset(self):
def images(self) -> ImageSet:
"""
- Returns the training images.
+ Returns
+ -------
+ ImageSet:
+ the training images
"""
if self.__images is None:
(self.__images, self.__labels) = load_images_labels(self._components['images'],
@@ -338,7 +515,10 @@ def images(self) -> ImageSet:
def labels(self) -> ImageSet:
"""
- Returns the label images.
+ Returns
+ -------
+ ImageSet:
+ the label images
"""
if self.__labels is None:
(self.__images, self.__labels) = load_images_labels(self._components['images'],
@@ -347,6 +527,9 @@ def labels(self) -> ImageSet:
return self.__labels
class CacheConfig(DeltaConfigComponent):
+ """
+ Configuration for cache.
+ """
def __init__(self):
super().__init__()
self.register_field('dir', str, None, validate_path, 'Cache directory.')
@@ -360,34 +543,57 @@ def reset(self):
def manager(self) -> disk_folder_cache.DiskCache:
"""
- Returns the disk cache object to manage the cache.
+ Returns
+ -------
+ `disk_folder_cache.DiskCache`:
+ the object to manage the cache
"""
if self._cache_manager is None:
+ # Auto-populating defaults here is a workaround so small tools can skip the full
+ # command line config setup. Could be improved!
+ if 'dir' not in self._config_dict:
+ self._config_dict['dir'] = 'default'
+ if 'limit' not in self._config_dict:
+ self._config_dict['limit'] = 8
cdir = self._config_dict['dir']
if cdir == 'default':
cdir = appdirs.AppDirs('delta', 'nasa').user_cache_dir
self._cache_manager = disk_folder_cache.DiskCache(cdir, self._config_dict['limit'])
return self._cache_manager
+def _validate_tile_size(size, _):
+ assert len(size) == 2, 'Size must have two components.'
+ assert isinstance(size[0], int) and isinstance(size[1], int), 'Size must be integer.'
+ assert size[0] > 0 and size[1] > 1, 'Size must be positive.'
+ return size
+
class IOConfig(DeltaConfigComponent):
+ """
+ Configuration for I/O.
+ """
def __init__(self):
- super().__init__()
- self.register_field('threads', int, 'threads', None, 'Number of threads to use.')
- self.register_field('block_size_mb', int, 'block_size_mb', validate_positive,
- 'Size of an image block to load in memory at once.')
+ super().__init__('IO')
+ self.register_field('threads', int, None, None, 'Number of threads to use.')
+ self.register_field('tile_size', list, 'tile_size', _validate_tile_size,
+ 'Size of an image tile to load in memory at once.')
self.register_field('interleave_images', int, 'interleave_images', validate_positive,
'Number of images to interleave at a time when training.')
- self.register_field('tile_ratio', float, 'tile_ratio', validate_positive,
- 'Width to height ratio of blocks to load in images.')
- self.register_field('resume_cutoff', int, 'resume_cutoff', None,
- 'When resuming a dataset, skip images where we have read this many tiles.')
self.register_arg('threads', '--threads')
- self.register_arg('block_size_mb', '--block-size-mb')
- self.register_arg('tile_ratio', '--tile-ratio')
self.register_component(CacheConfig(), 'cache')
+ def threads(self):
+ """
+ Returns
+ -------
+ int:
+ number of threads to use for I/O
+ """
+ if 'threads' in self._config_dict and self._config_dict['threads']:
+ return self._config_dict['threads']
+ return min(1, os.cpu_count() // 2)
+
def register():
"""
Registers imagery config options with the global config manager.
diff --git a/delta/imagery/imagery_dataset.py b/delta/imagery/imagery_dataset.py
index 00bae651..103ab31d 100644
--- a/delta/imagery/imagery_dataset.py
+++ b/delta/imagery/imagery_dataset.py
@@ -18,278 +18,529 @@
"""
Tools for loading input images into the TensorFlow Dataset class.
"""
+from concurrent.futures import ThreadPoolExecutor
+import copy
import functools
-import math
import random
-import sys
import os
import portalocker
import numpy as np
import tensorflow as tf
from delta.config import config
-from delta.imagery import rectangle
-from delta.imagery.sources import loader
-class ImageryDataset:
- """Create dataset with all files as described in the provided config file.
+class ImageryDataset: # pylint: disable=too-many-instance-attributes
+ """
+ A dataset for tiling very large imagery for training with tensorflow.
"""
- def __init__(self, images, labels, chunk_size, output_size, chunk_stride=1,
- resume_mode=False, log_folder=None):
+ def __init__(self, images, labels, output_shape, chunk_shape, stride=None,
+ tile_shape=(256, 256), tile_overlap=None):
"""
- Initialize the dataset based on the specified image and label ImageSets
+ Parameters
+ ----------
+ images: ImageSet
+ Images to train on
+ labels: ImageSet
+ Corresponding labels to train on
+ output_shape: (int, int)
+ Shape of the corresponding labels for a given chunk or tile size.
+ chunk_shape: (int, int)
+ If specified, divide tiles into individual chunks of this shape.
+ stride: (int, int)
+ Skip this stride between chunks. Only valid with chunk_shape.
+ tile_shape: (int, int)
+ Size of tiles to load from the images at a time.
+ tile_overlap: (int, int)
+ If specified, overlap tiles by this amount.
"""
- self._resume_mode = resume_mode
- self._log_folder = log_folder
- if self._log_folder and not os.path.exists(self._log_folder):
- os.mkdir(self._log_folder)
+ self._resume_mode = False
+ self._log_folder = None
+ self._iopool = ThreadPoolExecutor(1)
# Record some of the config values
- assert (chunk_size % 2) == (output_size % 2), 'Chunk size and output size must both be either even or odd.'
- self._chunk_size = chunk_size
- self._output_size = output_size
+ self.set_chunk_output_shapes(chunk_shape, output_shape)
self._output_dims = 1
- self._chunk_stride = chunk_stride
+ # one for imagery, one for labels
+ if stride is None:
+ stride = (1, 1)
+ self._stride = stride
self._data_type = tf.float32
self._label_type = tf.uint8
+ self._tile_shape = tile_shape
+ if tile_overlap is None:
+ tile_overlap = (0, 0)
+ self._tile_overlap = tile_overlap
if labels:
assert len(images) == len(labels)
self._images = images
self._labels = labels
+ self._access_counts = [np.zeros(0, np.uint8), np.zeros(0, np.uint8)]
# Load the first image to get the number of bands for the input files.
- self._num_bands = loader.load_image(images, 0).num_bands()
+ self._num_bands = images.load(0).num_bands()
- def _get_image_read_log_path(self, image_path):
- """Return the path to the read log for an input image"""
+ # TODO: I am skeptical that this works with multiple epochs.
+ # It is also less important now that training is so much faster.
+ # I think we should probably get rid of it at some point.
+ def set_resume_mode(self, resume_mode, log_folder):
+ """
+ Enable / disable resume mode and configure it.
+
+ Parameters
+ ----------
+ resume_mode: bool
+ If true, log and check access counts for if imagery can be skipped
+ this epoch.
+ log_folder: str
+ Folder to log access counts to
+ """
+ self._resume_mode = resume_mode
+ self._log_folder = log_folder
+ if self._log_folder and not os.path.exists(self._log_folder):
+ os.mkdir(self._log_folder)
+
+ def _resume_log_path(self, image_id):
+ """
+ Parameters
+ ----------
+ image_id: int
+
+ Returns
+ -------
+ str:
+ the path to the read log for an input image
+ """
if not self._log_folder:
return None
+ image_path = self._images[image_id]
image_name = os.path.basename(image_path)
file_name = os.path.splitext(image_name)[0] + '_read.log'
log_path = os.path.join(self._log_folder, file_name)
return log_path
- def _get_image_read_count(self, image_path):
- """Return the number of ROIs we have read from an image"""
- log_path = self._get_image_read_log_path(image_path)
- if (not log_path) or not os.path.exists(log_path):
- return 0
- counter = 0
- with portalocker.Lock(log_path, 'r', timeout=300) as f:
- for line in f: #pylint: disable=W0612
- counter += 1
- return counter
-
- def _load_tensor_imagery(self, is_labels, image_index, bbox):
- """Loads a single image as a tensor."""
- data = self._labels if is_labels else self._images
-
- if not is_labels: # Record each time we write a tile
- file_path = data[image_index.numpy()]
- log_path = self._get_image_read_log_path(file_path)
- if log_path:
- with portalocker.Lock(log_path, 'a', timeout=300) as f:
- f.write(str(bbox) + '\n')
- # TODO: What to write and when to clear it?
-
+ def resume_log_read(self, image_id): #pylint: disable=R0201
+ """
+ Reads an access count file containing a boolean and a count.
+
+ Parameters
+ ----------
+ image_id: int
+ Image id to check logs for
+
+ Returns
+ -------
+ (bool, int):
+ need_to_check, access count
+ The boolean is set to true if we need to check the count.
+ """
+ path = self._resume_log_path(image_id)
try:
- image = loader.load_image(data, image_index.numpy())
- w = int(bbox[2])
- h = int(bbox[3])
- rect = rectangle.Rectangle(int(bbox[0]), int(bbox[1]), w, h)
- r = image.read(rect)
- except Exception as e: #pylint: disable=W0703
- print('Caught exception loading tile from image: ' + data[image_index.numpy()] + ' -> ' + str(e)
- + '\nSkipping tile: ' + str(bbox))
- if config.general.stop_on_input_error():
- print('Aborting processing, set --bypass-input-errors to bypass this error.')
+ with portalocker.Lock(path, 'r', timeout=300) as f:
+ line = f.readline()
+ parts = line.split()
+ if len(parts) == 1: # Legacy files
+ return (True, int(parts[0]))
+ needToCheck = (parts[0] == '1')
+ return (needToCheck, int(parts[1]))
+ except OSError as e:
+ if e.errno == 122: # Disk quota exceeded
raise
- # Else just skip this tile
- r = np.zeros(shape=(0,0,0), dtype=np.float32)
- return r
-
- def _tile_images(self):
- max_block_bytes = config.io.block_size_mb() * 1024 * 1024
- def tile_generator():
- tgs = []
- for i in range(len(self._images)):
-
- if self._resume_mode:
- # TODO: Improve feature to work with multiple epochs
- # Skip images which we have already read some number of tiles from
- if self._get_image_read_count(self._images[i]) > config.io.resume_cutoff():
- continue
-
- try:
- img = loader.load_image(self._images, i)
-
- if self._labels: # If we have labels make sure they are the same size as the input images
- label = loader.load_image(self._labels, i)
- if label.size() != img.size():
- raise Exception('Label file ' + self._labels[i] + ' with size ' + str(label.size())
- + ' does not match input image size of ' + str(img.size()))
- # w * h * bands * 4 * chunk * chunk = max_block_bytes
- tile_width = int(math.sqrt(max_block_bytes / img.num_bands() / self._data_type.size /
- config.io.tile_ratio()))
- tile_height = int(config.io.tile_ratio() * tile_width)
- min_block_size = self._chunk_size ** 2 * config.io.tile_ratio() * img.num_bands() * 4
- if max_block_bytes < min_block_size:
- print('Warning: max_block_bytes=%g MB, but %g MB is recommended (minimum: %g MB)'
- % (max_block_bytes / 1024 / 1024,
- min_block_size * 2 / 1024 / 1024, min_block_size / 1024/ 1024),
- file=sys.stderr)
- if tile_width < self._chunk_size or tile_height < self._chunk_size:
- raise ValueError('max_block_bytes is too low.')
- tiles = img.tiles(tile_width, tile_height, min_width=self._chunk_size, min_height=self._chunk_size,
- overlap=self._chunk_size - 1)
- except Exception as e: #pylint: disable=W0703
- print('Caught exception tiling image: ' + self._images[i] + ' -> ' + str(e)
- + '\nWill not load any tiles from this image')
- if config.general.stop_on_input_error():
- print('Aborting processing, set --bypass-input-errors to bypass this error.')
- raise
- tiles = [] # Else move past this image without loading any tiles
-
- random.Random(0).shuffle(tiles) # gives consistent random ordering so labels will match
- tgs.append((i, tiles))
- if not tgs:
- return
- while tgs:
- cur = tgs[:config.io.interleave_images()]
- tgs = tgs[config.io.interleave_images():]
- done = False
- while not done:
- done = True
- for it in cur:
- if not it[1]:
- continue
- t = it[1].pop(0)
- if t:
- done = False
- yield (it[0], t.min_x, t.min_y, t.max_x, t.max_y)
- if done:
- break
- return tf.data.Dataset.from_generator(tile_generator,
- (tf.int32, tf.int32, tf.int32, tf.int32, tf.int32))
+ return (False, 0)
+ except Exception: #pylint: disable=W0703
+ # If there is a problem reading the count just treat as zero
+ return (False, 0)
- def _load_images(self, is_labels, data_type):
+ def resume_log_update(self, image_id, count=None, need_check=False): #pylint: disable=R0201
"""
- Loads a list of images as tensors.
- If label_list is specified, load labels instead. The corresponding image files are still required however.
+ Update logs of when images are read. Should only be needed internally.
+
+ Parameters
+ ----------
+ image_id: int
+ The image to update
+ count: int
+ Number of tiles that have been read
+ need_check: bool
+ Set flag for if a check is needed
+ """
+ log_path = self._resume_log_path(image_id)
+ if not log_path:
+ return
+ if count is None:
+ (_, count) = self.resume_log_read(image_id)
+ count += 1
+ with portalocker.Lock(log_path, 'w', timeout=300) as f:
+ f.write('%d %d' % (int(need_check), count))
+
+ def reset_access_counts(self, set_need_check=False):
"""
- ds_input = self._tile_images()
- def load_tile(image_index, x1, y1, x2, y2):
- img = tf.py_function(functools.partial(self._load_tensor_imagery,
- is_labels),
- [image_index, [x1, y1, x2, y2]], data_type)
- return img
- ret = ds_input.map(load_tile, num_parallel_calls=tf.data.experimental.AUTOTUNE)#config.io.threads())
+ Go through all the access files and reset the counts. Should be done at the end of each epoch.
- # Don't let the entire session be taken down by one bad dataset input.
- # - Would be better to handle this somehow but it is not clear if TF supports that.
-# ret = ret.apply(tf.data.experimental.ignore_errors())
+ Parameters
+ ----------
+ set_need_check: bool
+ if true, keep the count and mark that it needs to be checked. (should be
+ set at the start of training)
+ """
+ if not self._log_folder:
+ return
+ if config.general.verbose():
+ print('Resetting access counts in folder: ' + self._log_folder)
+ for i in range(len(self._images)):
+ self.resume_log_update(i, count=0, need_check=set_need_check)
- return ret
+ def _list_tiles(self, i): # pragma: no cover
+ """
+ Parameters
+ ----------
+ i: int
+ Image to list tiles for.
+
+ Returns
+ -------
+ List[Rectangle]:
+ List of tiles to read from the given image
+ """
+ # If we need to skip this file because of the read count, no need to look up tiles.
+ if self._resume_mode:
+ file_path = self._images[i]
+ log_path = self._resume_log_path(i)
+ if log_path:
+ if config.general.verbose():
+ print('get_image_tile_list for index ' + str(i) + ' -> ' + file_path)
+ (need_to_check, count) = self.resume_log_read(i)
+ if need_to_check and (count > config.train.resume_cutoff()):
+ if config.general.verbose():
+ print('Skipping index ' + str(i) + ' tile gen with count '
+ + str(count) + ' -> ' + file_path)
+ return []
+ if config.general.verbose():
+ print('Computing tile list for index ' + str(i) + ' with count '
+ + str(count) + ' -> ' + file_path)
+ else:
+ if config.general.verbose():
+ print('No read log file for index ' + str(i))
+
+ img = self._images.load(i)
+
+ if self._labels: # If we have labels make sure they are the same size as the input images
+ label = self._labels.load(i)
+ if label.size() != img.size():
+ raise AssertionError('Label file ' + self._labels[i] + ' with size ' + str(label.size())
+ + ' does not match input image size of ' + str(img.size()))
+ tile_shape = self._tile_shape
+ if self._chunk_shape:
+ assert tile_shape[0] >= self._chunk_shape[0] and \
+ tile_shape[1] >= self._chunk_shape[1], 'Tile too small.'
+ return img.tiles((tile_shape[0], tile_shape[1]), min_shape=self._chunk_shape,
+ overlap_shape=(self._chunk_shape[0] - 1, self._chunk_shape[1] - 1),
+ by_block=True)
+ return img.tiles((tile_shape[0], tile_shape[1]), partials=False, partials_overlap=True,
+ overlap_shape=self._tile_overlap, by_block=True)
+
+ def _tile_generator(self, i, is_labels): # pragma: no cover
+ """
+ A generator that yields image tiles from the given image.
+
+ Parameters
+ ----------
+ i: int
+ Image id
+ is_labels: bool
+ Load the label if true, image if false
+
+ Returns
+ -------
+ Iterator[numpy.ndarray]:
+ Iterator over iamge tiles.
+ """
+ i = int(i)
+ tiles = self._list_tiles(i)
+ # track epoch (must be same for label and non-label)
+ epoch = self._access_counts[1 if is_labels else 0][i]
+ self._access_counts[1 if is_labels else 0][i] += 1
+ if not tiles:
+ return
+
+ # different order each epoch
+ random.Random(epoch * i * 11617).shuffle(tiles)
+
+ image = (self._labels if is_labels else self._images).load(i)
+ preprocess = image.get_preprocess()
+ image.set_preprocess(None) # parallelize the preprocessing, not in disk i/o threadpool
+ bands = range(image.num_bands())
+
+ # read one row ahead of what we process now
+ next_buf = self._iopool.submit(lambda: image.read(tiles[0][0]))
+ for (c, (rect, sub_tiles)) in enumerate(tiles):
+ cur_buf = next_buf
+ if c + 1 < len(tiles):
+ # extra lambda to bind c in closure
+ next_buf = self._iopool.submit((lambda x: (lambda: image.read(tiles[x + 1][0])))(c))
+ if cur_buf is None:
+ continue
+ buf = cur_buf.result()
+ (rect, sub_tiles) = tiles[c]
+ for s in sub_tiles:
+ if preprocess:
+ t = copy.copy(s)
+ t.shift(rect.min_x, rect.min_y)
+ yield preprocess(buf[s.min_x:s.max_x, s.min_y:s.max_y, :], t, bands)
+ else:
+ yield buf[s.min_x:s.max_x, s.min_y:s.max_y, :]
+
+ if not is_labels: # update access count per row
+ self.resume_log_update(i, need_check=False)
+
+ def _load_images(self, is_labels, data_type):
+ """
+ Loads a list of images as tensors.
- def _chunk_image(self, image):
+ Parameters
+ ----------
+ is_labels: bool
+ Load labels if true, images if not
+ data_type: numpy.dtype
+ Data type that will be returned.
+
+ Returns
+ -------
+ Dataset:
+ Dataset of image tiles
+ """
+ r = tf.data.Dataset.range(len(self._images))
+ r = r.shuffle(1000, seed=0, reshuffle_each_iteration=True) # shuffle same way for labels and non-labels
+ self._access_counts[1 if is_labels else 0] = np.zeros(len(self._images), np.uint8) # count epochs for random
+ # different seed for each image, use ge
+ gen_func = lambda x: tf.data.Dataset.from_generator(functools.partial(self._tile_generator,
+ is_labels=is_labels),
+ output_types=data_type,
+ output_shapes=tf.TensorShape((None, None, None)), args=(x,))
+ return r.interleave(gen_func, cycle_length=config.io.interleave_images(),
+ num_parallel_calls=config.io.threads())
+
+ def _chunk_image(self, image): # pragma: no cover
"""Split up a tensor image into tensor chunks"""
- ksizes = [1, self._chunk_size, self._chunk_size, 1] # Size of the chunks
- strides = [1, self._chunk_stride, self._chunk_stride, 1] # Spacing between chunk starts
+
+ ksizes = [1, self._chunk_shape[0], self._chunk_shape[1], 1] # Size of the chunks
+ strides = [1, self._stride[0], self._stride[1], 1] # Spacing between chunk starts
rates = [1, 1, 1, 1]
result = tf.image.extract_patches(tf.expand_dims(image, 0), ksizes, strides, rates,
padding='VALID')
# Output is [1, M, N, chunk*chunk*bands]
- result = tf.reshape(result, [-1, self._chunk_size, self._chunk_size, self._num_bands])
+ result = tf.reshape(result, [-1, self._chunk_shape[0], self._chunk_shape[1], self._num_bands])
return result
- def _reshape_labels(self, labels):
+ def _reshape_labels(self, labels): # pragma: no cover
"""Reshape the labels to account for the chunking process."""
- w = (self._chunk_size - self._output_size) // 2
- labels = tf.image.crop_to_bounding_box(labels, w, w, tf.shape(labels)[0] - 2 * w,
- tf.shape(labels)[1] - 2 * w) #pylint: disable=C0330
-
- ksizes = [1, self._output_size, self._output_size, 1]
- strides = [1, self._chunk_stride, self._chunk_stride, 1]
+ if self._chunk_shape:
+ w = (self._chunk_shape[0] - self._output_shape[0]) // 2
+ h = (self._chunk_shape[1] - self._output_shape[1]) // 2
+ else:
+ w = (tf.shape(labels)[0] - self._output_shape[0]) // 2
+ h = (tf.shape(labels)[1] - self._output_shape[1]) // 2
+ labels = tf.image.crop_to_bounding_box(labels, w, h, tf.shape(labels)[0] - 2 * w,
+ tf.shape(labels)[1] - 2 * h)
+ if not self._chunk_shape:
+ return labels
+
+ ksizes = [1, self._output_shape[0], self._output_shape[1], 1]
+ strides = [1, self._stride[0], self._stride[1], 1]
rates = [1, 1, 1, 1]
labels = tf.image.extract_patches(tf.expand_dims(labels, 0), ksizes, strides, rates,
padding='VALID')
- return tf.reshape(labels, [-1, self._output_size, self._output_size])
+ result = tf.reshape(labels, [-1, self._output_shape[0], self._output_shape[1]])
+ return result
def data(self):
"""
- Unbatched dataset of image chunks.
+ Returns
+ -------
+ Dataset:
+ image chunks / tiles.
"""
ret = self._load_images(False, self._data_type)
- ret = ret.map(self._chunk_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
- return ret.unbatch()
+ if self._chunk_shape:
+ ret = ret.map(self._chunk_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return ret.unbatch()
+ return ret
def labels(self):
"""
- Unbatched dataset of labels.
+ Returns
+ -------
+ Dataset:
+ Unbatched dataset of labels corresponding to `data()`.
"""
label_set = self._load_images(True, self._label_type)
- label_set = label_set.map(self._reshape_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE) #pylint: disable=C0301
- return label_set.unbatch()
+ if self._chunk_shape or self._output_shape:
+ label_set = label_set.map(self._reshape_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE) #pylint: disable=C0301
+ if self._chunk_shape:
+ return label_set.unbatch()
+ return label_set
def dataset(self, class_weights=None):
"""
- Return the underlying TensorFlow dataset object that this class creates.
+ Returns a tensorflow dataset as configured by the class.
- class_weights: a list of weights to apply to the samples in each class, if specified.
+ Parameters
+ ----------
+ class_weights: list
+ list of weights for the classes.
+
+ Returns
+ -------
+ tensorflow Dataset:
+ With (data, labels, optionally weights)
"""
# Pair the data and labels in our dataset
ds = tf.data.Dataset.zip((self.data(), self.labels()))
- # ignore labels with no data
- if self._labels.nodata_value():
- ds = ds.filter(lambda x, y: tf.math.not_equal(y, self._labels.nodata_value()))
+ # ignore chunks which are all nodata (nodata is re-indexed to be after the classes)
+ if self._labels.nodata_value() is not None:
+ ds = ds.filter(lambda x, y: tf.math.reduce_any(tf.math.not_equal(y, self._labels.nodata_value())))
if class_weights is not None:
+ class_weights.append(0.0)
lookup = tf.constant(class_weights)
- ds = ds.map(lambda x, y: (x, y, tf.gather(lookup, tf.cast(y, tf.int32), axis=None)))
+ ds = ds.map(lambda x, y: (x, y, tf.gather(lookup, tf.cast(y, tf.int32), axis=None)),
+ num_parallel_calls=config.io.threads())
return ds
def num_bands(self):
"""
- Return the number of bands in each image of the data set.
+ Returns
+ -------
+ int:
+ number of bands in each image
"""
return self._num_bands
- def chunk_size(self):
+ def set_chunk_output_shapes(self, chunk_shape, output_shape):
+ """
+ Parameters
+ ----------
+ chunk_shape: (int, int)
+ Size of chunks to read at a time. Set to None to
+ use on a per tile basis (i.e., for FCNs).
+ output_shape: (int, int)
+ Shape output by the network. May differ from the input size
+ (dervied from chunk_shape or tile_shape)
"""
- Size of chunks used for inputs.
+ if chunk_shape:
+ assert len(chunk_shape) == 2, 'Chunk must be two dimensional.'
+ assert (chunk_shape[0] % 2) == (chunk_shape[1] % 2) == \
+ (output_shape[0] % 2) == (output_shape[1] % 2), 'Chunk and output shapes must both be even or odd.'
+ if output_shape:
+ assert len(output_shape) == 2 or len(output_shape) == 3, 'Output must be two or three dimensional.'
+ if len(output_shape) == 3:
+ output_shape = output_shape[0:2]
+ self._chunk_shape = chunk_shape
+ self._output_shape = output_shape
+
+ def chunk_shape(self):
"""
+ Returns
+ -------
+ (int, int):
+ Size of chunks used for inputs.
+ """
+ return self._chunk_shape
+
+ def input_shape(self):
+ """
+ Returns
+ -------
+ Tuple[int, ...]:
+ Input size for the network.
+ """
+ if self._chunk_shape:
+ return (self._chunk_shape[0], self._chunk_shape[1], self._num_bands)
+ return (None, None, self._num_bands)
+
def output_shape(self):
"""
- Output size of blocks of labels.
+ Returns
+ -------
+ Tuple[int, ...]:
+ Output size, size of blocks of labels
"""
- return (self._output_size, self._output_size, self._output_dims)
+ if self._output_shape:
+ return (self._output_shape[0], self._output_shape[1], self._output_dims)
+ return (None, None, self._output_dims)
def image_set(self):
"""
- Returns set of images.
+ Returns
+ -------
+ ImageSet:
+ set of images
"""
return self._images
def label_set(self):
"""
- Returns set of label images.
+ Returns
+ -------
+ ImageSet:
+ set of labels
"""
return self._labels
-class AutoencoderDataset(ImageryDataset):
- """Slightly modified dataset class for the Autoencoder which does not use separate label files"""
+ def set_tile_shape(self, tile_shape):
+ """
+ Set the tile size.
- def __init__(self, images, chunk_size, chunk_stride=1, resume_mode=False, log_folder=None):
+ Parameters
+ ----------
+ tile_shape: (int, int)
+ New tile shape"""
+ self._tile_shape = tile_shape
+
+ def tile_shape(self):
+ """
+ Returns
+ -------
+ Tuple[int, ...]:
+ tile shape to load at a time
+ """
+ return self._tile_shape
+
+ def tile_overlap(self):
"""
- The images are used as labels as well.
+ Returns
+ -------
+ Tuple[int, ...]:
+ the amount tiles overlap
"""
- super(AutoencoderDataset, self).__init__(images, None, chunk_size, chunk_size, chunk_stride=chunk_stride,
- resume_mode=resume_mode, log_folder=log_folder)
+ return self._tile_overlap
+
+ def stride(self):
+ """
+ Returns
+ -------
+ Tuple[int, ...]:
+ Stride between chunks (only when chunk_shape is set).
+ """
+ return self._stride
+
+class AutoencoderDataset(ImageryDataset):
+ """
+ Slightly modified dataset class for the autoencoder.
+
+ Instead of specifying labels, the inputs are used as labels.
+ """
+
+ def __init__(self, images, chunk_shape, stride=(1, 1), tile_shape=(256, 256), tile_overlap=None):
+ super().__init__(images, None, chunk_shape, chunk_shape, tile_shape=tile_shape,
+ stride=stride, tile_overlap=tile_overlap)
self._labels = self._images
self._output_dims = self.num_bands()
def labels(self):
return self.data()
+
+ def dataset(self, class_weights=None):
+ return self.data().map(lambda x: (x, x))
diff --git a/delta/imagery/rectangle.py b/delta/imagery/rectangle.py
index 776f25a5..a17aa52a 100644
--- a/delta/imagery/rectangle.py
+++ b/delta/imagery/rectangle.py
@@ -21,12 +21,24 @@
import math
class Rectangle:
- """Simple rectangle class for ROIs. Max values are NON-INCLUSIVE.
- When using it, stay consistent with float or integer values.
+ """
+ Simple rectangle class for ROIs. Max values are NON-INCLUSIVE.
+ When using it, stay consistent with float or integer values.
"""
def __init__(self, min_x, min_y, max_x=0, max_y=0,
width=0, height=0):
- """Specify width/height by name to use those instead of max_x/max_y."""
+ """
+ Parameters
+ ----------
+ min_x: int
+ min_y: int
+ max_x: int
+ max_y: int
+ Rectangle bounds.
+ width: int
+ height: int
+ Specify width / height to use these instead of max_x/max_y.
+ """
self.min_x = min_x
self.min_y = min_y
if width > 0:
@@ -55,8 +67,13 @@ def __repr__(self):
# for col in range(self.min_x, self.max_x):
# yield(TileIndex(row,col))
- def get_bounds(self):
- '''Returns (min_x, max_x, min_y, max_y)'''
+ def bounds(self):
+ """
+ Returns
+ -------
+ (int, int, int, int):
+ (min_x, max_x, min_y, max_y)
+ """
return (self.min_x, self.max_x, self.min_y, self.max_y)
def width(self):
@@ -65,14 +82,18 @@ def height(self):
return self.max_y - self.min_y
def has_area(self):
- '''Returns true if the rectangle contains any area.'''
+ """
+ Returns
+ -------
+ bool:
+ true if the rectangle contains any area.
+ """
return (self.width() > 0) and (self.height() > 0)
def perimeter(self):
return 2*self.width() + 2*self.height()
def area(self):
- '''Returns the valid area'''
if not self.has_area():
return 0
return self.height() * self.width()
@@ -157,18 +178,45 @@ def overlaps(self, other_rect):
overlap_area = self.get_intersection(other_rect)
return overlap_area.has_area()
- def make_tile_rois(self, tile_width, tile_height, min_width=0, min_height=0,
- include_partials=True, overlap_amount=0):
- '''Return a list of tiles encompassing the entire area of this Rectangle'''
-
- tile_spacing_x = tile_width - overlap_amount
- tile_spacing_y = tile_height - overlap_amount
+ def make_tile_rois(self, tile_shape, overlap_shape=(0, 0), include_partials=True, min_shape=(0, 0),
+ partials_overlap=False, by_block=False):
+ """
+ Return a list of tiles encompassing the entire area of this Rectangle.
+
+ Parameters
+ ----------
+ tile_shape: (int, int)
+ Shape of each tile
+ overlap_shape: (int, int)
+ Amount to overlap tiles in x and y direction
+ include_partials: bool
+ If true, include partial tiles at the edge of the image.
+ min_shape: (int, int)
+ If true and `partials` is true, keep partial tiles of this minimum size.
+ partials_overlap: bool
+ If `partials` is false, and this is true, expand partial tiles
+ to the desired size. Tiles may overlap in some areas.
+ by_block: bool
+ If true, changes the returned generator to group tiles by block.
+ This is intended to optimize disk reads by reading the entire block at once.
+
+ Returns
+ -------
+ List[Rectangle]:
+ Generator yielding ROIs. If `by_block` is true, returns a generator of (Rectangle, List[Rectangle])
+ instead, where the first rectangle is a larger block containing multiple tiles in a list.
+ """
+ tile_width, tile_height = tile_shape
+ min_width, min_height = min_shape
+
+ tile_spacing_x = tile_width - overlap_shape[0]
+ tile_spacing_y = tile_height - overlap_shape[1]
num_tiles = (int(math.ceil(self.width() / tile_spacing_x )),
int(math.ceil(self.height() / tile_spacing_y)))
output_tiles = []
for c in range(0, num_tiles[0]):
+ row_tiles = []
for r in range(0, num_tiles[1]):
-
tile = Rectangle(self.min_x + c*tile_spacing_x,
self.min_y + r*tile_spacing_y,
width=tile_width, height=tile_height)
@@ -177,8 +225,24 @@ def make_tile_rois(self, tile_width, tile_height, min_width=0, min_height=0,
tile = tile.get_intersection(self)
if tile.width() < min_width or tile.height() < min_height:
continue
- output_tiles.append(tile)
else: # Only use it if the uncropped tile fits entirely in this Rectangle
- if self.contains_rect(tile):
- output_tiles.append(tile)
+ if not self.contains_rect(tile):
+ if not partials_overlap:
+ continue
+ tile = Rectangle(min(self.max_x, tile.max_x) - tile_width,
+ min(self.max_y, tile.max_y) - tile_height,
+ width=tile_width, height=tile_height)
+ if not self.contains_rect(tile):
+ continue
+ if by_block:
+ row_tiles.append(tile)
+ else:
+ output_tiles.append(tile)
+
+ if by_block and row_tiles:
+ row_rect = Rectangle(row_tiles[0].min_x, row_tiles[0].min_y, row_tiles[-1].max_x, row_tiles[-1].max_y)
+ for r in row_tiles:
+ r.shift(-row_rect.min_x, -row_rect.min_y)
+ output_tiles.append((row_rect, row_tiles))
+
return output_tiles
diff --git a/delta/imagery/sources/README.md b/delta/imagery/sources/README.md
deleted file mode 100644
index 924410ab..00000000
--- a/delta/imagery/sources/README.md
+++ /dev/null
@@ -1,10 +0,0 @@
-DELTA Imagery Sources
-=====================
-
-DELTA supports a variety of imagery sources. Currently, by the `type` given in
-a configuration file:
-
- * `tiff`: Geotiff files.
- * `worldview`: Worldview images, stored in zip files.
- * `landsat`: Landsat images, stored in zip files with MTL files.
- * `npy`: numpy arrays saved to disk.
diff --git a/delta/imagery/sources/delta_image.py b/delta/imagery/sources/delta_image.py
deleted file mode 100644
index fb984196..00000000
--- a/delta/imagery/sources/delta_image.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# Copyright © 2020, United States Government, as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All rights reserved.
-#
-# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
-# licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0.
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Base class for loading images.
-"""
-
-from abc import ABC, abstractmethod
-import concurrent.futures
-import copy
-import functools
-from typing import Callable, Iterator, List, Tuple
-
-import numpy as np
-
-from delta.imagery import rectangle, utilities
-
-class DeltaImage(ABC):
- """
- Base class used for wrapping input images in a way that they can be passed
- to Tensorflow dataset objects.
- """
- def __init__(self):
- self.__preprocess_function = None
-
- def read(self, roi: rectangle.Rectangle=None, bands: List[int]=None, buf: np.ndarray=None) -> np.ndarray:
- """
- Reads the image in [row, col, band] indexing.
-
- If `roi` is not specified, reads the entire image.
- If `buf` is specified, writes the image to buf.
- If `bands` is not specified, reads all bands, otherwise
- only the listed bands are read.
- If bands is a single integer, drops the band dimension.
- """
- if roi is None:
- roi = rectangle.Rectangle(0, 0, width=self.width(), height=self.height())
- else:
- if roi.min_x < 0 or roi.min_y < 0 or roi.max_x > self.width() or roi.max_y > self.height():
- raise IndexError('Rectangle (%d, %d, %d, %d) outside of bounds (%d, %d).' %
- (roi.min_x, roi.min_y, roi.max_x, roi.max_y, self.width(), self.height()))
- if bands is None:
- bands = range(self.num_bands())
- if isinstance(bands, int):
- result = self._read(roi, [bands], buf)
- result = result[:, :, 0] # reduce dimensions
- else:
- result = self._read(roi, bands, buf)
- if self.__preprocess_function:
- return self.__preprocess_function(result, roi, bands)
- return result
-
- def set_preprocess(self, callback: Callable[[np.ndarray, rectangle.Rectangle, List[int]], np.ndarray]) -> None:
- """
- Set a preproprocessing function callback to be applied to the results of all reads on the image.
-
- The function takes the arguments callback(image, roi, bands), where image is the numpy array containing
- the read data, roi is the region of interest read, and bands is a list of the bands being read.
- """
- self.__preprocess_function = callback
-
- @abstractmethod
- def _read(self, roi, bands, buf=None):
- """
- Read the image of the given data type. An optional roi specifies the boundaries.
-
- This function is intended to be overwritten by subclasses.
- """
-
- def metadata(self): #pylint:disable=no-self-use
- """
- Returns a dictionary of metadata, in the format used by GDAL.
- """
- return {}
-
- @abstractmethod
- def size(self) -> Tuple[int, int]:
- """Return the size of this image in pixels, as (width, height)."""
-
- @abstractmethod
- def num_bands(self) -> int:
- """Return the number of bands in the image."""
-
- def block_aligned_roi(self, desired_roi: rectangle.Rectangle) -> rectangle.Rectangle:#pylint:disable=no-self-use
- """Return the block-aligned roi containing this image region, if applicable."""
- return desired_roi
-
- def width(self) -> int:
- """Return the number of columns."""
- return self.size()[0]
-
- def height(self) -> int:
- """Return the number of rows."""
- return self.size()[1]
-
- def tiles(self, width: int, height: int, min_width: int=0, min_height: int=0,
- overlap: int=0) -> Iterator[rectangle.Rectangle]:
- """Generator to yield ROIs for the image."""
- input_bounds = rectangle.Rectangle(0, 0, width=self.width(), height=self.height())
- return input_bounds.make_tile_rois(width, height, min_width=min_width, min_height=min_height,
- include_partials=True, overlap_amount=overlap)
-
- def roi_generator(self, requested_rois: Iterator[rectangle.Rectangle]) -> Iterator[rectangle.Rectangle]:
- """
- Generator that yields ROIs of blocks in the requested region.
- """
- block_rois = copy.copy(requested_rois)
-
- whole_bounds = rectangle.Rectangle(0, 0, width=self.width(), height=self.height())
- for roi in requested_rois:
- if not whole_bounds.contains_rect(roi):
- raise Exception('Roi outside image bounds: ' + str(roi) + str(whole_bounds))
-
- # gdal doesn't work reading multithreading. But this let's a thread
- # take care of IO input while we do computation.
- jobs = []
-
- total_rois = len(block_rois)
- while block_rois:
- # For the next (output) block, figure out the (input block) aligned
- # data read that we need to perform to get it.
- read_roi = self.block_aligned_roi(block_rois[0])
-
- applicable_rois = []
-
- # Loop through the remaining ROIs and apply the callback function to each
- # ROI that is contained in the section we read in.
- index = 0
- while index < len(block_rois):
-
- if not read_roi.contains_rect(block_rois[index]):
- index += 1
- continue
- applicable_rois.append(block_rois.pop(index))
-
- jobs.append((read_roi, applicable_rois))
-
- # only do a few reads ahead since otherwise we will exhaust our memory
- pending = []
- exe = concurrent.futures.ThreadPoolExecutor(1)
- NUM_AHEAD = 2
- for i in range(min(NUM_AHEAD, len(jobs))):
- pending.append(exe.submit(functools.partial(self.read, jobs[i][0])))
- num_remaining = total_rois
- for (i, (read_roi, rois)) in enumerate(jobs):
- buf = pending.pop(0).result()
- for roi in rois:
- x0 = roi.min_x - read_roi.min_x
- y0 = roi.min_y - read_roi.min_y
- num_remaining -= 1
- yield (roi, buf[x0:x0 + roi.width(), y0:y0 + roi.height(), :],
- (total_rois - num_remaining, total_rois))
- if i + NUM_AHEAD < len(jobs):
- pending.append(exe.submit(functools.partial(self.read, jobs[i + NUM_AHEAD][0])))
-
- def process_rois(self, requested_rois: Iterator[rectangle.Rectangle],
- callback_function: Callable[[rectangle.Rectangle, np.ndarray], None],
- show_progress: bool=False) -> None:
- """
- Process the given region broken up into blocks using the callback function.
- Each block will get the image data from each input image passed into the function.
- Data reading takes place in a separate thread, but the callbacks are executed
- in a consistent order on a single thread.
- """
- for (roi, buf, (i, total)) in self.roi_generator(requested_rois):
- callback_function(roi, buf)
- if show_progress:
- utilities.progress_bar('%d / %d' % (i, total), i / total, prefix='Blocks Processed:')
- if show_progress:
- print()
-
-class DeltaImageWriter(ABC):
- @abstractmethod
- def initialize(self, size, numpy_dtype, metadata=None):
- """
- Prepare for writing with the given size and dtype.
- """
-
- @abstractmethod
- def write(self, data, x, y):
- """
- Writes the data as a rectangular block starting at the given coordinates.
- """
-
- @abstractmethod
- def close(self):
- """
- Finish writing.
- """
-
- @abstractmethod
- def abort(self):
- """
- Cancel writing before finished.
- """
-
- def __del__(self):
- self.close()
-
- def __enter__(self):
- return self
-
- def __exit__(self, *unused):
- self.close()
- return False
diff --git a/delta/imagery/sources/loader.py b/delta/imagery/sources/loader.py
deleted file mode 100644
index 21f1271b..00000000
--- a/delta/imagery/sources/loader.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright © 2020, United States Government, as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All rights reserved.
-#
-# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
-# licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0.
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Load images given configuration.
-"""
-
-from . import worldview, landsat, tiff, npy
-
-_IMAGE_TYPES = {
- 'worldview' : worldview.WorldviewImage,
- 'landsat' : landsat.LandsatImage,
- 'tiff' : tiff.TiffImage,
- 'rgba' : tiff.RGBAImage,
- 'npy' : npy.NumpyImage
-}
-
-def register_image_type(image_type, image_class):
- """
- Register a custom image type for use by DELTA.
-
- image_type is a string specified in config files.
- image_class is a custom class that extends
- `delta.iamge.sources.delta_image.DeltaImage`.
- """
- global _IMAGE_TYPES #pylint: disable=global-statement
- _IMAGE_TYPES[image_type] = image_class
-
-def load(filename, image_type, preprocess=False):
- """
- Load an image of the appropriate type and parameters.
- """
- if image_type not in _IMAGE_TYPES:
- raise ValueError('Unexpected image_type %s.' % (image_type))
- img = _IMAGE_TYPES[image_type](filename)
- if preprocess:
- img.set_preprocess(preprocess)
- return img
-
-def load_image(image_set, index):
- """
- Load the specified image in the ImageSet.
- """
- return load(image_set[index], image_set.type(), preprocess=image_set.preprocess())
diff --git a/delta/imagery/utilities.py b/delta/imagery/utilities.py
index b64ca206..9ad5d6dc 100644
--- a/delta/imagery/utilities.py
+++ b/delta/imagery/utilities.py
@@ -27,6 +27,13 @@
def unpack_to_folder(compressed_path, unpack_folder):
"""
Unpack a file into the given folder.
+
+ Parameters
+ ----------
+ compressed_path: str
+ Zip or tar file path
+ unpack_folder: str
+ Folder to unpack to
"""
tmpdir = os.path.normpath(unpack_folder) + '_working'
@@ -42,13 +49,24 @@ def unpack_to_folder(compressed_path, unpack_folder):
except Exception as e:
shutil.rmtree(tmpdir) # Clear any partially unpacked results
raise RuntimeError('Caught exception unpacking compressed file: ' + compressed_path
- + '\n' + str(e))
+ + '\n' + str(e)) from e
os.rename(tmpdir, unpack_folder) # Clean up
def progress_bar(text, fill_amount, prefix = '', length = 80): #pylint: disable=W0613
"""
Prints a progress bar. Call multiple times with increasing progress to
overwrite the printed line.
+
+ Parameters
+ ----------
+ text: str
+ Text to print after progress bar
+ fill_amount: float
+ Percent to fill bar, from 0.0 - 1.0
+ prefix: str
+ Text to print before progress bar
+ length: int
+ Number of characters to fill as bar
"""
filled_length = int(length * fill_amount)
fill_char = '█' if sys.stdout.encoding.lower() == 'utf-8' else 'X'
diff --git a/delta/ml/config_parser.py b/delta/ml/config_parser.py
new file mode 100644
index 00000000..0c0b0611
--- /dev/null
+++ b/delta/ml/config_parser.py
@@ -0,0 +1,343 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Functions to support loading custom ML-related objects from dictionaries specified
+in yaml files. Includes constructing custom neural networks and more.
+"""
+from collections.abc import Mapping
+import copy
+import functools
+from typing import Callable, List, Union
+
+import tensorflow
+import tensorflow.keras.layers
+import tensorflow.keras.losses
+import tensorflow.keras.models
+
+from delta.config import config
+import delta.config.extensions as extensions
+
+class _LayerWrapper:
+ def __init__(self, layer_type, layer_name, inputs, params, all_layers):
+ """
+ all_layers is a name indexed dictionary of LayerWrappers for all the layers,
+ shared between them.
+ """
+ self._layer_type = layer_type
+ self.name = layer_name
+ self._inputs = inputs
+ lc = extensions.layer(layer_type)
+ if lc is None:
+ lc = getattr(tensorflow.keras.layers, layer_type, None)
+ if lc is None:
+ raise ValueError('Unknown layer type %s.' % (layer_type))
+ self.layer = lc(**params)
+ self._sub_layers = None
+ self._tensor = None
+ all_layers[layer_name] = self
+ self._all_layers = all_layers
+
+ def is_input(self):
+ return self._layer_type == 'Input'
+
+ def sub_layer(self, name):
+ assert self._sub_layers, 'Layer %s does not support sub-layers.' % (self.layer.name)
+ assert name in self._sub_layers, ('Layer %s not found in ' % (name)) + str(self._sub_layers)
+ return self._sub_layers[name]
+
+ # TODO: will crash if there is a cycle in the graph
+ def output_tensor(self):
+ """
+ Constructs the output tensor with preceding layers as inputs.
+ """
+ if self._tensor is not None:
+ return self._tensor
+ inputs = []
+ for k in self._inputs:
+ if isinstance(k, tensorflow.Tensor):
+ inputs.append(k)
+ continue
+ if isinstance(k, int) or '/' not in k:
+ l = self._all_layers[k].output_tensor()
+ inputs.append(l)
+ continue
+ # getting nested layer
+ parts = k.split('/')
+ input_layer = parts[0]
+ if input_layer not in self._all_layers:
+ raise ValueError('Input layer ' + str(input_layer) + ' not found.')
+ self._all_layers[input_layer].output_tensor() # compute it if it hasn't been
+ cur = self._all_layers[input_layer].sub_layer(k[len(parts[0]) + 1:])
+
+ if isinstance(self._tensor, tensorflow.keras.layers.Layer):
+ inputs.append(cur.output)
+ else:
+ inputs.append(cur)
+ if inputs:
+ if len(inputs) == 1:
+ inputs = inputs[0]
+ self._tensor = self.layer(inputs)
+ if isinstance(self._tensor, tuple):
+ self._sub_layers = self._tensor[1]
+ self._tensor = self._tensor[0]
+ if isinstance(self._tensor, tensorflow.keras.layers.Layer):
+ self._tensor = self._tensor.output
+ else:
+ self._tensor = self.layer
+ return self._tensor
+
+def _make_layer(layer_dict, layer_id, prev_layer, all_layers):
+ """
+ Constructs a layer specified in layer_dict.
+ layer_id is the order in the order in the config file.
+ Assumes layer_dict only contains the key which is the
+ layer type, mapped to a sub-dict with properly named parameters for constructing
+ the layer, and the additional fields:
+
+ * `name` (optional): a name to refer to the layer by
+ * `inputs` (optional): the name or a list of names of
+ the preceding layers (defaults to previous in list)
+ """
+ if len(layer_dict.keys()) > 1:
+ raise ValueError('Layer with multiple types.')
+ layer_type = next(layer_dict.keys().__iter__())
+ l = layer_dict[layer_type]
+ if l is None:
+ l = {}
+
+ inputs = [prev_layer]
+ if layer_type == 'Input':
+ inputs = []
+ if 'name' in l:
+ layer_id = l['name']
+ if 'inputs' in l:
+ inputs = l['inputs']
+ l = copy.copy(l) # don't modify original dict
+ del l['inputs']
+ if isinstance(inputs, (int, str)):
+ inputs = [inputs]
+
+ return _LayerWrapper(layer_type, layer_id, inputs, l, all_layers)
+
+def _make_model(layer_list):
+ """
+ Makes a model from a list of layers.
+ """
+ assert layer_list is not None, 'No model specified!'
+
+ prev_layer = 0
+ last = None
+ all_layers = {}
+ for (i, l) in enumerate(layer_list):
+ last = _make_layer(l, i, prev_layer, all_layers)
+ prev_layer = last.name
+
+ outputs = last.output_tensor()
+ inputs = [l.output_tensor() for l in all_layers.values() if l.is_input()]
+
+ if len(inputs) == 1:
+ inputs = inputs[0]
+ return tensorflow.keras.models.Model(inputs=inputs, outputs=outputs)
+
+def _apply_params(model_dict, exposed_params):
+ """
+ Apply the parameters in exposed_params and in model_dict['params']
+ to the fields in model_dict, returning a copy.
+ """
+ defined_params = {}
+ if 'params' in model_dict and model_dict['params'] is not None:
+ defined_params = model_dict['params']
+
+ params = {**exposed_params, **defined_params}
+ # replace parameters recursively in all layers
+ def recursive_dict_list_apply(d, func):
+ if isinstance(d, Mapping):
+ for k, v in d.items():
+ d[k] = recursive_dict_list_apply(v, func)
+ return d
+ if isinstance(d, list):
+ return list(map(functools.partial(recursive_dict_list_apply, func=func), d))
+ if isinstance(d, str):
+ return func(d)
+ return d
+ def apply_params(s):
+ for (k, v) in params.items():
+ if s == k:
+ return v
+ return s
+ model_dict_copy = copy.deepcopy(model_dict)
+ recursive_dict_list_apply(model_dict_copy, apply_params)
+
+ # checks if the first layer is an Input, if not insert one
+ layer_list = model_dict_copy['layers']
+ assert layer_list is not None, 'No model specified!'
+ first_layer_type = next(layer_list[0].keys().__iter__())
+ if first_layer_type != 'Input' and 'input' not in layer_list[0][first_layer_type]:
+ model_dict_copy['layers'] = [{'Input' : {'shape' : params['in_shape']}}] + layer_list
+
+ return model_dict_copy
+
+def model_from_dict(model_dict: dict, exposed_params: dict) -> Callable[[], tensorflow.keras.models.Model]:
+ """
+ Construct a model.
+
+ Parameters
+ ----------
+ model_dict: dict
+ Config dictionary describing the model
+ exposed_params: dict
+ Dictionary of parameter names and values to substitute.
+
+ Returns
+ -------
+ Callable[[], tensorflow.keras.models.Model]:
+ Model constructor function.
+ """
+ model_dict = _apply_params(model_dict, exposed_params)
+ return functools.partial(_make_model, model_dict['layers'])
+
+def _parse_str_or_dict(spec, type_name):
+ if isinstance(spec, str):
+ return (spec, {})
+ if isinstance(spec, dict):
+ assert len(spec.keys()) == 1, 'Only one %s may be specified.' % (type_name)
+ name = list(spec.keys())[0]
+ return (name, spec[name])
+ raise ValueError('Unexpected entry for %s.' % (type_name))
+
+def loss_from_dict(loss_spec: Union[dict, str]) -> tensorflow.keras.losses.Loss:
+ """
+ Construct a loss function.
+
+ Parameters
+ ----------
+ loss_spec: Union[dict, str]
+ Specification of the loss function. Either a string that is compatible
+ with the keras interface (e.g. 'categorical_crossentropy') or an object defined by a dict
+ of the form {'LossFunctionName': {'arg1':arg1_val, ...,'argN',argN_val}}
+
+ Returns
+ -------
+ tensorflow.keras.losses.Loss
+ The loss object.
+ """
+ (name, params) = _parse_str_or_dict(loss_spec, 'loss function')
+ lc = extensions.loss(name)
+ if lc is None:
+ lc = getattr(tensorflow.keras.losses, name, None)
+ if lc is None:
+ raise ValueError('Unknown loss type %s.' % (name))
+ if isinstance(lc, type) and issubclass(lc, tensorflow.keras.losses.Loss):
+ lc = lc(**params)
+ return lc
+
+def metric_from_dict(metric_spec: Union[dict, str]) -> tensorflow.keras.metrics.Metric:
+ """
+ Construct a metric.
+
+ Parameters
+ ----------
+ metric_spec: Union[dict, str]
+ Config dictionary or string defining the metric
+
+ Returns
+ -------
+ tensorflow.keras.metrics.Metric
+ The metric object.
+ """
+ (name, params) = _parse_str_or_dict(metric_spec, 'metric')
+ mc = extensions.metric(name)
+ if mc is None:
+ mc = getattr(tensorflow.keras.metrics, name, None)
+ if mc is None:
+ try:
+ mc = loss_from_dict(metric_spec)
+ except:
+ raise ValueError('Unknown metric %s.' % (name)) #pylint:disable=raise-missing-from
+ if isinstance(mc, type) and issubclass(mc, tensorflow.keras.metrics.Metric):
+ mc = mc(**params)
+ return mc
+
+def optimizer_from_dict(spec: Union[dict, str]) -> tensorflow.keras.optimizers.Optimizer:
+ """
+ Construct an optimizer from a dictionary or string.
+
+ Parameters
+ ----------
+ spec: Union[dict, str]
+ Config dictionary or string defining an optimizer
+
+ Returns
+ -------
+ tensorflow.keras.optimizers.Optimizer
+ The optimizer object.
+ """
+ (name, params) = _parse_str_or_dict(spec, 'optimizer')
+ mc = getattr(tensorflow.keras.optimizers, name, None)
+ if mc is None:
+ raise ValueError('Unknown optimizer %s.' % (name))
+ return mc(**params)
+
+def callback_from_dict(callback_dict: Union[dict, str]) -> tensorflow.keras.callbacks.Callback:
+ """
+ Construct a callback from a dictionary.
+
+ Parameters
+ ----------
+ callback_dict: Union[dict, str]
+ Config dictionary defining a callback.
+
+ Returns
+ -------
+ tensorflow.keras.callbacks.Callback
+ The callback object.
+ """
+ assert len(callback_dict.keys()) == 1, f'Error: Callback has more than one type {callback_dict.keys()}'
+
+ cb_type = next(iter(callback_dict.keys()))
+ callback_class = extensions.callback(cb_type)
+ if callback_class is None:
+ callback_class = getattr(tensorflow.keras.callbacks, cb_type, None)
+ if callback_dict[cb_type] is None:
+ callback_dict[cb_type] = {}
+ if callback_class is None:
+ raise ValueError('Unknown callback %s.' % (cb_type))
+ return callback_class(**callback_dict[cb_type])
+
+def config_callbacks() -> List[tensorflow.keras.callbacks.Callback]:
+ """
+ Returns
+ -------
+ List[tensorflow.keras.callbacks.Callback]
+ List of callbacks specified in the config file.
+ """
+ if not config.train.callbacks() is None:
+ return [callback_from_dict(callback) for callback in config.train.callbacks()]
+ return []
+
+def config_model(num_bands: int) -> Callable[[], tensorflow.keras.models.Model]:
+ """
+ Returns
+ -------
+ Callable[[], tensorflow.keras.models.Model]
+ A function to construct the model given in the config file.
+ """
+ params_exposed = {'num_classes' : len(config.dataset.classes),
+ 'num_bands' : num_bands}
+
+ return model_from_dict(config.train.network.to_dict(), params_exposed)
diff --git a/delta/ml/io.py b/delta/ml/io.py
index 12e4733f..14d7fbfb 100644
--- a/delta/ml/io.py
+++ b/delta/ml/io.py
@@ -20,13 +20,60 @@
"""
import h5py
+import numpy as np
+import tensorflow.keras.backend as K
from delta.config import config
def save_model(model, filename):
"""
Save a model. Includes DELTA configuration.
+
+ Parameters
+ ----------
+ model: tensorflow.keras.models.Model
+ The model to save.
+ filename: str
+ Output filename.
"""
model.save(filename, save_format='h5')
with h5py.File(filename, 'r+') as f:
f.attrs['delta'] = config.export()
+
+def print_layer(l):
+ """
+ Print a layer to stdout.
+
+ l: tensorflow.keras.layers.Layer
+ The layer to print.
+ """
+ s = "{:<25}".format(l.name) + ' ' + '{:<20}'.format(str(l.input_shape)) + \
+ ' -> ' + '{:<20}'.format(str(l.output_shape))
+ c = l.get_config()
+ if 'strides' in c:
+ s += ' s: ' + '{:<10}'.format(str(c['strides']))
+ if 'kernel_size' in c:
+ s += ' ks: ' + str(c['kernel_size'])
+ print(s)
+
+def print_network(a, tile_shape=None):
+ """
+ Print a model to stdout.
+
+ a: tensorflow.keras.models.Model
+ The model to print.
+ tile_shape: Optional[Tuple[int, int]]
+ If specified, print layer output sizes (necessary for FCN only).
+ """
+ for l in a.layers:
+ print_layer(l)
+ in_shape = a.layers[0].input_shape[0]
+ if tile_shape is not None:
+ in_shape = (in_shape[0], tile_shape[0], tile_shape[1], in_shape[3])
+ out_shape = a.compute_output_shape(in_shape)
+ print('Size: ' + str(in_shape[1:]) + ' --> ' + str(out_shape[1:]))
+ if out_shape[1] is not None and out_shape[2] is not None:
+ print('Compression Rate - ', out_shape[1] * out_shape[2] * out_shape[3] /
+ (in_shape[1] * in_shape[2] * in_shape[3]))
+ print('Layers - ', len(a.layers))
+ print('Trainable Parameters - ', np.sum([K.count_params(w) for w in a.trainable_weights]))
diff --git a/delta/ml/ml_config.py b/delta/ml/ml_config.py
index 10ea3bcf..dad06c44 100644
--- a/delta/ml/ml_config.py
+++ b/delta/ml/ml_config.py
@@ -22,6 +22,8 @@
# when tensorflow isn't needed
import os.path
+from typing import Optional
+
import appdirs
import pkg_resources
import yaml
@@ -29,47 +31,24 @@
from delta.imagery.imagery_config import ImageSet, ImageSetConfig, load_images_labels
import delta.config as config
-def loss_function_factory(loss_spec):
- '''
- loss_function_factory - Creates a loss function object, if an object is specified in the
- config file, or a string if that is all that is specified.
-
- :param: loss_spec Specification of the loss function. Either a string that is compatible
- with the keras interface (e.g. 'categorical_crossentropy') or an object defined by a dict
- of the form {'LossFunctionName': {'arg1':arg1_val, ...,'argN',argN_val}}
- '''
- import tensorflow.keras.losses # pylint: disable=import-outside-toplevel
-
- if isinstance(loss_spec, str):
- return loss_spec
-
- if isinstance(loss_spec, list):
- assert len(loss_spec) == 1, 'Too many loss functions specified'
- assert isinstance(loss_spec[0], dict), '''Loss functions objects and parameters must
- be specified as a yaml dictionary object
- '''
- assert len(loss_spec[0].keys()) == 1, f'Too many loss functions specified: {dict.keys()}'
- loss_type = list(loss_spec[0].keys())[0]
- loss_fn_args = loss_spec[0][loss_type]
-
- loss_class = getattr(tensorflow.keras.losses, loss_type, None)
- return loss_class(**loss_fn_args)
-
- raise RuntimeError(f'Did not recognize the loss function specification: {loss_spec}')
-
-
class ValidationSet:#pylint:disable=too-few-public-methods
"""
Specifies the images and labels in a validation set.
"""
- def __init__(self, images=None, labels=None, from_training=False, steps=1000):
+ def __init__(self, images: Optional[ImageSet]=None, labels: Optional[ImageSet]=None,
+ from_training: bool=False, steps: int=1000):
"""
- Uses the specified `delta.imagery.sources.ImageSet`s images and labels.
-
- If `from_training` is `True`, instead takes samples from the training set
- before they are used for training.
-
- The number of samples to use for validation is set by `steps`.
+ Parameters
+ ----------
+ images: ImageSet
+ Validation images.
+ labels: ImageSet
+ Optional, validation labels.
+ from_training: bool
+ If true, ignore images and labels arguments and take data from the training imagery.
+ The validation data will not be used for training.
+ steps: int
+ If from_training is true, take this many batches for validation.
"""
self.images = images
self.labels = labels
@@ -80,18 +59,21 @@ class TrainingSpec:#pylint:disable=too-few-public-methods,too-many-arguments
"""
Options used in training by `delta.ml.train.train`.
"""
- def __init__(self, batch_size, epochs, loss_function, metrics, validation=None, steps=None,
- chunk_stride=1, optimizer='adam'):
+ def __init__(self, batch_size, epochs, loss, metrics, validation=None, steps=None,
+ stride=None, optimizer='Adam'):
self.batch_size = batch_size
self.epochs = epochs
- self.loss_function = loss_function
+ self.loss = loss
self.validation = validation
self.steps = steps
self.metrics = metrics
- self.chunk_stride = chunk_stride
+ self.stride = stride
self.optimizer = optimizer
-class NetworkModelConfig(config.DeltaConfigComponent):
+class NetworkConfig(config.DeltaConfigComponent):
+ """
+ Configuration for a neural network.
+ """
def __init__(self):
super().__init__()
self.register_field('yaml_file', str, 'yaml_file', config.validate_path,
@@ -102,13 +84,12 @@ def __init__(self):
# overwrite model entirely if updated (don't want combined layers from multiple files)
def _load_dict(self, d : dict, base_dir):
super()._load_dict(d, base_dir)
- if 'yaml_file' in d:
- self._config_dict['layers'] = None
- elif 'layers' in d:
+ if 'layers' in d and d['layers'] is not None:
self._config_dict['yaml_file'] = None
- if 'yaml_file' in d and 'layers' in d and d['yaml_file'] is not None and d['layers'] is not None:
- raise ValueError('Specified both yaml file and layers in model.')
- if 'yaml_file' in d and d['yaml_file'] is not None:
+ elif 'yaml_file' in d and d['yaml_file'] is not None:
+ self._config_dict['layers'] = None
+ if 'layers' in d and d['layers'] is not None:
+ raise ValueError('Specified both yaml file and layers in model.')
yaml_file = d['yaml_file']
resource = os.path.join('config', yaml_file)
if not os.path.exists(yaml_file) and pkg_resources.resource_exists('delta', resource):
@@ -118,23 +99,21 @@ def _load_dict(self, d : dict, base_dir):
with open(yaml_file, 'r') as f:
self._config_dict.update(yaml.safe_load(f))
-class NetworkConfig(config.DeltaConfigComponent):
- def __init__(self):
- super().__init__()
- self.register_field('chunk_size', int, 'chunk_size', config.validate_positive,
- 'Width of an image chunk to input to the neural network.')
- self.register_field('output_size', int, 'output_size', config.validate_positive,
- 'Width of an image chunk to output from the neural network.')
-
- self.register_arg('chunk_size', '--chunk-size')
- self.register_arg('output_size', '--output-size')
- self.register_component(NetworkModelConfig(), 'model')
-
- def setup_arg_parser(self, parser, components = None) -> None:
- group = parser.add_argument_group('Network')
- super().setup_arg_parser(group, components)
+def validate_size(size, _):
+ """
+ Validate an image region size.
+ """
+ if size is None:
+ return size
+ assert len(size) == 2, 'Size must be tuple.'
+ assert isinstance(size[0], int) and isinstance(size[1], int), 'Size must be integer.'
+ assert size[0] > 0 and size[1] > 0, 'Size must be positive.'
+ return size
class ValidationConfig(config.DeltaConfigComponent):
+ """
+ Configuration for training validation.
+ """
def __init__(self):
super().__init__()
self.register_field('steps', int, 'steps', config.validate_positive,
@@ -171,32 +150,44 @@ def labels(self) -> ImageSet:
config.config.dataset.classes)
return self.__labels
+def _validate_stride(stride, _):
+ if stride is None:
+ return None
+ if isinstance(stride, int):
+ stride = (stride, stride)
+ assert len(stride) == 2, 'Stride must have two components.'
+ assert isinstance(stride[0], int) and isinstance(stride[1], int), 'Stride must be integer.'
+ assert stride[0] > 0 and stride[1] > 0, 'Stride must be positive.'
+ return stride
+
class TrainingConfig(config.DeltaConfigComponent):
+ """
+ Configuration for training.
+ """
def __init__(self):
- super().__init__()
- self.register_field('chunk_stride', int, None, config.validate_positive,
+ super().__init__(section_header='Training')
+ self.register_field('stride', (list, int, None), None, _validate_stride,
'Pixels to skip when iterating over chunks. A value of 1 means to take every chunk.')
self.register_field('epochs', int, None, config.validate_positive,
'Number of times to repeat training on the dataset.')
self.register_field('batch_size', int, None, config.validate_positive,
'Features to group into each training batch.')
- self.register_field('loss_function', (str, list), None, None, 'Keras loss function.')
+ self.register_field('loss', (str, dict), None, None, 'Keras loss function.')
self.register_field('metrics', list, None, None, 'List of metrics to apply.')
- self.register_field('steps', int, None, config.validate_positive, 'Batches to train per epoch.')
- self.register_field('optimizer', str, None, None, 'Keras optimizer to use.')
-
- self.register_arg('chunk_stride', '--chunk-stride')
+ self.register_field('steps', int, None, config.validate_non_negative, 'Batches to train per epoch.')
+ self.register_field('optimizer', (str, dict), None, None, 'Keras optimizer to use.')
+ self.register_field('callbacks', list, 'callbacks', None, 'Callbacks used to modify training')
self.register_arg('epochs', '--epochs')
self.register_arg('batch_size', '--batch-size')
self.register_arg('steps', '--steps')
+ self.register_field('log_folder', str, 'log_folder', config.validate_path,
+ 'Directory where dataset progress is recorded.')
+ self.register_field('resume_cutoff', int, 'resume_cutoff', None,
+ 'When resuming a dataset, skip images where we have read this many tiles.')
self.register_component(ValidationConfig(), 'validation')
self.register_component(NetworkConfig(), 'network')
self.__training = None
- def setup_arg_parser(self, parser, components = None) -> None:
- group = parser.add_argument_group('Training')
- super().setup_arg_parser(group, components)
-
def spec(self) -> TrainingSpec:
"""
Returns the options configuring training.
@@ -208,19 +199,25 @@ def spec(self) -> TrainingSpec:
if not from_training:
(vimg, vlabels) = (self._components['validation'].images(), self._components['validation'].labels())
validation = ValidationSet(vimg, vlabels, from_training, vsteps)
- loss_fn = loss_function_factory(self._config_dict['loss_function'])
self.__training = TrainingSpec(batch_size=self._config_dict['batch_size'],
epochs=self._config_dict['epochs'],
- loss_function=loss_fn,
+ loss=self._config_dict['loss'],
metrics=self._config_dict['metrics'],
validation=validation,
steps=self._config_dict['steps'],
- chunk_stride=self._config_dict['chunk_stride'],
+ stride=self._config_dict['stride'],
optimizer=self._config_dict['optimizer'])
return self.__training
+ def _load_dict(self, d : dict, base_dir):
+ self.__training = None
+ super()._load_dict(d, base_dir)
+
class MLFlowCheckpointsConfig(config.DeltaConfigComponent):
+ """
+ Configure MLFlow checkpoints.
+ """
def __init__(self):
super().__init__()
self.register_field('frequency', int, 'frequency', None,
@@ -229,6 +226,9 @@ def __init__(self):
'If true, only keep the most recent checkpoint.')
class MLFlowConfig(config.DeltaConfigComponent):
+ """
+ Configure MLFlow.
+ """
def __init__(self):
super().__init__()
self.register_field('enabled', bool, 'enabled', None, 'Enable MLFlow.')
@@ -251,6 +251,9 @@ def uri(self) -> str:
return uri
class TensorboardConfig(config.DeltaConfigComponent):
+ """
+ Tensorboard configuration.
+ """
def __init__(self):
super().__init__()
self.register_field('enabled', bool, 'enabled', None, 'Enable Tensorboard.')
@@ -267,21 +270,12 @@ def dir(self) -> str:
def register():
"""
- Registers imagery config options with the global config manager.
+ Registers machine learning config options with the global config manager.
The arguments enable command line arguments for different components.
"""
- if not hasattr(config.config, 'general'):
- config.config.register_component(config.DeltaConfigComponent('General'), 'general')
-
config.config.general.register_field('gpus', int, 'gpus', None, 'Number of gpus to use.')
config.config.general.register_arg('gpus', '--gpus')
- config.config.general.register_field('stop_on_input_error', bool, 'stop_on_input_error', None,
- 'If false, skip past bad input images.')
- config.config.general.register_arg('stop_on_input_error', '--bypass-input-errors',
- action='store_const', const=False, type=None)
- config.config.general.register_arg('stop_on_input_error', '--stop-on-input-error',
- action='store_const', const=True, type=None)
config.config.register_component(TrainingConfig(), 'train')
config.config.register_component(MLFlowConfig(), 'mlflow')
diff --git a/delta/ml/model_parser.py b/delta/ml/model_parser.py
deleted file mode 100644
index 3ba37ff3..00000000
--- a/delta/ml/model_parser.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright © 2020, United States Government, as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All rights reserved.
-#
-# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
-# licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0.
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Construct sequential neural networks using the Tensorflow-Keras API from
-dictionaries. Assumes that the names of the parameters for the layer constructor functions are as
-given in the Tensorflow API documentation.
-"""
-import functools
-from typing import Callable
-
-import tensorflow
-import tensorflow.keras.models
-import tensorflow.keras.layers
-
-from delta.config import config
-
-from . import layers
-
-class _LayerWrapper:
- def __init__(self, layer_type, layer_name, inputs, params):
- self._layer_type = layer_type
- self._layer_name = layer_name
- self._inputs = inputs
- layer_class = getattr(tensorflow.keras.layers, self._layer_type, None)
- if layer_class is None and self._layer_type in layers.ALL_LAYERS:
- layer_class = layers.ALL_LAYERS[self._layer_type]
- if layer_class is None:
- raise ValueError('Unknown layer type %s.' % (self._layer_type))
- self._layer_constructor = layer_class(**params)
- self._layer = None
-
- def is_input(self):
- return self._layer_type == 'Input'
-
- # TODO: will crash if there is a cycle in the graph
- def layer(self, layer_dict):
- """
- Constructs the layers with preceding layers as inputs. layer_dict is a name
- indexed dictionary of LayerWrappers for all the layers.
- """
- if self._layer is not None:
- return self._layer
- inputs = []
- for k in self._inputs:
- if isinstance(k, tensorflow.Tensor):
- inputs.append(k)
- continue
- if k not in layer_dict:
- raise ValueError('Input layer ' + str(k) + ' not found.')
- inputs.append(layer_dict[k].layer(layer_dict))
- if inputs:
- if len(inputs) == 1:
- inputs = inputs[0]
- self._layer = self._layer_constructor(inputs)
- else:
- self._layer = self._layer_constructor
- return self._layer
-
-def _make_layer(layer_dict, layer_id, prev_layer, param_dict):
- """
- Constructs a layer specified in layer_dict, possibly using parameters specified in
- param_dict. layer_id is the order in the order in the config file.
- Assumes layer_dict only contains the properly named parameters for constructing
- the layer, and the additional fields:
-
- * `type`: the type of keras layer.
- * `name` (optional): a name to refer to the layer by
- * `inputs` (optional): the name or a list of names of
- the preceding layers (defaults to previous in list)
- """
- if len(layer_dict.keys()) > 1:
- raise ValueError('Layer with multiple types.')
- layer_type = next(layer_dict.keys().__iter__())
- l = layer_dict[layer_type]
- if l is None:
- l = {}
-
- for (k, v) in l.items():
- if isinstance(v, str) and v in param_dict.keys():
- l[k] = param_dict[v]
-
- inputs = [prev_layer]
- if layer_type == 'Input':
- inputs = []
- if 'name' in l:
- layer_id = l['name']
- if 'inputs' in l:
- inputs = l['inputs']
- del l['inputs']
- if isinstance(inputs, (int, str)):
- inputs = [inputs]
-
- return (layer_id, _LayerWrapper(layer_type, layer_id, inputs, l))
-
-def _make_model(model_dict, exposed_params):
- layer_list = model_dict['layers']
- defined_params = {}
- if 'params' in model_dict and model_dict['params'] is not None:
- defined_params = model_dict['params']
-
- params = {**exposed_params, **defined_params}
- layer_dict = {}
- last = None
- first_layer_type = next(layer_list[0].keys().__iter__())
- if first_layer_type != 'Input' and 'input' not in layer_list[0][first_layer_type]:
- layer_list = [{'Input' : {'shape' : params['in_shape']}}] + layer_list
- #if layer_list[0]['type'] != 'Input' and 'input' not in layer_list[0]:
- prev_layer = 0
- for (i, l) in enumerate(layer_list):
- (layer_id, layer) = _make_layer(l, i, prev_layer, params)
- last = layer
- layer_dict[layer_id] = layer
- prev_layer = layer_id
-
- outputs = last.layer(layer_dict)
- inputs = [l.layer(layer_dict) for l in layer_dict.values() if l.is_input()]
-
- if len(inputs) == 1:
- inputs = inputs[0]
- return tensorflow.keras.models.Model(inputs=inputs, outputs=outputs)
-
-def model_from_dict(model_dict, exposed_params) -> Callable[[], tensorflow.keras.models.Sequential]:
- """
- Creates a function that returns a sequential model from a dictionary.
- """
- return functools.partial(_make_model, model_dict, exposed_params)
-
-def config_model(num_bands: int) -> Callable[[], tensorflow.keras.models.Sequential]:
- """
- Creates the model specified in the configuration.
- """
- in_data_shape = (config.train.network.chunk_size(), config.train.network.chunk_size(), num_bands)
- out_data_shape = (config.train.network.output_size(), config.train.network.output_size(),
- len(config.dataset.classes))
-
- params_exposed = {'out_shape' : out_data_shape,
- 'out_dims' : out_data_shape[0] * out_data_shape[1] * out_data_shape[2],
- 'in_shape' : in_data_shape,
- 'in_dims' : in_data_shape[0] * in_data_shape[1] * in_data_shape[2],
- 'num_bands' : in_data_shape[2]}
-
- return model_from_dict(config.train.network.model.to_dict(), params_exposed)
diff --git a/delta/ml/predict.py b/delta/ml/predict.py
index 6c3e752e..50ac2981 100644
--- a/delta/ml/predict.py
+++ b/delta/ml/predict.py
@@ -21,32 +21,35 @@
"""
from abc import ABC, abstractmethod
+import math
import numpy as np
import tensorflow as tf
from delta.imagery import rectangle
-from delta.config import config
-
-#pylint: disable=unsubscriptable-object
-# Pylint was barfing lines 32 and 76. See relevant bug report
-# https://github.com/PyCQA/pylint/issues/1498
-
-_TILE_SIZE = 256
class Predictor(ABC):
"""
Abstract class to run prediction for an image given a model.
"""
- def __init__(self, model, show_progress=False):
+ def __init__(self, model, tile_shape=None, show_progress=False):
self._model = model
self._show_progress = show_progress
+ self._tile_shape = tile_shape
@abstractmethod
- def _initialize(self, shape, label, image):
+ def _initialize(self, shape, image, label=None):
"""
Called at the start of a new prediction.
- The output shape, label image, and image being read are passed as inputs.
+
+ Parameters
+ ----------
+ shape: (int, int)
+ The final output shape from the network.
+ image: delta.imagery.delta_image.DeltaImage
+ The image to classify.
+ label: delta.imagery.delta_image.DeltaImage
+ The label image, if provided (otherwise None).
"""
def _complete(self):
@@ -56,34 +59,76 @@ def _abort(self):
"""Cancel the operation and cleanup neatly."""
@abstractmethod
- def _process_block(self, pred_image, x, y, labels):
+ def _process_block(self, pred_image: np.ndarray, x: int, y: int, labels: np.ndarray, label_nodata):
"""
- Processes a predicted block. The predictions are in pred_image,
- (sx, sy) is the starting coordinates of the block, and the corresponding labels
- if available are passed as labels.
+ Processes a predicted block. Must be overriden in subclasses.
+
+ Parameters
+ ----------
+ pred_image: numpy.ndarray
+ Output of model for a block of the image.
+ x: int
+ Top-left x coordinate of block.
+ y: int
+ Top-left y coordinate of block.
+ labels: numpy.ndarray
+ Labels (or None if not available) for same block as `pred_image`.
+ label_nodata: dtype of labels
+ Pixel value for nodata (or None).
"""
- def _predict_array(self, data):
+ def _predict_array(self, data: np.ndarray, image_nodata_value):
+ """
+ Runs model on data.
+
+ Parameters
+ ----------
+ data: np.ndarray
+ Block of image to apply the model to.
+ image_nodata_value: dtype of data
+ Nodata value in image. If given, nodata values are
+ replaced with nans in output.
+
+ Returns
+ -------
+ np.ndarray:
+ Result of applying model to data.
+ """
net_input_shape = self._model.input_shape[1:]
net_output_shape = self._model.output_shape[1:]
+ image = tf.convert_to_tensor(data)
+ image = tf.expand_dims(image, 0)
+
assert net_input_shape[2] == data.shape[2],\
'Model expects %d input channels, data has %d channels' % (net_input_shape[2], data.shape[2])
+ # supports variable input size, just toss everything in
+ if net_input_shape[0] is None and net_input_shape[1] is None:
+ result = np.squeeze(self._model.predict_on_batch(image))
+ if image_nodata_value is not None:
+ x0 = (data.shape[0] - result.shape[0]) // 2
+ y0 = (data.shape[1] - result.shape[1]) // 2
+ invalid = (data if len(data.shape) == 2 else \
+ data[:, :, 0])[x0:x0 + result.shape[0], y0:y0 + result.shape[1]] == image_nodata_value
+ if len(result.shape) == 2:
+ result[invalid] = math.nan
+ else:
+ result[invalid, :] = math.nan
+ return result
+
out_shape = (data.shape[0] - net_input_shape[0] + net_output_shape[0],
data.shape[1] - net_input_shape[1] + net_output_shape[1])
out_type = tf.dtypes.as_dtype(self._model.dtype)
- image = tf.convert_to_tensor(data)
- image = tf.expand_dims(image, 0)
chunks = tf.image.extract_patches(image, [1, net_input_shape[0], net_input_shape[1], 1],
[1, net_output_shape[0], net_output_shape[1], 1],
[1, 1, 1, 1], padding='VALID')
chunks = tf.reshape(chunks, (-1,) + net_input_shape)
best = np.zeros((chunks.shape[0],) + net_output_shape, dtype=out_type.as_numpy_dtype)
- BATCH_SIZE = int(config.io.block_size_mb() * 1024 * 1024 / net_input_shape[0] / net_input_shape[1] /
- net_input_shape[2] / out_type.size)
- assert BATCH_SIZE > 0, 'block_size_mb too small.'
+ # do 8 MB at a time... this is arbitrary, may want to change in future
+ BATCH_SIZE = max(1, int(8 * 1024 * 1024 / net_input_shape[0] / net_input_shape[1] /
+ net_input_shape[2] / out_type.size))
for i in range(0, chunks.shape[0], BATCH_SIZE):
best[i:i+BATCH_SIZE] = self._model.predict_on_batch(chunks[i:i+BATCH_SIZE])
@@ -93,29 +138,68 @@ def _predict_array(self, data):
c = (chunk_idx % ( out_shape[1] // net_output_shape[1])) * net_output_shape[1]
retval[r:r+net_output_shape[0],c:c+net_output_shape[1],:] = best[chunk_idx,:,:,:]
+ if image_nodata_value is not None:
+ ox = (data.shape[0] - out_shape[0]) // 2
+ oy = (data.shape[1] - out_shape[1]) // 2
+ output_slice = data[ox:-ox, oy:-oy, 0]
+ retval[output_slice == image_nodata_value] = math.nan
+
return retval
- def predict(self, image, label=None, input_bounds=None):
+ def predict(self, image, label=None, input_bounds=None, overlap=(0, 0)):
"""
- Runs the model on `image`, comparing the results to `label` if specified.
- Results are limited to `input_bounds`. Returns output, the meaning of which
- depends on the subclass.
+ Runs the model on an image. The behavior is specific to the subclass.
+
+ Parameters
+ ----------
+ image: delta.imagery.delta_image.DeltaImage
+ Image to evalute.
+ label: delta.imagery.delta_image.DeltaImage
+ Optional label to compare to.
+ input_bounds: delta.imagery.rectangle.Rectangle
+ If specified, only evaluate the given portion of the image.
+ overlap: (int, int)
+ `predict` evaluates the image by selecting tiles, dependent on the tile_shape
+ provided in the subclass. If an overlap is specified, the tiles will be overlapped
+ by the given amounts in the x and y directions. Subclasses may select or interpolate
+ to favor tile interior pixels for improved classification.
+
+ Returns
+ -------
+ The result of the `_complete` function, which depends on the sublcass.
"""
net_input_shape = self._model.input_shape[1:]
net_output_shape = self._model.output_shape[1:]
- offset_r = -net_input_shape[0] + net_output_shape[0]
- offset_c = -net_input_shape[1] + net_output_shape[1]
- block_size_x = net_input_shape[0] * (_TILE_SIZE // net_input_shape[0])
- block_size_y = net_input_shape[1] * (_TILE_SIZE // net_input_shape[1])
# Set up the output image
if not input_bounds:
input_bounds = rectangle.Rectangle(0, 0, width=image.width(), height=image.height())
+ output_shape = (input_bounds.width(), input_bounds.height())
+
+ ts = self._tile_shape if self._tile_shape else (image.width(), image.height())
+ if net_input_shape[0] is None and net_input_shape[1] is None:
+ assert net_output_shape[0] is None and net_output_shape[1] is None
+ out_shape = self._model.compute_output_shape((0, ts[0], ts[1], net_input_shape[2]))
+ tiles = input_bounds.make_tile_rois(ts, include_partials=False,
+ overlap_shape=(ts[0] - out_shape[1] + overlap[0],
+ ts[1] - out_shape[2] + overlap[1]),
+ partials_overlap=True)
- self._initialize((input_bounds.width() + offset_r, input_bounds.height() + offset_c), label, image)
+ else:
+ offset_r = -net_input_shape[0] + net_output_shape[0] + overlap[0]
+ offset_c = -net_input_shape[1] + net_output_shape[1] + overlap[1]
+ output_shape = (output_shape[0] + offset_r, output_shape[1] + offset_c)
+ block_size_x = net_input_shape[0] * max(1, ts[0] // net_input_shape[0])
+ block_size_y = net_input_shape[1] * max(1, ts[1] // net_input_shape[1])
+ tiles = input_bounds.make_tile_rois((block_size_x - offset_r, block_size_y - offset_c),
+ include_partials=False, overlap_shape=(-offset_r, -offset_c))
+
+ self._initialize(output_shape, image, label)
+
+ label_nodata = label.nodata_value() if label else None
def callback_function(roi, data):
- pred_image = self._predict_array(data)
+ pred_image = self._predict_array(data, image.nodata_value())
block_x = (roi.min_x - input_bounds.min_x)
block_y = (roi.min_y - input_bounds.min_y)
@@ -129,13 +213,20 @@ def callback_function(roi, data):
label_x + pred_image.shape[0], label_y + pred_image.shape[1])
labels = np.squeeze(label.read(label_roi))
- self._process_block(pred_image, sx, sy, labels)
-
- output_rois = input_bounds.make_tile_rois(block_size_x - offset_r, block_size_y - offset_c,
- include_partials=False, overlap_amount=-offset_r)
+ tl = [0, 0]
+ tl = (overlap[0] // 2 if block_x > 0 else 0, overlap[1] // 2 if block_y > 0 else 0)
+ br = (roi.max_x - roi.min_x, roi.max_y - roi.min_y)
+ br = (br[0] - (overlap[0] // 2 if roi.max_x < input_bounds.max_x else 0),
+ br[1] - (overlap[1] // 2 if roi.max_x < input_bounds.max_x else 0))
+ if len(pred_image.shape) == 2:
+ input_block = pred_image[tl[0]:br[0], tl[1]:br[1]]
+ else:
+ input_block = pred_image[tl[0]:br[0], tl[1]:br[1], :]
+ self._process_block(input_block, sx + tl[0], sy + tl[1],
+ None if labels is None else labels[tl[0]:br[0], tl[1]:br[1]], label_nodata)
try:
- image.process_rois(output_rois, callback_function, show_progress=self._show_progress)
+ image.process_rois(tiles, callback_function, show_progress=self._show_progress)
except KeyboardInterrupt:
self._abort()
raise
@@ -146,13 +237,30 @@ class LabelPredictor(Predictor):
"""
Predicts integer labels for an image.
"""
- def __init__(self, model, output_image=None, show_progress=False, nodata_value=None, # pylint:disable=too-many-arguments
+ def __init__(self, model, tile_shape=None, output_image=None, show_progress=False, # pylint:disable=too-many-arguments
colormap=None, prob_image=None, error_image=None, error_colors=None):
"""
- output_image, prob_image, and error_image are all DeltaImageWriter's.
- colormap and error_colors are all numpy arrays mapping classes to colors.
+ Parameters
+ ----------
+ model: tensorflow.keras.models.Model
+ Model to evaluate.
+ tile_shape: (int, int)
+ Shape of tiles to process.
+ output_image: str
+ If specified, output the results to this image.
+ show_progress: bool
+ Print progress to command line.
+ colormap: List[Any]
+ Map classes to colors given in the colormap.
+ prob_image: str
+ If given, output a probability image to this file. Probabilities are scaled as bytes
+ 1-255, with 0 as nodata.
+ error_image: str
+ If given, output an image showing where the classification is incorrect.
+ error_colors: List[Any]
+ Colormap for the error_image.
"""
- super(LabelPredictor, self).__init__(model, show_progress)
+ super().__init__(model, tile_shape, show_progress)
self._confusion_matrix = None
self._num_classes = None
self._output_image = output_image
@@ -165,7 +273,6 @@ def __init__(self, model, output_image=None, show_progress=False, nodata_value=N
a[i][1] = (v >> 8) & 0xFF
a[i][2] = v & 0xFF
colormap = a
- self._nodata_value = nodata_value
self._colormap = colormap
self._prob_image = prob_image
self._error_image = error_image
@@ -176,9 +283,21 @@ def __init__(self, model, output_image=None, show_progress=False, nodata_value=N
self._prob_o = None
self._errors = None
- def _initialize(self, shape, label, image):
+ def _initialize(self, shape, image, label=None):
net_output_shape = self._model.output_shape[1:]
self._num_classes = net_output_shape[-1]
+ if self._prob_image:
+ self._prob_image.initialize((shape[0], shape[1], self._num_classes), np.dtype(np.uint8),
+ image.metadata(), nodata_value=0)
+
+ if self._num_classes == 1: # special case
+ self._num_classes = 2
+ if self._colormap is not None and self._num_classes != self._colormap.shape[0]:
+ print('Warning: Defined defined classes (%d) in config do not match network (%d).' %
+ (self._colormap.shape[0], self._num_classes))
+ if self._colormap.shape[0] > self._num_classes:
+ self._num_classes = self._colormap.shape[0]
+
if label:
self._errors = np.zeros(shape, dtype=np.bool)
self._confusion_matrix = np.zeros((self._num_classes, self._num_classes), dtype=np.int32)
@@ -190,9 +309,7 @@ def _initialize(self, shape, label, image):
self._output_image.initialize((shape[0], shape[1], self._colormap.shape[1]),
self._colormap.dtype, image.metadata())
else:
- self._output_image.initialize((shape[0], shape[1], 1), np.int32, image.metadata())
- if self._prob_image:
- self._prob_image.initialize((shape[0], shape[1], self._num_classes), np.float32, image.metadata())
+ self._output_image.initialize((shape[0], shape[1]), np.int32, image.metadata())
if self._error_image:
self._error_image.initialize((shape[0], shape[1], self._error_colors.shape[1]),
self._error_colors.dtype, image.metadata())
@@ -214,30 +331,59 @@ def _abort(self):
if self._error_image is not None:
self._error_image.abort()
- def _process_block(self, pred_image, x, y, labels):
+ def _process_block(self, pred_image, x, y, labels, label_nodata):
if self._prob_image is not None:
- self._prob_image.write(pred_image, x, y)
- pred_image = np.argmax(pred_image, axis=2)
+ prob = 1.0 + (pred_image * 254.0)
+ prob = prob.astype(np.uint8)
+ prob[np.isnan(pred_image[:, :, 0] if len(pred_image.shape) == 3 else pred_image)] = 0
+ self._prob_image.write(prob, x, y)
+
+ if labels is None and self._output_image is None:
+ return
+
+ prob_image = pred_image
+ if len(pred_image.shape) == 2:
+ pred_image[~np.isnan(pred_image)] = pred_image[~np.isnan(pred_image)] >= 0.5
+ pred_image = pred_image.astype(int)
+ prob_image = np.expand_dims(prob_image, -1)
+ else:
+ pred_image = np.argmax(pred_image, axis=2)
- if self._output_image is not None:
- if self._colormap is not None:
- self._output_image.write(self._colormap[pred_image], x, y)
- else:
- self._output_image.write(pred_image, x, y)
+ # nodata pixels were set to nan in the probability image
+ pred_image[np.isnan(prob_image[:, :, 0])] = -1
if labels is not None:
- eimg = self._error_colors[(labels != pred_image).astype(int)]
- if self._nodata_value is not None:
- valid = (labels != self._nodata_value)
- eimg[np.logical_not(valid)] = np.zeros(eimg.shape[-1:], dtype=eimg.dtype)
- labels = labels[valid]
- pred_image = pred_image[valid]
- self._error_image.write(eimg, x, y)
- cm = tf.math.confusion_matrix(np.ndarray.flatten(labels),
- np.ndarray.flatten(pred_image),
+ incorrect = (labels != pred_image).astype(int)
+
+ valid_labels = labels
+ valid_pred = pred_image
+ if label_nodata is not None:
+ invalid = np.logical_or((labels == label_nodata), pred_image == -1)
+ valid = np.logical_not(invalid)
+ incorrect[invalid] = 0
+ valid_labels = labels[valid]
+ valid_pred = pred_image[valid]
+
+ if self._error_image:
+ self._error_image.write(self._error_colors[incorrect], x, y)
+ cm = tf.math.confusion_matrix(np.ndarray.flatten(valid_labels),
+ np.ndarray.flatten(valid_pred),
self._num_classes)
self._confusion_matrix[:, :] += cm
+ if self._output_image is not None:
+ if self._colormap is not None:
+ colormap = np.zeros((self._colormap.shape[0] + 1, self._colormap.shape[1]))
+ colormap[0:-1, :] = self._colormap
+ if labels is not None and label_nodata is not None:
+ pred_image[pred_image == -1] = self._colormap.shape[0]
+ result = np.zeros((pred_image.shape[0], pred_image.shape[1], self._colormap.shape[1]))
+ for i in range(prob_image.shape[2]):
+ result += (colormap[i, :] * prob_image[:, :, i, np.newaxis]).astype(colormap.dtype)
+ self._output_image.write(result, x, y)
+ else:
+ self._output_image.write(pred_image, x, y)
+
def confusion_matrix(self):
"""
Returns a matrix counting true labels matched to predicted labels.
@@ -248,22 +394,32 @@ class ImagePredictor(Predictor):
"""
Predicts an image from an image.
"""
- def __init__(self, model, output_image=None, show_progress=False, transform=None):
+ def __init__(self, model, tile_shape=None, output_image=None, show_progress=False, transform=None):
"""
- Trains on model, outputs to output_image, which is a DeltaImageWriter.
-
- transform is a tuple (function, output numpy type, number of bands) applied
- to the output image.
+ Parameters
+ ----------
+ model: tensorflow.keras.models.Model
+ Model to evaluate.
+ tile_shape: (int, int)
+ Shape of tiles to process at a time.
+ output_image: str
+ File to output results to.
+ show_progress: bool
+ Print progress to screen.
+ transform: (Callable[[numpy.ndarray], numpy.ndarray], output_type, num_bands)
+ The callable will be applied to the results from the network before saving
+ to a file. The results should be of type output_type and the third dimension
+ should be size num_bands.
"""
- super(ImagePredictor, self).__init__(model, show_progress)
+ super().__init__(model, tile_shape, show_progress)
self._output_image = output_image
self._output = None
self._transform = transform
- def _initialize(self, shape, label, image):
+ def _initialize(self, shape, image, label=None):
net_output_shape = self._model.output_shape[1:]
if self._output_image is not None:
- dtype = np.float32 if self._transform is None else self._transform[1]
+ dtype = np.float32 if self._transform is None else np.dtype(self._transform[1])
bands = net_output_shape[-1] if self._transform is None else self._transform[2]
self._output_image.initialize((shape[0], shape[1], bands), dtype, image.metadata())
@@ -276,7 +432,7 @@ def _abort(self):
if self._output_image is not None:
self._output_image.abort()
- def _process_block(self, pred_image, x, y, labels):
+ def _process_block(self, pred_image, x, y, labels, label_nodata):
if self._output_image is not None:
im = pred_image
if self._transform is not None:
diff --git a/delta/ml/train.py b/delta/ml/train.py
index 6e76843b..ce556d0b 100644
--- a/delta/ml/train.py
+++ b/delta/ml/train.py
@@ -19,18 +19,40 @@
Train neural networks.
"""
+import datetime
import os
import tempfile
import shutil
import mlflow
+import numpy as np
import tensorflow as tf
+import tensorflow.keras.backend as K
+from tensorflow.keras.layers import Layer
from delta.config import config
from delta.imagery.imagery_dataset import ImageryDataset
from delta.imagery.imagery_dataset import AutoencoderDataset
-from .layers import DeltaLayer
-from .io import save_model
+from .io import save_model, print_network
+from .config_parser import config_callbacks, loss_from_dict, metric_from_dict, optimizer_from_dict
+
+class DeltaLayer(Layer):
+ """
+ Network layer class with extra features specific to DELTA.
+
+ Extentds `tensorflow.keras.layers.Layer`.
+ """
+ def callback(self): # pylint:disable=no-self-use
+ """
+ Override this method to make a layer automatically register
+ a training callback.
+
+ Returns
+ -------
+ tensorflow.keras.callbacks.Callback:
+ The callback to register (or None).
+ """
+ return None
def _devices(num_gpus):
'''
@@ -60,68 +82,95 @@ def _strategy(devices):
strategy = tf.distribute.MirroredStrategy(devices=devices)
return strategy
-def _prep_datasets(ids, tc, chunk_size, output_size):
+def _prep_datasets(ids, tc):
ds = ids.dataset(config.dataset.classes.weights())
- ds = ds.batch(tc.batch_size)
- #ds = ds.cache()
- ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
+
+ validation=None
if tc.validation:
if tc.validation.from_training:
validation = ds.take(tc.validation.steps)
ds = ds.skip(tc.validation.steps)
else:
- vimg = tc.validation.images
+ vimg = tc.validation.images
vlabel = tc.validation.labels
if not vimg:
validation = None
else:
if vlabel:
- vimagery = ImageryDataset(vimg, vlabel, chunk_size, output_size, tc.chunk_stride,
- resume_mode=False)
+ vimagery = ImageryDataset(vimg, vlabel, ids.output_shape(), ids.chunk_shape(),
+ tile_shape=ids.tile_shape(), stride=ids.stride(),
+ tile_overlap=ids.tile_overlap())
else:
- vimagery = AutoencoderDataset(vimg, chunk_size, tc.chunk_stride, resume_mode=False)
- validation = vimagery.dataset().batch(tc.batch_size)
+ vimagery = AutoencoderDataset(vimg, ids.chunk_shape(), tile_shape=ids.tile_shape(),
+ stride=ids.stride(), tile_overlap=ids.tile_overlap())
+ validation = vimagery.dataset(config.dataset.classes.weights())
if tc.validation.steps:
validation = validation.take(tc.validation.steps)
- #validation = validation.prefetch(4)#tf.data.experimental.AUTOTUNE)
+ if validation:
+ validation = validation.batch(tc.batch_size, drop_remainder=True).prefetch(1)
else:
-
validation = None
+
+ ds = ds.batch(tc.batch_size, drop_remainder=True)
+ ds = ds.prefetch(1)
if tc.steps:
ds = ds.take(tc.steps)
- #ds = ds.prefetch(4)#tf.data.experimental.AUTOTUNE)
- ds = ds.repeat(tc.epochs)
return (ds, validation)
def _log_mlflow_params(model, dataset, training_spec):
images = dataset.image_set()
#labels = dataset.label_set()
- mlflow.log_param('Image Type', images.type())
- mlflow.log_param('Preprocess', images.preprocess())
- mlflow.log_param('Number of Images', len(images))
- mlflow.log_param('Chunk Size', dataset.chunk_size())
- mlflow.log_param('Chunk Stride', training_spec.chunk_stride)
- mlflow.log_param('Output Shape', dataset.output_shape())
- mlflow.log_param('Steps', training_spec.steps)
- mlflow.log_param('Loss Function', training_spec.loss_function)
- mlflow.log_param('Epochs', training_spec.epochs)
- mlflow.log_param('Batch Size', training_spec.batch_size)
- mlflow.log_param('Optimizer', training_spec.optimizer)
- mlflow.log_param('Model Layers', len(model.layers))
+ mlflow.log_param('Images - Type', images.type())
+ mlflow.log_param('Images - Count', len(images))
+ mlflow.log_param('Images - Stride', training_spec.stride)
+ mlflow.log_param('Images - Tile Size', len(model.layers))
+ mlflow.log_param('Train - Steps', training_spec.steps)
+ mlflow.log_param('Train - Loss Function', training_spec.loss)
+ mlflow.log_param('Train - Epochs', training_spec.epochs)
+ mlflow.log_param('Train - Batch Size', training_spec.batch_size)
+ mlflow.log_param('Train - Optimizer', training_spec.optimizer)
+ mlflow.log_param('Model - Layers', len(model.layers))
+ mlflow.log_param('Model - Parameters - Non-Trainable',
+ np.sum([K.count_params(w) for w in model.non_trainable_weights]))
+ mlflow.log_param('Model - Parameters - Trainable',
+ np.sum([K.count_params(w) for w in model.trainable_weights]))
+ mlflow.log_param('Model - Shape - Output', dataset.output_shape())
+ mlflow.log_param('Model - Shape - Input', dataset.input_shape())
#mlflow.log_param('Status', 'Running') Illegal to change the value!
+class _EpochResetCallback(tf.keras.callbacks.Callback):
+ """
+ Reset imagery_dataset file counts on epoch end
+ """
+ def __init__(self, ids, stop_epoch):
+ super().__init__()
+ self.ids = ids
+ self.last_epoch = stop_epoch - 1
+
+ def on_epoch_end(self, epoch, _=None):
+ if config.general.verbose():
+ print('Finished epoch ' + str(epoch))
+ # Leave the counts from the last epoch just as a record
+ if epoch != self.last_epoch:
+ self.ids.reset_access_counts()
+
class _MLFlowCallback(tf.keras.callbacks.Callback):
"""
Callback to log everything for MLFlow.
"""
def __init__(self, temp_dir):
- super(_MLFlowCallback, self).__init__()
+ super().__init__()
self.epoch = 0
self.batch = 0
self.temp_dir = temp_dir
- def on_epoch_end(self, epoch, _=None):
+ def on_epoch_end(self, epoch, logs=None):
self.epoch = epoch
+ for k in logs.keys():
+ if k.startswith('val_'):
+ mlflow.log_metric('Validation ' + k[4:], logs[k], epoch)
+ else:
+ mlflow.log_metric('Epoch ' + k, logs[k], epoch)
def on_train_batch_end(self, batch, logs=None):
self.batch = batch
@@ -140,12 +189,6 @@ def on_train_batch_end(self, batch, logs=None):
mlflow.log_artifact(filename, 'checkpoints')
os.remove(filename)
- def on_test_batch_end(self, _, logs=None): # pylint:disable=no-self-use
- for k in logs.keys():
- if k in ('batch', 'size'):
- continue
- mlflow.log_metric('validation_' + k, logs[k].item())
-
def _mlflow_train_setup(model, dataset, training_spec):
mlflow.set_tracking_uri(config.mlflow.uri())
mlflow.set_experiment(config.mlflow.experiment())
@@ -161,39 +204,12 @@ def _mlflow_train_setup(model, dataset, training_spec):
return _MLFlowCallback(temp_dir)
-def train(model_fn, dataset : ImageryDataset, training_spec):
- """
- Trains the specified model on a dataset according to a training
- specification.
+def _build_callbacks(model, dataset, training_spec):
"""
- if isinstance(model_fn, tf.keras.Model):
- model = model_fn
- else:
- with _strategy(_devices(config.general.gpus())).scope():
- model = model_fn()
- assert isinstance(model, tf.keras.models.Model),\
- "Model is not a Tensorflow Keras model"
- loss = training_spec.loss_function
- # TODO: specify learning rate and optimizer parameters, change learning rate over time
- model.compile(optimizer=training_spec.optimizer, loss=loss,
- metrics=training_spec.metrics)
-
- input_shape = model.input_shape
- output_shape = model.output_shape
- chunk_size = input_shape[1]
-
- assert len(input_shape) == 4, 'Input to network is wrong shape.'
- assert input_shape[0] is None, 'Input is not batched.'
- # The below may no longer be valid if we move to convolutional architectures.
- assert input_shape[1] == input_shape[2], 'Input to network is not chunked'
- assert len(output_shape) == 2 or output_shape[1] == output_shape[2], 'Output from network is not chunked'
- assert input_shape[3] == dataset.num_bands(), 'Number of bands in model does not match data.'
- # last element differs for the sparse metrics
- assert output_shape[1:-1] == dataset.output_shape()[:-1], \
- 'Network output shape %s does not match label shape %s.' % (output_shape[1:], dataset.output_shape())
-
- (ds, validation) = _prep_datasets(dataset, training_spec, chunk_size, output_shape[1])
+ Create callbacks needed based on configuration.
+ Returns (list of callbacks, mlflow callback).
+ """
callbacks = [tf.keras.callbacks.TerminateOnNaN()]
# add callbacks from DeltaLayers
for l in model.layers:
@@ -201,27 +217,175 @@ def train(model_fn, dataset : ImageryDataset, training_spec):
c = l.callback()
if c:
callbacks.append(c)
+
+ mcb = None
+ if config.mlflow.enabled():
+ mcb = _mlflow_train_setup(model, dataset, training_spec)
+ callbacks.append(mcb)
+ if config.general.verbose():
+ print('Using mlflow folder: ' + mlflow.get_artifact_uri())
+
if config.tensorboard.enabled():
- tcb = tf.keras.callbacks.TensorBoard(log_dir=config.tensorboard.dir(),
+ tb_dir = config.tensorboard.dir()
+ if config.mlflow.enabled():
+ tb_dir = os.path.join(tb_dir, str(mlflow.active_run().info.run_id))
+ mlflow.log_param('TensorBoard Directory', tb_dir)
+ else:
+ tb_dir = os.path.join(tb_dir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+ tcb = tf.keras.callbacks.TensorBoard(log_dir=tb_dir,
update_freq='epoch',
histogram_freq=1,
write_images=True,
embeddings_freq=1)
callbacks.append(tcb)
- if config.mlflow.enabled():
- mcb = _mlflow_train_setup(model, dataset, training_spec)
- callbacks.append(mcb)
- #print('Using mlflow folder: ' + mlflow.get_artifact_uri())
+ callbacks.append(_EpochResetCallback(dataset, training_spec.epochs))
+
+ callbacks.extend(config_callbacks())
+
+ return (callbacks, mcb)
+
+def _compile_helper(model, training_spec):
+ model.compile(optimizer=optimizer_from_dict(training_spec.optimizer),
+ loss=loss_from_dict(training_spec.loss),
+ metrics=[metric_from_dict(m) for m in training_spec.metrics])
+
+class ContinueTrainingException(Exception):
+ """
+ Callbacks can raise this exception to modify the model, recompile, and
+ continue training.
+ """
+ def __init__(self, msg: str=None, completed_epochs: int=0,
+ recompile_model: bool=False, learning_rate: float=None):
+ """
+ Parameters
+ ----------
+ msg: str
+ Optional error message.
+ completed_epochs: int
+ The number of epochs that have been finished. (resumes from the next epoch)
+ recompile_model: bool
+ If True, recompile the model. This is necessary if the model has been changed.
+ learning_rate: float
+ Optionally set the learning rate to the given value.
+ """
+ super().__init__(msg)
+ self.completed_epochs = completed_epochs
+ self.recompile_model = recompile_model
+ self.learning_rate = learning_rate
+
+def compile_model(model_fn, training_spec, resume_path=None):
+ """
+ Compile and check that the model is valid.
+
+ Parameters
+ ----------
+ model_fn: Callable[[], tensorflow.keras.model.Model]
+ Function to construct a keras Model.
+ training_spec: delta.ml.ml_config.TrainingSpec
+ Trainnig parameters.
+ resume_path: str
+ File name to load initial model weights from.
+
+ Returns
+ -------
+ tensorflow.keras.models.Model:
+ The compiled model, ready for training.
+ """
+ if not hasattr(training_spec, 'strategy'):
+ training_spec.strategy = _strategy(_devices(config.general.gpus()))
+ with training_spec.strategy.scope():
+ model = model_fn()
+ assert isinstance(model, tf.keras.models.Model), \
+ "Model is not a Tensorflow Keras model"
+
+ if resume_path is not None:
+ print('Loading existing model: ' + resume_path)
+ model.load_weights(resume_path)
+
+ _compile_helper(model, training_spec)
+
+ input_shape = model.input_shape
+ output_shape = model.output_shape
+
+ assert len(input_shape) == 4, 'Input to network is wrong shape.'
+ assert input_shape[0] is None, 'Input is not batched.'
+ # The below may no longer be valid if we move to convolutional architectures.
+ assert input_shape[1] == input_shape[2], 'Input to network is not chunked'
+ assert len(output_shape) == 2 or output_shape[1] == output_shape[2], 'Output from network is not chunked'
+
+ if config.general.verbose():
+ print('Training model:')
+ print_network(model, (512, 512, 8))
+ print(model.summary(line_length=120))
+
+ return model
+
+def train(model_fn, dataset : ImageryDataset, training_spec, resume_path=None):
+ """
+ Trains the specified model on a dataset according to a training
+ specification.
+
+ Parameters
+ ----------
+ model_fn: Callable[[], tensorflow.keras.model.Model]
+ Function that constructs a model.
+ dataset: delta.imagery.imagery_dataset.ImageryDataset
+ Dataset to train on.
+ training_spec: delta.ml.ml_config.TrainingSpec
+ Training parameters.
+ resume_path: str
+ Optional file to load initial model weights from.
+
+ Returns
+ -------
+ (tensorflow.keras.models.Model, History):
+ The trained model and the training history.
+ """
+ model = compile_model(model_fn, training_spec, resume_path)
+ assert model.input_shape[3] == dataset.num_bands(), 'Number of bands in model does not match data.'
+ # last element differs for the sparse metrics
+ assert model.output_shape[1:-1] == dataset.output_shape()[:-1] or (model.output_shape[1] is None), \
+ 'Network output shape %s does not match label shape %s.' % \
+ (model.output_shape[1:], dataset.output_shape()[:-1])
+
+ (ds, validation) = _prep_datasets(dataset, training_spec)
+
+ (callbacks, mcb) = _build_callbacks(model, dataset, training_spec)
try:
- history = model.fit(ds,
- epochs=training_spec.epochs,
- callbacks=callbacks,
- validation_data=validation,
- validation_steps=training_spec.validation.steps if training_spec.validation else None,
- steps_per_epoch=training_spec.steps,
- verbose=1)
+
+ # Mark that we need to check the dataset counts the
+ # first time we try to read the images.
+ # This won't do anything unless we are resuming training.
+ dataset.reset_access_counts(set_need_check=True)
+
+ if (training_spec.steps is None) or (training_spec.steps > 0):
+ done = False
+ epochs = training_spec.epochs
+ initial_epoch = 0
+ while not done:
+ try:
+ history = model.fit(ds,
+ epochs=epochs,
+ initial_epoch=initial_epoch,
+ callbacks=callbacks,
+ validation_data=validation,
+ validation_steps=None, # Steps are controlled in the dataset setup
+ steps_per_epoch=None,
+ verbose=1) # Set to 2 when logging
+ done = True
+ except ContinueTrainingException as cte:
+ print('Recompiling model and resuming training.')
+ initial_epoch += cte.completed_epochs
+ if cte.recompile_model:
+ model = compile_model(model, training_spec)
+ if cte.learning_rate:
+ K.set_value(model.optimizer.lr, cte.learning_rate)
+ else: # Skip training
+ print('Skipping straight to validation')
+ history = model.evaluate(validation, steps=training_spec.validation.steps,
+ callbacks=callbacks, verbose=1)
if config.mlflow.enabled():
model_path = os.path.join(mcb.temp_dir, 'final_model.h5')
@@ -233,6 +397,8 @@ def train(model_fn, dataset : ImageryDataset, training_spec):
except:
if config.mlflow.enabled():
mlflow.log_param('Status', 'Aborted')
+ mlflow.log_param('Epoch', mcb.epoch)
+ mlflow.log_param('Batch', mcb.batch)
mlflow.end_run('FAILED')
model_path = os.path.join(mcb.temp_dir, 'aborted_model.h5')
print('\nAborting, saving current model to %s.' % (mlflow.get_artifact_uri() + '/aborted_model.h5'))
@@ -242,8 +408,6 @@ def train(model_fn, dataset : ImageryDataset, training_spec):
raise
finally:
if config.mlflow.enabled():
- mlflow.log_param('Epoch', mcb.epoch)
- mlflow.log_param('Batch', mcb.batch)
if mcb and mcb.temp_dir:
shutil.rmtree(mcb.temp_dir)
diff --git a/delta/subcommands/classify.py b/delta/subcommands/classify.py
index 45ec4dc2..943c321e 100644
--- a/delta/subcommands/classify.py
+++ b/delta/subcommands/classify.py
@@ -22,16 +22,16 @@
import time
import numpy as np
-import matplotlib.pyplot as plt
+import matplotlib
import tensorflow as tf
from delta.config import config
-from delta.imagery.sources import tiff
-from delta.imagery.sources import loader
+from delta.config.extensions import custom_objects, image_writer
from delta.ml import predict
-import delta.ml.layers
-import delta.imagery.imagery_config
-import delta.ml.ml_config
+from delta.extensions.sources.tiff import write_tiff
+
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt #pylint: disable=wrong-import-order,wrong-import-position,ungrouped-imports
def save_confusion(cm, class_labels, filename):
f = plt.figure()
@@ -57,7 +57,8 @@ def save_confusion(cm, class_labels, filename):
f.savefig(filename)
def ae_convert(data):
- return (data[:, :, [4, 2, 1]] * 256.0).astype(np.uint8)
+ r = np.clip((data[:, :, [4, 2, 1]] * np.float32(100.0)), 0.0, 255.0).astype(np.uint8)
+ return r
def main(options):
@@ -66,9 +67,9 @@ def main(options):
if cpuOnly:
with tf.device('/cpu:0'):
- model = tf.keras.models.load_model(options.model, custom_objects=delta.ml.layers.ALL_LAYERS)
+ model = tf.keras.models.load_model(options.model, custom_objects=custom_objects(), compile=False)
else:
- model = tf.keras.models.load_model(options.model, custom_objects=delta.ml.layers.ALL_LAYERS)
+ model = tf.keras.models.load_model(options.model, custom_objects=custom_objects(), compile=False)
colors = list(map(lambda x: x.color, config.dataset.classes))
error_colors = np.array([[0x0, 0x0, 0x0],
@@ -79,50 +80,71 @@ def main(options):
start_time = time.time()
images = config.dataset.images()
labels = config.dataset.labels()
+ net_name = os.path.splitext(os.path.basename(options.model))[0]
+ full_cm = None
if options.autoencoder:
labels = None
for (i, path) in enumerate(images):
- image = loader.load_image(images, i)
+ image = images.load(i)
base_name = os.path.splitext(os.path.basename(path))[0]
- output_image = tiff.DeltaTiffWriter('predicted_' + base_name + '.tiff')
- prob_image = None
- if options.prob:
- prob_image = tiff.DeltaTiffWriter('prob_' + base_name + '.tiff')
- error_image = None
- if labels:
- error_image = tiff.DeltaTiffWriter('errors_' + base_name + '.tiff')
+ writer = image_writer('tiff')
+ prob_image = writer(net_name + '_' + base_name + '.tiff') if options.prob else None
+ output_image = writer(net_name + '_' + base_name + '.tiff') if not options.prob else None
+ error_image = None
label = None
if labels:
- label = loader.load_image(config.dataset.labels(), i)
+ error_image = writer('errors_' + base_name + '.tiff')
+ label = labels.load(i)
+ assert image.size() == label.size(), 'Image and label do not match.'
+ ts = config.io.tile_size()
if options.autoencoder:
label = image
- predictor = predict.ImagePredictor(model, output_image, True, (ae_convert, np.uint8, 3))
+ predictor = predict.ImagePredictor(model, ts, output_image, True,
+ None if options.noColormap else (ae_convert, np.uint8, 3))
else:
- predictor = predict.LabelPredictor(model, output_image, True, labels.nodata_value(), colormap=colors,
+ predictor = predict.LabelPredictor(model, ts, output_image, True, colormap=colors,
prob_image=prob_image, error_image=error_image,
error_colors=error_colors)
+ overlap = (options.overlap, options.overlap)
try:
if cpuOnly:
with tf.device('/cpu:0'):
- predictor.predict(image, label)
+ predictor.predict(image, label, overlap=overlap)
else:
- predictor.predict(image, label)
+ predictor.predict(image, label, overlap=overlap)
except KeyboardInterrupt:
print('\nAborted.')
return 0
if labels:
cm = predictor.confusion_matrix()
- print('%.2g%% Correct: %s' % (np.sum(np.diag(cm)) / np.sum(cm) * 100, path))
+ if full_cm is None:
+ full_cm = np.copy(cm).astype(np.int64)
+ else:
+ full_cm += cm
+ for j in range(cm.shape[0]):
+ print('%s--- Precision: %.2f%% Recall: %.2f%% Pixels: %d / %d' % \
+ (config.dataset.classes[j].name,
+ 100 * cm[j,j] / np.sum(cm[:, j]),
+ 100 * cm[j,j] / np.sum(cm[j, :]),
+ int(np.sum(cm[j, :])), int(np.sum(cm))))
+ print('%.2f%% Correct: %s' % (float(np.sum(np.diag(cm)) / np.sum(cm) * 100), path))
save_confusion(cm, map(lambda x: x.name, config.dataset.classes), 'confusion_' + base_name + '.pdf')
if options.autoencoder:
- tiff.write_tiff('orig_' + base_name + '.tiff', ae_convert(image.read()),
- metadata=image.metadata())
+ write_tiff('orig_' + base_name + '.tiff', image.read() if options.noColormap else ae_convert(image.read()),
+ metadata=image.metadata())
stop_time = time.time()
+ if labels:
+ for i in range(full_cm.shape[0]):
+ print('%s--- Precision: %.2f%% Recall: %.2f%% Pixels: %d / %d' % \
+ (config.dataset.classes[i].name,
+ 100 * full_cm[i,i] / np.sum(full_cm[:, i]),
+ 100 * full_cm[i,i] / np.sum(full_cm[i, :]),
+ int(np.sum(cm[j, :])), int(np.sum(cm))))
print('Elapsed time = ', stop_time - start_time)
return 0
diff --git a/delta/subcommands/commands.py b/delta/subcommands/commands.py
index 5e56b825..ad111e4f 100644
--- a/delta/subcommands/commands.py
+++ b/delta/subcommands/commands.py
@@ -35,6 +35,10 @@ def main_mlflow_ui(options):
from .import mlflow_ui
mlflow_ui.main(options)
+def main_validate(options):
+ from .import validate
+ validate.main(options)
+
def setup_classify(subparsers):
sub = subparsers.add_parser('classify', help='Classify images given a model.')
config.setup_arg_parser(sub, ['general', 'io', 'dataset'])
@@ -43,6 +47,8 @@ def setup_classify(subparsers):
sub.add_argument('--autoencoder', dest='autoencoder', action='store_true', help='Classify with the autoencoder.')
sub.add_argument('--no-colormap', dest='noColormap', action='store_true',
help='Save raw classification values instead of colormapped values.')
+ sub.add_argument('--overlap', dest='overlap', type=int, default=0, help='Classify with the autoencoder.')
+ sub.add_argument('--validation', dest='validation', help='Classify validation images instead.')
sub.add_argument('model', help='File to save the network to.')
sub.set_defaults(function=main_classify)
@@ -62,5 +68,10 @@ def setup_mlflow_ui(subparsers):
sub.set_defaults(function=main_mlflow_ui)
+def setup_validate(subparsers):
+ sub = subparsers.add_parser('validate', help='Validate input dataset.')
+ config.setup_arg_parser(sub, ['general', 'io', 'dataset', 'train'])
+
+ sub.set_defaults(function=main_validate)
-SETUP_COMMANDS = [setup_train, setup_classify, setup_mlflow_ui]
+SETUP_COMMANDS = [setup_train, setup_classify, setup_mlflow_ui, setup_validate]
diff --git a/delta/subcommands/main.py b/delta/subcommands/main.py
index 15bca996..e88934c6 100644
--- a/delta/subcommands/main.py
+++ b/delta/subcommands/main.py
@@ -14,6 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Main delta command, calls subcommands.
+"""
import sys
import argparse
@@ -23,6 +26,9 @@
from delta.subcommands import commands
def main(args):
+ """
+ DELTA main function.
+ """
delta.config.modules.register_all()
parser = argparse.ArgumentParser(description='DELTA Machine Learning Toolkit')
subparsers = parser.add_subparsers()
diff --git a/delta/subcommands/train.py b/delta/subcommands/train.py
index 85b0ee31..34203886 100644
--- a/delta/subcommands/train.py
+++ b/delta/subcommands/train.py
@@ -29,55 +29,76 @@
import tensorflow as tf
from delta.config import config
+from delta.config.extensions import custom_objects
from delta.imagery import imagery_dataset
from delta.ml.train import train
-from delta.ml.model_parser import config_model
-from delta.ml.layers import ALL_LAYERS
+from delta.ml.config_parser import config_model
from delta.ml.io import save_model
#tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
def main(options):
- log_folder = config.dataset.log_folder()
+ log_folder = config.train.log_folder()
if log_folder:
if not options.resume: # Start fresh and clear the read logs
- os.system('rm ' + log_folder + '/*')
+ os.system('rm -f ' + log_folder + '/*')
print('Dataset progress recording in: ' + log_folder)
else:
print('Resuming dataset progress recorded in: ' + log_folder)
- start_time = time.time()
images = config.dataset.images()
if not images:
print('No images specified.', file=sys.stderr)
return 1
- tc = config.train.spec()
+
+ img = images.load(0)
+ model = config_model(img.num_bands())
+ if options.resume is not None:
+ temp_model = tf.keras.models.load_model(options.resume, custom_objects=custom_objects())
+ else:
+ # this one is not built with proper scope, just used to get input and output shapes
+ temp_model = model()
+
+ start_time = time.time()
+ tile_size = config.io.tile_size()
+ tile_overlap = None
+ stride = config.train.spec().stride
+
+ # compute input and output sizes
+ if temp_model.input_shape[1] is None:
+ in_shape = None
+ out_shape = temp_model.compute_output_shape((0, tile_size[0], tile_size[1], temp_model.input_shape[3]))
+ out_shape = out_shape[1:3]
+ tile_overlap = (tile_size[0] - out_shape[0], tile_size[1] - out_shape[1])
+ else:
+ in_shape = temp_model.input_shape[1:3]
+ out_shape = temp_model.output_shape[1:3]
+
if options.autoencoder:
- ids = imagery_dataset.AutoencoderDataset(images, config.train.network.chunk_size(),
- tc.chunk_stride, resume_mode=options.resume,
- log_folder=log_folder)
+ ids = imagery_dataset.AutoencoderDataset(images, in_shape, tile_shape=tile_size,
+ tile_overlap=tile_overlap, stride=stride)
else:
labels = config.dataset.labels()
if not labels:
print('No labels specified.', file=sys.stderr)
return 1
- ids = imagery_dataset.ImageryDataset(images, labels, config.train.network.chunk_size(),
- config.train.network.output_size(), tc.chunk_stride,
- resume_mode=options.resume,
- log_folder=log_folder)
+ ids = imagery_dataset.ImageryDataset(images, labels, out_shape, in_shape,
+ tile_shape=tile_size, tile_overlap=tile_overlap,
+ stride=stride)
+ if log_folder is not None:
+ ids.set_resume_mode(options.resume, log_folder)
+
+ assert temp_model.input_shape[1] == temp_model.input_shape[2], 'Must have square chunks in model.'
+ assert temp_model.input_shape[3] == ids.num_bands(), 'Model takes wrong number of bands.'
+ tf.keras.backend.clear_session()
try:
- if options.resume is not None:
- model = tf.keras.models.load_model(options.resume, custom_objects=ALL_LAYERS)
- else:
- model = config_model(ids.num_bands())
- model, _ = train(model, ids, tc)
+ model, _ = train(model, ids, config.train.spec(), options.resume)
if options.model is not None:
save_model(model, options.model)
except KeyboardInterrupt:
- print()
print('Training cancelled.')
stop_time = time.time()
diff --git a/delta/subcommands/validate.py b/delta/subcommands/validate.py
new file mode 100644
index 00000000..0da57292
--- /dev/null
+++ b/delta/subcommands/validate.py
@@ -0,0 +1,200 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Check if the input data is valid.
+"""
+
+import sys
+import os
+
+import numpy as np
+from osgeo import gdal
+
+from delta.config import config
+
+def get_image_stats(path):
+ '''Return a list of image band statistics like [[min, max, mean, stddev], ...]'''
+ tif_handle = gdal.Open(path)
+ num_bands = tif_handle.RasterCount
+
+ output = []
+ for b in range(0,num_bands):
+ band = tif_handle.GetRasterBand(b+1)
+ stats = band.GetStatistics(False, True)
+ output.append(stats)
+
+ return output
+
+
+def get_class_dict():
+ '''Populate dictionary with class names by index number'''
+ d = {}
+ for c in config.dataset.classes:
+ d[c.end_value] = c.name
+ if config.dataset.labels().nodata_value():
+ d[len(config.dataset.classes)] = 'nodata'
+ return d
+
+def classes_string(classes, values, image_name):
+ '''Generate a formatted string out of strings or numbers.
+ "classes" must come from get_class_dict()'''
+ s = '%-20s ' % (image_name)
+ is_integer = np.issubdtype(type(values[0]), np.integer)
+ is_float = isinstance(values[0], float)
+ if is_integer:
+ total = sum(values.values())
+ nodata_class = None
+ if config.dataset.labels().nodata_value():
+ nodata_class = len(config.dataset.classes)
+ total -= values[nodata_class]
+ for (j, name) in classes.items():
+ if name == 'nodata':
+ continue
+ v = values[j] if j in values else 0
+ if is_integer:
+ s += '%12.2f%% ' % (v / total * 100, )
+ else:
+ if is_float:
+ s += '%12.2f ' % (v)
+ else:
+ s += '%12s ' % (v, )
+ return s
+
+
+
+def check_image(images, measures, total_counts, i):
+ '''Accumulate total_counts and print out image statistics'''
+
+ # Find min, max, mean, std
+ stats = get_image_stats(images[i])
+
+ # Accumulate statistics
+ if not total_counts:
+ for band in stats: #pylint: disable=W0612
+ total_counts.append({'min' : 0.0,
+ 'max' : 0.0,
+ 'mean' : 0.0,
+ 'stddev': 0.0})
+
+ for (b, bandstats) in enumerate(stats):
+ total_counts[b]['min' ] += bandstats[0]
+ total_counts[b]['max' ] += bandstats[1]
+ total_counts[b]['mean' ] += bandstats[2]
+ total_counts[b]['stddev'] += bandstats[3]
+ name = ''
+ if b == 0:
+ name = os.path.basename(images[i])
+ print(classes_string(measures, dict(enumerate(bandstats)), name))
+
+ return ''
+
+def print_image_totals(images, measures, total_counts):
+ '''Convert from source image stat totals to averages and print'''
+ num_images = len(images)
+ num_bands = len(total_counts)
+ for b in range(0,num_bands):
+ values = []
+ for m in range(0,len(measures)): #pylint: disable=C0200
+ values.append(total_counts[b][measures[m]]/num_images)
+ name = ''
+ if b == 0:
+ name = 'Total'
+ print(classes_string(measures, dict(enumerate(values)), name))
+
+def check_label(images, labels, classes, total_counts, i):
+ '''Accumulate total_counts and print out image statistics'''
+ img = images.load(i)
+ label = labels.load(i)
+ if label.size() != img.size():
+ return 'Error: size mismatch for %s and %s.\n' % (images[i], labels[i])
+ # Count number of times each label appears in image
+ v, counts = np.unique(label.read(), return_counts=True)
+
+ # Load the label counts into dictionary and accumulate total_counts
+ values = { k:0 for (k, _) in classes.items() }
+ for (j, value) in enumerate(v):
+ values[value] = counts[j]
+ if value not in total_counts:
+ total_counts[value] = 0
+ total_counts[value] += counts[j]
+ # Print out display line with percentages
+ print(classes_string(classes, values, labels[i].split('/')[-1]))
+ return ''
+
+def evaluate_images(images, labels):
+ '''Print class statistics for a set of images with matching labels'''
+ errors = ''
+ classes = get_class_dict()
+
+ # Evaluate labels first
+ counts = {}
+ if config.dataset.labels().nodata_value():
+ counts[len(config.dataset.classes)] = 0
+ header = classes_string(classes, classes, 'Label')
+ print(header)
+ print('-' * len(header))
+ for i in range(len(labels)):
+ errors += check_label(images, labels, classes, counts, i)
+ print('-' * len(header))
+ print(classes_string(classes, counts, 'Total'))
+ print()
+
+ if config.dataset.labels().nodata_value():
+ nodata_c = counts[len(config.dataset.classes)]
+ total = sum(counts.values())
+ print('Nodata is %6.2f%% of the data. Total Pixels: %.2f million.' % \
+ (nodata_c / total * 100, (total - nodata_c) / 1000000))
+
+ # Now evaluate source images
+ counts = []
+ print()
+ measures = {0:'min', 1:'max', 2:'mean', 3:'stddev'}
+ header = classes_string(classes, measures, 'Image')
+ print(header)
+ print('-' * len(header))
+ for i in range(len(images)):
+ errors += check_image(images, measures, counts, i)
+ print('-' * len(header))
+ print_image_totals(images, measures, counts)
+ print()
+
+
+ return errors
+
+def main(_):
+ images = config.dataset.images() # Get all image paths based on config values
+ labels = config.dataset.labels() # Get all label paths based on config pathn
+ if not images:
+ print('No images specified.', file=sys.stderr)
+ return 1
+ if not labels:
+ print('No labels specified.', file=sys.stderr)
+ else:
+ assert len(images) == len(labels)
+ print('Validating %d images.' % (len(images)))
+ errors = evaluate_images(images, labels)
+ tc = config.train.spec()
+ if tc.validation.images:
+ print('Validating %d validation images.' % (len(tc.validation.images)))
+ errors += evaluate_images(tc.validation.images, tc.validation.labels)
+ if errors:
+ print(errors, file=sys.stderr)
+ return -1
+
+ print('Validation successful.')
+ return 0
diff --git a/scripts/.coveragerc b/scripts/.coveragerc
new file mode 100644
index 00000000..2285f6ac
--- /dev/null
+++ b/scripts/.coveragerc
@@ -0,0 +1,6 @@
+[report]
+exclude_lines =
+ pragma: no cover
+ raise AssertionError
+ raise ValueError
+ if __name__ == .__main__.:
diff --git a/scripts/convert/gdal_util.py b/scripts/convert/gdal_util.py
new file mode 100644
index 00000000..7d8ace9e
--- /dev/null
+++ b/scripts/convert/gdal_util.py
@@ -0,0 +1,431 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Diverse useful functions for handeling new datasets using the gdal python bindings:
+ - Merging images
+ - Matching images
+ - Comparing images
+ - Computing shared region between images
+ - Filling noData pixels
+
+Usual parameters:
+ - input_path(s): path(s) to the file(s) that will be processed. Should be a string or a list of string.
+ - output_path(/file): path to the directory(/file) where the generated output(s) will be stored
+ - db_ref(/new): Gdal datasets obtained with 'gdal.Open(path_to_file)'. Ref is for reference
+
+This file needs more cleanup and testing.
+"""
+
+import os
+
+import numpy as np
+from scipy import ndimage
+
+from osgeo import gdal, osr, gdalconst
+
+# Indexes for Geotransform metadata
+XCOORD = 0
+YCOORD = 3
+XRES = 1
+YRES = 5
+
+def merge_images(input_paths, output_file):
+ """
+ Merge tiles of an image into a single big image.
+ Merging is based on the GeoTransforms of the tiles, input order is irrelevant.
+ If some part of the overall scene is not covered by the tiles, it will be filled with
+ NoDataValue in the final image. The images are first compared to check if they can be
+ merged: Same resolution, nb of bands, projection, ... The result is stored in 'output_file'
+ """
+
+ args_command = '-o %s' % (output_file) # arguments for the merge command
+ for i, path in enumerate(input_paths):
+ if i == 0:
+ ref_image = gdal.Open(path) # the first input image is considered as the reference and will be compared
+ ref_path = path # with all the other images to make sur that merging is possible
+ else:
+ # Number of difference between two input images. Discarding the comparisons between the size of the bands
+ # and the scene coverage because they don't prevent merging
+ if not images_equivalent(ref_image, gdal.Open(path), raster_size=False, coverage=False):
+ print("Differences between %s and %s prevent merging. Aborted." %
+ (os.path.split(path)[1], os.path.split(ref_path)[1]))
+ return
+ args_command += ' ' + path
+
+ assert os.system('gdal_merge.py ' + args_command) == 0, 'Merging failed.'
+ print("Merged the images successfully in {}".format(output_file))
+
+
+def match_images(input_paths, output_path, target_path=None):
+ """
+ Transform the input images to match the target image.
+ Matching the GeoTransform (resolution + coverage) and the spatial reference system.
+ The data type of the source will stay the same.
+ The images are first compared to check if they can be matched directly: need the same scene cover.
+ If not, it will try to find the largest shared scene between all the images (inputs + target).
+ All images will be cropped to cover only the shared scene.
+ The generated images will be stored in a new directory: 'target_name_match' localized in output_path
+ """
+ target_need_crop = False # Target needs cropping if it doesn't cover exactly the same scene as the inputs
+
+ if isinstance(input_paths, str):
+ input_paths = [input_paths]
+
+ if target_path is None: # If no target is given, the first input is considered as the target
+ if len(input_paths) == 1:
+ print("Only one input was given. Need a target or another input to match.")
+ return
+ target_path = input_paths[0]
+ target_name = os.path.split(input_paths[0])[1].split('.')[0]
+ input_paths = np.delete(input_paths, 0)
+ print("No target was given. Replacing it with the first input: {}".format(target_name))
+ else:
+ target_name = os.path.split(target_path)[1].split('.')[0]
+
+ data_match = gdal.Open(target_path)
+ geo_match = data_match.GetGeoTransform() # target image characteristics used as reference
+ srs_match = data_match.GetSpatialRef()
+ corner_match_dict = gdal.Info(data_match, format='json')['cornerCoordinates']
+ corner_match_arr = [corner_match_dict.get(key) for key in corner_match_dict.keys()
+ if key in ('upperLeft', 'lowerRight')]
+
+ # sets the biggest shared region to the whole image
+ biggest_common_bounds = (corner_match_arr[0][0], corner_match_arr[1][1], corner_match_arr[1][0],
+ corner_match_arr[0][1])
+
+ # Stores the characteristics of the source images during the first for-loop
+ type_source = [None]*len(input_paths)
+ srs_source = [None]*len(input_paths)
+ output_file = ['']*len(input_paths)
+
+ new_dir = target_name + '_match'
+ new_path = os.path.join(output_path, new_dir)
+ try:
+ os.makedirs(new_path)
+ except OSError:
+ print('A directory already exists for images matched with {}. Old files can be replaced.\n'.format(target_name))
+
+ # Find biggest common overlapping region to all the images
+ for i, path in enumerate(input_paths):
+ img = gdal.Open(path)
+ input_name, file_ext = os.path.split(path)[1].split('.')
+ srs_source[i] = img.GetSpatialRef()
+ type_source[i] = img.GetRasterBand(1).DataType
+ output_file[i] = os.path.join(new_path, input_name + '_match.' + file_ext)
+
+ # Checks if the images cover the same scene
+ if not images_equivalent(data_match, img, projection=False, data_type=False,
+ num_bands=False, raster_size=False, resolution=False):
+ # Compute possible overlapping region
+ overlap_geo, overlap_pix_match, overlap_pix_src = compute_overlap_region(data_match, img)
+
+ # No overlap between the images then process is stopped
+ if overlap_geo is None or overlap_pix_match is None or overlap_pix_src is None:
+ print("Can't match {} to {} because they don't overlap geographically.".format(path,
+ target_name))
+ return
+ # Updates the shared region boundaries
+ biggest_common_bounds = (max(biggest_common_bounds[0], overlap_geo[0][0]),
+ max(biggest_common_bounds[1], overlap_geo[1][1]),
+ min(biggest_common_bounds[2], overlap_geo[1][0]),
+ min(biggest_common_bounds[3], overlap_geo[0][1]))
+
+ # Checks if the previous biggest shared region and the new computed shared region overlaps. If not, process
+ # is stopped
+ if biggest_common_bounds[0] > biggest_common_bounds[2] \
+ or biggest_common_bounds[1] > biggest_common_bounds[3]:
+ print("Some inputs do not have a common overlapping region with the target. Matching them is not "
+ "possible (Problem occured for {}).".format(path))
+ return
+
+ if corner_match_arr != overlap_geo:
+ target_need_crop = True
+
+ else:
+ print("{} and {} cover the same region (left corner:[{}, {}], right corner:[{}, {}])."
+ .format(input_name, target_name, corner_match_arr[0][0], corner_match_arr[0][1],
+ corner_match_arr[1][0], corner_match_arr[1][1]))
+
+ print("\nCommon overlapping region to all the images: {}".format(biggest_common_bounds))
+
+ for i, path in enumerate(input_paths):
+
+ print("Matching {} to {} in the overlapping region {}.".format(path, target_name,
+ biggest_common_bounds))
+
+ # Actual matching operation. Resampling algorithm is bilinear. Only matches the shared region of the images
+ gdal.Warp(output_file[i], path, dstSRS=srs_match, srcSRS=srs_source[i], outputType=type_source[i],
+ xRes=geo_match[XRES], yRes=geo_match[YRES], resampleAlg=gdalconst.GRIORA_Bilinear,
+ outputBounds=biggest_common_bounds)
+
+ if target_need_crop:
+ crop_target_path = os.path.join(new_path, target_name + '_crop.' + file_ext)
+
+ print("Cropping the overlapping region from the target so that the images cover the same region.")
+
+ # Crops the target if needed to only cover the shared region with the inputs
+ gdal.Translate(crop_target_path, target_path, projWin=[biggest_common_bounds[0], biggest_common_bounds[3],
+ biggest_common_bounds[2], biggest_common_bounds[1]])
+
+ return
+
+
+def images_equivalent(db_ref, db_new, projection=True, data_type=True, num_bands=True, raster_size=True,
+ resolution=True, coverage=True):
+ """
+ Compare some characteristics of two images (needs to be gdal datasets):
+ - Projections and Spatial reference system (SRS): Can be different but equivalent.
+ - Type of data: Usually Bytes, float, uintXX, ...
+ - Number of bands
+ - Bands' dimensions: equivalent to the shape of the image (in Pixels)
+ - Spatial resolutions: X and Y axis
+ - Coverage: Corner coordinates of the image in the SRS
+
+ inputs: Which characteristic not to compare can be specified in the kwargs (by default: compare all characteristics)
+ outputs: True if characteristics are the same, false otherwise
+ """
+ # pylint: disable=too-many-return-statements
+
+ assert isinstance(db_ref, gdal.Dataset) and isinstance(db_new, gdal.Dataset), 'Inputs must be gdal datasets.'
+
+ # Compares Projections and SRS
+ if projection:
+ if db_new.GetProjection() != db_ref.GetProjection():
+ # Checks if the projections are equivalent eventhough there are not exactly the same
+ if not osr.SpatialReference(db_ref.GetProjection()).IsSame(osr.SpatialReference(db_new.GetProjection())):
+ return False
+
+ # Compares data type
+ if data_type:
+ dtype_ref = gdal.GetDataTypeName(db_ref.GetRasterBand(1).DataType)
+ dtype_new = gdal.GetDataTypeName(db_new.GetRasterBand(1).DataType)
+
+ if dtype_new != dtype_ref:
+ return False
+
+ # Compares number of bands
+ if num_bands and db_ref.RasterCount != db_new.RasterCount:
+ return False
+
+ # Compares the bands' dimension
+ if raster_size:
+ gSzX = db_ref.RasterXSize
+ nSzX = db_new.RasterXSize
+ gSzY = db_ref.RasterYSize
+ nSzY = db_new.RasterYSize
+
+ if gSzX != nSzX or gSzY != nSzY:
+ return False
+
+ # Compares the spatial resolution
+ if resolution:
+ geo_ref = db_ref.GetGeoTransform()
+ geo_new = db_new.GetGeoTransform()
+
+ if geo_ref[XRES] != geo_new[XRES] or geo_ref[YRES] != geo_new[YRES]:
+ return False
+
+ # Compares the coverage
+ if coverage:
+ cornerCoord_ref = gdal.Info(db_ref, format='json')['cornerCoordinates']
+ cornerCoord_new = gdal.Info(db_new, format='json')['cornerCoordinates']
+
+ if cornerCoord_ref != cornerCoord_new:
+ return False
+
+ return True
+
+def compute_overlap_region(db_ref, db_new):
+ """
+ Computes the overlapping/shared region between two images.
+ Outputs:
+ - Corner coordinates of the overlapping region in the SRS
+ - Corresponding pixel indexes in both input images
+ """
+
+ cornerCoord_ref = gdal.Info(db_ref, format='json')['cornerCoordinates']
+ cornerCoord_new = gdal.Info(db_new, format='json')['cornerCoordinates']
+
+ cc_val_ref = [cornerCoord_ref.get(key) for key in cornerCoord_ref.keys() if key != 'center']
+ [ul_ref, dummy, lr_ref, _] = cc_val_ref
+
+ cc_val_new = [cornerCoord_new.get(key) for key in cornerCoord_new.keys() if key != 'center']
+ [ul_new, dummy, lr_new, _] = cc_val_new
+
+ # Checks if the images cover exactly the same regions.
+ if cc_val_ref == cc_val_new:
+ print("The images cover the same region (left corner:[{}, {}], right corner:[{}, {}])."
+ .format(ul_ref[0], ul_ref[1], lr_ref[0], lr_ref[1]))
+ return [ul_ref, lr_ref], [[0, 0], [db_ref.RasterXSize, db_ref.RasterYSize]], \
+ [[0, 0], [db_ref.RasterXSize, db_ref.RasterYSize]]
+
+ # Computes the overlapping region
+ overlap_corners = [[max(ul_ref[0], ul_new[0]), min(ul_ref[1], ul_new[1])],[min(lr_ref[0], lr_new[0]),
+ max(lr_ref[1], lr_new[1])]]
+
+ # Checks if the overlapping region is physically possible. If not, then the images don't cover the same region
+ if overlap_corners[0][0] > overlap_corners[1][0] or overlap_corners[0][1] < overlap_corners[1][1]:
+ print("The two regions represented by the images don't overlap.")
+ return None, None, None
+ print("Found an overlapping regions (left corner:[{}, {}], right corner:[{}, {}])."
+ .format(overlap_corners[0][0], overlap_corners[0][1], overlap_corners[1][0], overlap_corners[1][1]))
+
+ # If a shared region is found then compute the pixels indexes corresponding to its corner coordinates in
+ # for both images
+ col_ul_ref, row_ul_ref = geo2pix(db_ref.GetGeoTransform(), overlap_corners[0][0], overlap_corners[0][1])
+ col_lr_ref, row_lr_ref = geo2pix(db_ref.GetGeoTransform(), overlap_corners[1][0], overlap_corners[1][1])
+
+ col_ul_new, row_ul_new = geo2pix(db_new.GetGeoTransform(), overlap_corners[0][0], overlap_corners[0][1])
+ col_lr_new, row_lr_new = geo2pix(db_new.GetGeoTransform(), overlap_corners[1][0], overlap_corners[1][1])
+
+ return overlap_corners, [[row_ul_ref, col_ul_ref], [row_lr_ref, col_lr_ref]], [[row_ul_new, col_ul_new],
+ [row_lr_new, col_lr_new]]
+
+
+def fill_no_data_value(input_path, interpolation_distance=16, create_new_file=True,
+ output_file=None, no_data_value=None):
+ """
+ Fill noData pixels by interpolation from valid pixels around them.
+ inputs:
+ - Interpolation_distance: Maximum number of pixels to search in all
+ directions to find values to interpolate from.
+ - create_new_file: If True, the filled image will be stored in output_file
+ instead of modifying the original file.
+ - output_file: Where the new file is stored
+ - no_data_value: Not useful in most cases. If the noDataValue of the gdal dataset is not set.
+ This algorithm is generally suitable for interpolating missing regions of fairly continuously varying rasters
+ (such as elevation models for instance).
+ It is also suitable for filling small holes and cracks in more irregularly varying images (like airphotos).
+ It is generally not so great for interpolating a raster from sparse point data.
+ Returns 0 if a filling was necessary and 1 if all the pixels were already valid.
+ """
+ # create mask file or not
+ # gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK', 'NO')
+
+ # Get driver to write the output file
+ file_format = "GTiff"
+ driver = gdal.GetDriverByName(file_format)
+ input_head, input_tail = os.path.split(input_path)
+ input_name, input_ext = input_tail.split('.')
+
+ # Opens dataset in update mode
+ db = gdal.Open(input_path, gdal.GA_Update)
+
+ # Creates new file if instructed
+ if create_new_file:
+ if output_file is None:
+ print("A path for the output file is mandatory if create_new_file is True. Because none was given, "
+ "the new file will be saved at the same location as the input.")
+ output_file = os.path.join(input_head, input_name + '_filled.' + input_ext)
+ new_db = driver.CreateCopy(output_file, db)
+
+
+ # Else, updates the original image
+ else:
+ new_db = db
+
+ doFill = [True]*db.RasterCount
+
+ # Checks if each band needs filling
+ for i in range(0, new_db.RasterCount):
+ band = new_db.GetRasterBand(i+1)
+
+ # If needed, sets the value of the noData pixels
+ if no_data_value is not None and band.GetNoDataValue() is None:
+ band.SetNoDataValue(no_data_value)
+
+ # Check the gdal dataset flag. The band won't be filled if flag == GMF_ALL_VALID
+ if band.GetMaskFlags() == gdal.GMF_ALL_VALID:
+ print("All pixels of band {} of {} have a value. No need to fill it.".format(i + 1, input_name))
+ doFill[i] = False
+
+ # If all bands don't need to be filled then the process is stopped
+ if np.unique(doFill).size == 1 and not np.unique(doFill)[0]:
+ print("All pixels of {} are valid. No filling will be applied.".format(input_name))
+ return 0
+
+ # Fills each band and applies once a 3x3 smoothing filter on the filled areas
+ for i in range(0, new_db.RasterCount):
+
+ if doFill[i]:
+ band = new_db.GetRasterBand(i+1)
+ gdal.FillNodata(band, None, interpolation_distance, 1)
+ print("Filled successfully the pixels without data of {}.".format(input_tail))
+ return 1
+
+
+def pix2geo(geo_transform, xpix, ypix):
+ """
+ Computes the coordinate in the spatial reference system of a pixel given its indexes in the image array and the
+ geotransform of the latter.
+ """
+
+ xcoord = xpix*geo_transform[1] + geo_transform[0]
+ ycoord = ypix*geo_transform[5] + geo_transform[3]
+
+ return xcoord, ycoord
+
+
+def geo2pix(geo_transform, xcoord, ycoord):
+ """
+ Computes the indexes of the pixel in the image array corresponding to a point with given coordinates in the
+ spatial reference system.
+ """
+
+ xpix = int((xcoord-geo_transform[0])/geo_transform[1])
+ ypix = int((ycoord - geo_transform[3]) / geo_transform[5])
+
+ return xpix, ypix
+
+
+def compute_normalized_dsm(dsm_file, dem_file, apply_blur=False, output_file=None):
+ """
+ Computes the normalized Digital Surface Model (nDSM) from a DSM and its corresponding Digital Terrain Model (DTM).
+ The nDSM is obtained by simply subtracting the DTM to the DSM (pixel-wise).
+ If apply_blur is True then a blurring filter is applied to the DEM prior to the computation.
+ """
+
+ # Get driver to write the output file
+ file_format = "GTiff"
+ driver = gdal.GetDriverByName(file_format)
+
+ # If no output is given then creates the nDSM where the DSM is stored
+ if output_file is None:
+ print('No output file was given. The nDSM will be stored where the DSM is: {}'
+ .format(os.path.split(dsm_file)[0]))
+ output_file = os.path.join(os.path.split(dsm_file)[0], 'nDSM.tif')
+
+ # Applies blurring effect on the DEM (keeps the original)
+ if apply_blur:
+ db_dem = gdal.Open(dem_file)
+ path, ext = dem_file.split('.')
+
+ # Creates copy of the DEM to store the blurred data
+ blur_file = path + '_blurred.' + ext
+ db_blur = driver.CreateCopy(blur_file, db_dem)
+ band = db_blur.GetRasterBand(1)
+ data = band.ReadAsArray()
+ blurred_data = ndimage.gaussian_filter(data, 9)
+ band.WriteArray(blurred_data)
+ dem_file = blur_file
+
+ # Subtraction of the two rasters
+ assert os.system('gdal_cacl.py --calc=A-b -A %s -B %s --outfile %s --quiet' % \
+ (dsm_file, dem_file, output_file)) == 0
+
+ print('Computed the normalized dsm successfully.')
diff --git a/scripts/convert/landsat_toa.py b/scripts/convert/landsat_toa.py
index 43bc8800..ee117e90 100755
--- a/scripts/convert/landsat_toa.py
+++ b/scripts/convert/landsat_toa.py
@@ -23,7 +23,7 @@
import sys
import argparse
-from delta.imagery.sources import landsat
+from delta.extensions.sources import landsat
#------------------------------------------------------------------------------
diff --git a/scripts/convert/project_same.sh b/scripts/convert/project_same.sh
new file mode 100755
index 00000000..1e038630
--- /dev/null
+++ b/scripts/convert/project_same.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+# usage: project_same.sh projection.tiff in.tiff
+# Converts in.tiff to overlap projection.tiff with the same
+# resolution.
+
+srs_file=$1
+in_file=$2
+out_file="proj_$2"
+
+echo "Converting ${in_file} to ${out_file}, using area and resolution of ${srs_file}."
+
+data_type=$(gdalinfo ${in_file} | sed -n -e 's/.*Type=\(.*\),.*/\1/p' | head -1)
+num_bands=$(gdalinfo ${in_file} | grep "^Band" | wc -l)
+band_arg=$(printf -- '-b 1 %.0s' $(eval echo "{1..$num_bands}"))
+empty1_file=/tmp/empty1.tiff
+empty2_file=/tmp/empty2.tiff
+gdal_merge.py -createonly -init "0 0 0" -ot ${data_type} -o ${empty1_file} ${srs_file}
+gdal_translate -ot ${data_type} ${band_arg} ${empty1_file} ${empty2_file}
+rm ${empty1_file}
+
+pjt_file=$(mktemp /tmp/pjt.XXXXXX)
+pjt_img=$(mktemp /tmp/pjt_img.XXXXXX.tiff)
+#upper_left=$(gdalinfo ${srs_file} | sed -n -e 's/^Upper Left *( *\(.*\), *\(.*\)).*)/\1 \2/p')
+#lower_right=$(gdalinfo ${srs_file} | sed -n -e 's/^Lower Right *( *\(.*\), *\(.*\)).*)/\1 \2/p')
+pixel_size=$(gdalinfo ${srs_file} | sed -n -e 's/^Pixel Size = (\(.*\),\(.*\))/\1 \2/p')
+shp_file=/tmp/shape.shp # cannot have uppercase for some reason...
+gdaltindex ${shp_file} ${srs_file}
+gdalsrsinfo -o wkt "${srs_file}" > "${pjt_file}"
+gdalwarp -r bilinear -t_srs "${pjt_file}" -tr ${pixel_size} -cutline ${shp_file} -crop_to_cutline "${in_file}" "${out_file}"
+rm ${pjt_file} ${shp_file}
+rm ${empty2_file}
+
diff --git a/scripts/convert/worldview_toa.py b/scripts/convert/worldview_toa.py
index 90c28f2c..2758ac2f 100755
--- a/scripts/convert/worldview_toa.py
+++ b/scripts/convert/worldview_toa.py
@@ -24,7 +24,7 @@
import argparse
import traceback
-from delta.imagery.sources import worldview
+from delta.extensions.sources import worldview
#------------------------------------------------------------------------------
diff --git a/scripts/coverage.sh b/scripts/coverage.sh
new file mode 100755
index 00000000..db8c0950
--- /dev/null
+++ b/scripts/coverage.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+
+SCRIPT=$(readlink -f "$0")
+SCRIPTPATH=$(dirname "$SCRIPT")
+cd $SCRIPTPATH/..
+pytest --cov=delta --cov-report=html --cov-config=${SCRIPTPATH}/.coveragerc
diff --git a/scripts/example/l8_cloud.sh b/scripts/example/l8_cloud.sh
new file mode 100755
index 00000000..392a559b
--- /dev/null
+++ b/scripts/example/l8_cloud.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+# This example trains a Landsat 8 cloud classifier. This classification is
+# based on the SPARCS validation data:
+# https://www.usgs.gov/core-science-systems/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs
+
+SCRIPT=$(readlink -f "$0")
+SCRIPTPATH=$(dirname "$SCRIPT")
+
+if [ ! -f l8cloudmasks.zip ]; then
+ echo "Downloading dataset."
+ wget https://landsat.usgs.gov/cloud-validation/sparcs/l8cloudmasks.zip
+fi
+
+if [ ! -d sending ]; then
+ echo "Extracting dataset."
+ unzip -q l8cloudmasks.zip
+ mkdir validate
+ mv sending/LC82290562014157LGN00_24_data.tif sending/LC82210662014229LGN00_18_data.tif validate/
+ mkdir train
+ mv sending/*_data.tif train/
+ mkdir labels
+ mv sending/*_mask.png labels/
+fi
+
+if [ ! -f l8_clouds.h5 ]; then
+ cp $SCRIPTPATH/l8_cloud.yaml .
+ delta train --config l8_cloud.yaml l8_clouds.h5
+fi
+
+delta classify --config l8_cloud.yaml --image-dir ./validate --overlap 32 l8_clouds.h5
diff --git a/scripts/example/l8_cloud.yaml b/scripts/example/l8_cloud.yaml
new file mode 100644
index 00000000..565f9dda
--- /dev/null
+++ b/scripts/example/l8_cloud.yaml
@@ -0,0 +1,154 @@
+dataset:
+ images:
+ type: tiff
+ extension: _data.tif
+ directory: train
+ labels:
+ extension: _mask.png
+ type: tiff
+ directory: labels
+ classes:
+ - 0:
+ name: Shadow
+ color: 0x000000
+ - 1:
+ name: Shadow over Water
+ color: 0x000080
+ - 2:
+ name: Water
+ color: 0x0000FF
+ - 3:
+ name: Snow
+ color: 0x00FFFF
+ - 4:
+ name: Land
+ color: 0x808080
+ - 5:
+ name: Cloud
+ color: 0xFFFFFF
+ - 6:
+ name: Flooded
+ color: 0x808000
+
+io:
+ tile_size: [512, 512]
+
+train:
+ loss: sparse_categorical_crossentropy
+ metrics:
+ - sparse_categorical_accuracy
+ network:
+ model:
+ layers:
+ - Input:
+ shape: [~, ~, num_bands]
+ - Conv2D:
+ filters: 16
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ name: c1
+ - Dropout:
+ rate: 0.2
+ - MaxPool2D:
+ - Conv2D:
+ filters: 32
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ name: c2
+ - Dropout:
+ rate: 0.2
+ - MaxPool2D:
+ - Conv2D:
+ filters: 64
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ name: c3
+ - Dropout:
+ rate: 0.2
+ - MaxPool2D:
+ - Conv2D:
+ filters: 128
+ kernel_size: [3, 3]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ name: c4
+ - UpSampling2D:
+ - Conv2D:
+ filters: 64
+ kernel_size: [2, 2]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ name: u3
+ - Concatenate:
+ inputs: [c3, u3]
+ - Dropout:
+ rate: 0.2
+ - Conv2D:
+ filters: 64
+ kernel_size: [3, 3]
+ padding: same
+ - UpSampling2D:
+ - Conv2D:
+ filters: 32
+ kernel_size: [2, 2]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ name: u2
+ - Concatenate:
+ inputs: [c2, u2]
+ - Dropout:
+ rate: 0.2
+ - Conv2D:
+ filters: 32
+ kernel_size: [3, 3]
+ padding: same
+ - UpSampling2D:
+ - Conv2D:
+ filters: 16
+ kernel_size: [2, 2]
+ padding: same
+ - BatchNormalization:
+ - Activation:
+ activation: relu
+ name: u1
+ - Concatenate:
+ inputs: [c1, u1]
+ - Dropout:
+ rate: 0.2
+ - Conv2D:
+ filters: 7
+ kernel_size: [3, 3]
+ activation: linear
+ padding: same
+ - Softmax:
+ axis: 3
+ batch_size: 10
+ epochs: 10
+ validation:
+ from_training: false
+ images:
+ type: tiff
+ extension: _data.tif
+ directory: validate
+ labels:
+ extension: _mask.png
+ type: tiff
+ directory: labels
+
+mlflow:
+ experiment_name: Landsat8 Clouds Example
diff --git a/scripts/fetch/check_inputs.py b/scripts/fetch/check_inputs.py
new file mode 100644
index 00000000..0a5c943a
--- /dev/null
+++ b/scripts/fetch/check_inputs.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python
+
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#pylint: disable=R0914
+
+"""
+Go through all of the images in a folder and verify all the images can be loaded.
+"""
+import os
+import sys
+import argparse
+import traceback
+import delta.config.modules
+from delta.config.extensions import image_reader
+
+# Needed for image cache to be created
+delta.config.modules.register_all()
+
+
+#------------------------------------------------------------------------------
+
+def get_label_path(image_name, options):
+ """Return the label file path for a given input image or throw if it is
+ not found at the expected location."""
+
+ label_name = image_name.replace(options.image_extension, options.label_extension)
+ label_path = os.path.join(options.label_folder, label_name)
+ if not os.path.exists(label_path):
+ raise Exception('Expected label file does not exist: ' + label_path)
+ return label_path
+
+def main(argsIn):
+
+
+ try:
+
+ usage = "usage: check_inputs [options]"
+ parser = argparse.ArgumentParser(usage=usage)
+
+ parser.add_argument("--image-folder", dest="image_folder", required=True,
+ help="Folder containing the input image files.")
+
+ parser.add_argument("--image-type", dest="image_type", default='worldview',
+ help="Type of image files.")
+
+ parser.add_argument("--image-ext", dest="image_extension", default='.zip',
+ help="Extension for image files.")
+
+ options = parser.parse_args(argsIn)
+
+ except argparse.ArgumentError:
+ print(usage)
+ return -1
+
+
+ # Recursively find image files, obtaining the full path for each file.
+ input_image_list = [os.path.join(root, name)
+ for root, dirs, files in os.walk(options.image_folder)
+ for name in files
+ if name.endswith((options.image_extension))]
+
+ print('Found ' + str(len(input_image_list)) + ' image files.')
+
+ # Try to load each file and record the ones that fail
+ failed_files = []
+ for image_path in input_image_list:
+
+ try:
+ image_reader(options.image_type)(image_path)
+ except Exception as e: #pylint: disable=W0703
+ failed_files.append(image_path)
+ print('For file: ' + image_path +
+ '\ncaught exception: ' + str(e))
+ traceback.print_exc(file=sys.stdout)
+
+ if failed_files:
+ print('The following files failed: ')
+ for f in failed_files:
+ print(f)
+ else:
+ print('No files failed to load!')
+
+ return 0
+
+if __name__ == "__main__":
+ sys.exit(main(sys.argv[1:]))
diff --git a/scripts/fetch/convert_image_list.py b/scripts/fetch/convert_image_list.py
index 9f5ba888..369935e9 100755
--- a/scripts/fetch/convert_image_list.py
+++ b/scripts/fetch/convert_image_list.py
@@ -28,6 +28,7 @@ def main(argsIn): #pylint: disable=R0914,R0912
if len(argsIn) != 2:
print("usage: convert_image_list.py ")
+ return -1
input_path = argsIn[0]
output_path = argsIn[1]
diff --git a/scripts/fetch/fetch_hdds_images.py b/scripts/fetch/fetch_hdds_images.py
index 112e813d..cf4756cd 100755
--- a/scripts/fetch/fetch_hdds_images.py
+++ b/scripts/fetch/fetch_hdds_images.py
@@ -355,8 +355,8 @@ def main(argsIn): #pylint: disable=R0914,R0912
product=download_type)
try:
url = r['data'][0]['url']
- except:
- raise Exception('Failed to get download URL from result: ' + str(r))
+ except Exception as e:
+ raise Exception('Failed to get download URL from result: ' + str(r)) from e
print(scene['summary'])
# Finally download the data!
diff --git a/scripts/fetch/get_landsat_dswe_labels.py b/scripts/fetch/get_landsat_dswe_labels.py
index 066b698a..b798081b 100755
--- a/scripts/fetch/get_landsat_dswe_labels.py
+++ b/scripts/fetch/get_landsat_dswe_labels.py
@@ -31,7 +31,7 @@
from usgs import api
from delta.imagery import utilities
-from delta.imagery.sources import landsat
+from delta.extensions.sources import landsat
#------------------------------------------------------------------------------
diff --git a/scripts/fetch/random_folder_split.py b/scripts/fetch/random_folder_split.py
index d050663b..0b6ee377 100644
--- a/scripts/fetch/random_folder_split.py
+++ b/scripts/fetch/random_folder_split.py
@@ -31,6 +31,7 @@
#------------------------------------------------------------------------------
+# TODO: Need a good system for this that handles unpacked images!!!
def get_label_path(image_name, options):
"""Return the label file path for a given input image or throw if it is
not found at the expected location."""
@@ -41,7 +42,7 @@ def get_label_path(image_name, options):
raise Exception('Expected label file does not exist: ' + label_path)
return label_path
-def main(argsIn):
+def main(argsIn): #pylint: disable=R0912
try:
@@ -64,6 +65,10 @@ def main(argsIn):
parser.add_argument("--label-ext", dest="label_extension", default='.tif',
help="Extension for label files.")
+ parser.add_argument("--link-folders", action="store_true",
+ dest="link_folders", default=False,
+ help="Link the files containing the detected folders")
+
parser.add_argument("--image-limit", dest="image_limit", default=None, type=int,
help="Only use this many image files total.")
@@ -116,7 +121,8 @@ def main(argsIn):
for image_path in input_image_list:
# If an image list was provided skip images which are not in the list.
- image_name = os.path.basename(image_path)
+ image_name = os.path.basename(image_path)
+ image_folder = os.path.dirname(image_path)
if images_to_use and (os.path.splitext(image_name)[0] not in images_to_use):
continue
@@ -124,13 +130,19 @@ def main(argsIn):
use_for_valid = (random.random() < options.validate_fraction)
# Handle the image file
+ target_name = image_name
+ if options.link_folders:
+ target_name = os.path.basename(image_folder) # Last folder name
if use_for_valid:
- image_dest = os.path.join(valid_image_folder, image_name)
+ image_dest = os.path.join(valid_image_folder, target_name)
valid_count += 1
else:
- image_dest = os.path.join(train_image_folder, image_name)
+ image_dest = os.path.join(train_image_folder, target_name)
train_count += 1
- os.symlink(image_path, image_dest)
+ if options.link_folders:
+ os.symlink(image_folder, image_dest)
+ else:
+ os.symlink(image_path, image_dest)
if options.label_folder: # Handle the label file
label_path = get_label_path(image_name, options)
diff --git a/scripts/fetch/unpack_inputs.py b/scripts/fetch/unpack_inputs.py
new file mode 100644
index 00000000..b75ec48b
--- /dev/null
+++ b/scripts/fetch/unpack_inputs.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python
+
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#pylint: disable=R0914
+
+"""
+Try to unpack compressed input images to an output folder
+"""
+import os
+import sys
+import argparse
+import traceback
+from delta.extensions.sources import worldview
+from delta.extensions.sources import sentinel1
+from delta.extensions.sources import tiff
+
+
+#------------------------------------------------------------------------------
+
+def main(argsIn):
+
+ SUPPORTED_IMAGE_TYPES = ['worldview', 'sentinel1']
+
+ try:
+
+ usage = "usage: unpack_inputs [options]"
+ parser = argparse.ArgumentParser(usage=usage)
+
+ parser.add_argument("--input-folder", dest="input_folder", required=True,
+ help="Folder containing the input image files.")
+
+ parser.add_argument("--output-folder", dest="output_folder", required=True,
+ help="Unpack images to this folder.")
+
+ parser.add_argument("--image-type", dest="image_type", default='worldview',
+ help="Type of image files: " +
+ ', '.join(SUPPORTED_IMAGE_TYPES))
+
+ parser.add_argument("--image-ext", dest="image_extension", default='.zip',
+ help="Extension for image files.")
+
+ parser.add_argument("--delete-inputs", action="store_true",
+ dest="delete_inputs", default=False,
+ help="Delete input files after unpacking.")
+
+ parser.add_argument("--image-limit", dest="image_limit",
+ default=None, type=int,
+ help="Stop after unpacking this many images.")
+
+ options = parser.parse_args(argsIn)
+
+ except argparse.ArgumentError:
+ print(usage)
+ return -1
+
+ if options.image_type not in SUPPORTED_IMAGE_TYPES:
+ print('Input image type is not supported!')
+ return -1
+
+ # Recursively find image files, obtaining the full path for each file.
+ input_image_list = [os.path.join(root, name)
+ for root, dirs, files in os.walk(options.input_folder)
+ for name in files
+ if name.endswith((options.image_extension))]
+
+ print('Found ' + str(len(input_image_list)) + ' image files.')
+
+ # Try to load each file and record the ones that fail
+ failed_files = []
+ count = 0
+ for image_path in input_image_list:
+
+ try:
+
+ if count % 10 == 0:
+ print('Progress = ' + str(count) + ' out of ' + str(len(input_image_list)))
+
+ if options.image_limit and (count >= options.image_limit):
+ print('Stopping because we hit the image limit.')
+ break
+ count += 1
+
+ # Mirror the input folder structure in the output folder
+ image_name = os.path.basename(os.path.splitext(image_path)[0])
+ image_folder = os.path.dirname(image_path)
+ relative_path = os.path.relpath(image_folder, options.input_folder)
+ this_output_folder = os.path.join(options.output_folder,
+ relative_path, image_name)
+
+ # TODO: Synch up the unpack functions
+ tif_path = None
+ if not os.path.exists(this_output_folder):
+ print('Unpacking input file: ' + image_path)
+ if options.image_type == 'worldview':
+ tif_path = worldview.unpack_wv_to_folder(image_path, this_output_folder)[0]
+ else: # sentinel1
+ tif_path = sentinel1.unpack_s1_to_folder(image_path, this_output_folder)
+
+ else: # The folder was already unpacked (at least partially)
+ if options.image_type == 'worldview':
+ tif_path = worldview.get_files_from_unpack_folder(this_output_folder)[0]
+ else: # sentinel1
+ tif_path = sentinel1.unpack_s1_to_folder(image_path, this_output_folder)
+
+ # Make sure the unpacked image loads properly
+ test_image = tiff.TiffImage(tif_path) #pylint: disable=W0612
+
+ if options.delete_inputs:
+ print('Deleting input file: ' + image_path)
+ os.remove(image_path)
+
+ except Exception as e: #pylint: disable=W0703
+ failed_files.append(image_path)
+ print('For file: ' + image_path +
+ '\ncaught exception: ' + str(e))
+ traceback.print_exc(file=sys.stdout)
+
+ if failed_files:
+ print('The following files failed: ')
+ for f in failed_files:
+ print(f)
+ else:
+ print('No files failed to unpack!')
+
+ return 0
+
+if __name__ == "__main__":
+ sys.exit(main(sys.argv[1:]))
diff --git a/scripts/label-img-info b/scripts/label-img-info
deleted file mode 100755
index 5bc326d3..00000000
--- a/scripts/label-img-info
+++ /dev/null
@@ -1,19 +0,0 @@
-#!/usr/bin/env python
-
-import sys
-import pathlib
-import numpy as np
-from osgeo import gdal
-
-if __name__=='__main__':
- assert len(sys.argv) > 1, 'Need to supply a file'
- filename = pathlib.Path(sys.argv[1])
-
- tif_file = gdal.Open(str(filename))
- assert tif_file is not None, f'Could not open file {filename}'
- tif_data = tif_file.ReadAsArray()
- unique_labels = np.unique(tif_data)
- print(np.any(np.isnan(tif_data)), tif_data.min(), tif_data.max(), tif_data.shape, unique_labels)
- print(np.histogram(tif_data, bins=len(unique_labels)))
-
-
diff --git a/scripts/model2config b/scripts/model2config
index 3b74504c..bd3d3855 100755
--- a/scripts/model2config
+++ b/scripts/model2config
@@ -1,15 +1,24 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
import tensorflow as tf
from argparse import ArgumentParser
import h5py
import pathlib
+from delta.config.extensions import custom_objects
+import delta.extensions
+from delta.ml.io import print_network
parser = ArgumentParser(description='Converts a neural network in a *.h5 file to the DELTA configuration langauge')
parser.add_argument('model_name', type=pathlib.Path, help='The model to convert')
+parser.add_argument('-s', '--size', type=str, help='Tile width.')
args = parser.parse_args()
+if args.size is not None:
+ t = args.size.split('x')
+ assert len(t) == 2
+ args.size = (int(t[0]), int(t[1]))
+
print('Configuration File')
with h5py.File(args.model_name, 'r') as f:
if 'delta' not in f.attrs:
@@ -17,16 +26,6 @@ with h5py.File(args.model_name, 'r') as f:
else:
print('\n' + f.attrs['delta'] + '\n')
-a = tf.keras.models.load_model(args.model_name)
-print('Network Structure')
-for l in a.layers:
- print('\t- ', type(l).__name__)
- configs = l.get_config()
- if isinstance(l.input, list):
- print('\t\t- input: ['+ ', '.join([x.name.replace('/Identity:0','') for x in l.input])+ ']')
- else:
- print('\t\t- input:', l.input.name.replace('/Identity:0',''))
- for k in configs.keys():
- if isinstance(configs[k], dict) or configs[k] is None:
- continue
- print(f'\t\t- {k}: {configs[k]}')
+print('Network')
+a = tf.keras.models.load_model(args.model_name, custom_objects=custom_objects(), compile=False)
+print_network(a, args.size)
diff --git a/scripts/visualize/compare_histograms.py b/scripts/visualize/compare_histograms.py
new file mode 100755
index 00000000..d7660661
--- /dev/null
+++ b/scripts/visualize/compare_histograms.py
@@ -0,0 +1,48 @@
+#!/usr/bin/python3
+
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script creates a pdf comparing the histograms of all tiff files input.
+
+import sys
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from matplotlib.backends.backend_pdf import PdfPages
+
+from delta.extensions.sources.tiff import TiffImage
+
+def plot_band(names, band):
+ imgs = [TiffImage(n) for n in names]
+ max_value = 2.0
+ for img in imgs:
+ data = np.ndarray.flatten(img.read(bands=band))
+ data = data[data > 0.0]
+ data[data > max_value] = max_value
+ plt.hist(data, bins=200, alpha=0.5)
+ plt.title('Band ' + str(band))
+ pdf.savefig()
+ plt.close()
+
+assert len(sys.argv) > 1, 'No input tiffs specified.'
+
+with PdfPages('output.pdf') as pdf:
+ a = TiffImage(sys.argv[1])
+ for i in range(a.num_bands()):
+ plot_band(sys.argv[1:], i)
diff --git a/scripts/visualize/diff.py b/scripts/visualize/diff.py
new file mode 100755
index 00000000..db3ef88e
--- /dev/null
+++ b/scripts/visualize/diff.py
@@ -0,0 +1,52 @@
+#!/usr/bin/python3
+
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Creates a difference image between two images.
+
+import sys
+
+import numpy as np
+
+from delta.extensions.sources.tiff import TiffImage, TiffWriter
+from delta.imagery import rectangle
+
+assert len(sys.argv) == 3, 'Please specify two tiff files of the same size.'
+
+img1 = TiffImage(sys.argv[1])
+img2 = TiffImage(sys.argv[2])
+
+output_image = TiffWriter('diff.tiff')
+output_image.initialize((img1.width(), img1.height(), 3), np.uint8, img1.metadata())
+
+assert img1.width()== img2.width() and img1.height() == img2.height() and \
+ img1.num_bands() == img2.num_bands(), 'Images must be same size.'
+
+def callback_function(roi, data):
+ data2 = img2.read(roi)
+ diff = np.mean((data - data2) ** 2, axis=-1)
+ diff = np.uint8(np.clip(diff * 128.0, 0.0, 255.0))
+ out = np.stack([diff, diff, diff], axis=-1)
+ output_image.write(out, roi.min_x, roi.min_y)
+
+input_bounds = rectangle.Rectangle(0, 0, width=img1.width(), height=img1.height())
+output_rois = input_bounds.make_tile_rois((2048, 2048), include_partials=True)
+
+img1.process_rois(output_rois, callback_function, show_progress=True)
+
+output_image.close()
diff --git a/setup.py b/setup.py
index 36ba9160..168092f9 100644
--- a/setup.py
+++ b/setup.py
@@ -30,7 +30,7 @@
setuptools.setup(
name="delta",
- version="0.1.2",
+ version="0.2.0",
author="NASA Ames",
author_email="todo@todo",
description="Deep learning for satellite imagery",
@@ -46,16 +46,16 @@
"Operating System :: OS Independent"
],
install_requires=[
- 'usgs',
- 'numpy',
+ 'tensorflow>=2.1',
+ 'usgs<0.3',
'scipy',
'matplotlib',
- 'tensorflow>=2.1',
'mlflow',
'portalocker',
'appdirs',
- 'gdal',
- 'h5py'
+ 'gdal'
+ #'numpy', # these are included by tensorflow with restrictions
+ #'h5py'
],
scripts=scripts,
include_package_data = True,
diff --git a/tests/conftest.py b/tests/conftest.py
index d7bffbf0..26220052 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,24 +17,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-#pylint:disable=redefined-outer-name
+#pylint:disable=redefined-outer-name,wrong-import-position
import os
import random
import shutil
import sys
import tempfile
+import warnings
import zipfile
import numpy as np
import pytest
-from delta.imagery.sources import tiff
+# conftest.py loaded before pytest.ini warning filters apparently
+warnings.filterwarnings('ignore', category=DeprecationWarning, module='osgeo')
+
+from delta.config import config
+from delta.extensions.sources import tiff
import delta.config.modules
delta.config.modules.register_all()
assert 'tensorflow' not in sys.modules, 'For speed of command line tool, tensorflow should not be imported by config!'
+from delta.imagery import imagery_dataset #pylint: disable=wrong-import-position
+
+import tensorflow as tf #pylint: disable=wrong-import-position, wrong-import-order
+tf.get_logger().setLevel('ERROR')
+
+def config_reset():
+ """
+ Resets the configuration with useful default options for testing.
+ """
+ config.reset() # don't load any user files
+ config.load(yaml_str=
+ '''
+ mlflow:
+ enabled: false
+ ''')
+
def generate_tile(width=32, height=32, blocks=50):
"""Generate a widthXheightX3 image, with blocks pixels surrounded by ones and the rest zeros in band 0"""
image = np.zeros((width, height, 1), np.float32)
@@ -71,6 +92,34 @@ def original_file():
shutil.rmtree(tmpdir)
+@pytest.fixture(scope="session")
+def doubling_tiff_filenames():
+ tmpdir = tempfile.mkdtemp()
+ image_path = os.path.join(tmpdir, 'image.tiff')
+ label_path = os.path.join(tmpdir, 'label.tiff')
+
+ image = np.random.random((128, 128)) #pylint: disable=no-member
+ label = 2 * image
+ tiff.write_tiff(image_path, image)
+ tiff.write_tiff(label_path, label)
+ yield ([image_path], [label_path])
+
+ shutil.rmtree(tmpdir)
+
+@pytest.fixture(scope="session")
+def binary_identity_tiff_filenames():
+ tmpdir = tempfile.mkdtemp()
+ image_path = os.path.join(tmpdir, 'image.tiff')
+ label_path = os.path.join(tmpdir, 'label.tiff')
+
+ label = np.random.randint(0, 2, (128, 128), np.uint8) #pylint: disable=no-member
+ image = np.take(np.asarray([[1.0, 0.0], [0.0, 1.0]]), label, axis=0)
+ tiff.write_tiff(image_path, image)
+ tiff.write_tiff(label_path, label)
+ yield ([image_path], [label_path])
+
+ shutil.rmtree(tmpdir)
+
@pytest.fixture(scope="session")
def worldview_filenames(original_file):
tmpdir = tempfile.mkdtemp()
@@ -101,7 +150,86 @@ def worldview_filenames(original_file):
shutil.rmtree(tmpdir)
+@pytest.fixture(scope="session")
+def landsat_filenames(original_file):
+ tmpdir = tempfile.mkdtemp()
+ image_name = 'L1_IGNORE_AAABBB_DATE'
+ mtl_name = image_name + '_MTL.txt'
+ mtl_path = os.path.join(tmpdir, mtl_name)
+ zip_path = os.path.join(tmpdir, image_name + '.zip')
+ # not really a valid file but this is all we need, only one band in image
+ with open(mtl_path, 'a') as f:
+ f.write('SPACECRAFT_ID = LANDSAT_1\n')
+ f.write('SUN_ELEVATION = 5.8\n')
+ f.write('FILE_NAME_BAND_1 = 1.tiff\n')
+ f.write('RADIANCE_MULT_BAND_1 = 2.0\n')
+ f.write('RADIANCE_ADD_BAND_1 = 2.0\n')
+ f.write('REFLECTANCE_MULT_BAND_1 = 2.0\n')
+ f.write('REFLECTANCE_ADD_BAND_1 = 2.0\n')
+ f.write('K1_CONSTANT_BAND_1 = 2.0\n')
+ f.write('K2_CONSTANT_BAND_1 = 2.0\n')
+
+ image_path = os.path.join(tmpdir, '1.tiff')
+ tiff.TiffImage(original_file[0]).save(image_path)
+
+ z = zipfile.ZipFile(zip_path, mode='x')
+ z.write(image_path, arcname='1.tiff')
+ z.write(mtl_path, arcname=mtl_name)
+ z.close()
+
+ label_path = os.path.join(tmpdir, image_name + '_label.tiff')
+ tiff.TiffImage(original_file[1]).save(label_path)
+
+ yield (zip_path, label_path)
+
+ shutil.rmtree(tmpdir)
+
NUM_SOURCES = 1
@pytest.fixture(scope="session")
def all_sources(worldview_filenames):
return [(worldview_filenames, '.zip', 'worldview', '_label.tiff', 'tiff')]
+
+def load_dataset(source, output_size, chunk_size=3, autoencoder=False):
+ config_reset()
+ (image_path, label_path) = source[0]
+ config.load(yaml_str=
+ '''
+ io:
+ cache:
+ dir: %s
+ dataset:
+ images:
+ type: %s
+ directory: %s
+ extension: %s
+ preprocess: ~
+ labels:
+ type: %s
+ directory: %s
+ extension: %s
+ preprocess: ~''' %
+ (os.path.dirname(image_path), source[2], os.path.dirname(image_path), source[1],
+ source[4], os.path.dirname(label_path), source[3]))
+
+ if autoencoder:
+ return imagery_dataset.AutoencoderDataset(config.dataset.images(), (chunk_size, chunk_size),
+ tile_shape=config.io.tile_size(),
+ stride=config.train.spec().stride)
+ return imagery_dataset.ImageryDataset(config.dataset.images(), config.dataset.labels(),
+ (output_size, output_size),
+ (chunk_size, chunk_size), tile_shape=config.io.tile_size(),
+ stride=config.train.spec().stride)
+
+@pytest.fixture(scope="function", params=range(NUM_SOURCES))
+def dataset(all_sources, request):
+ source = all_sources[request.param]
+ return load_dataset(source, 1)
+
+@pytest.fixture(scope="function", params=range(NUM_SOURCES))
+def ae_dataset(all_sources, request):
+ source = all_sources[request.param]
+ return load_dataset(source, 1, autoencoder=True)
+
+@pytest.fixture(scope="function")
+def dataset_block_label(all_sources):
+ return load_dataset(all_sources[0], 3)
diff --git a/tests/test_commands.py b/tests/test_commands.py
new file mode 100644
index 00000000..769bd07c
--- /dev/null
+++ b/tests/test_commands.py
@@ -0,0 +1,180 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#pylint: disable=redefined-outer-name
+
+import os
+import shutil
+import tempfile
+
+import numpy as np
+import pytest
+import tensorflow as tf
+
+from conftest import config_reset
+
+from delta.extensions.sources.tiff import TiffImage
+from delta.ml.predict import LabelPredictor, ImagePredictor
+from delta.subcommands.main import main
+
+@pytest.fixture(scope="session")
+def identity_config(binary_identity_tiff_filenames):
+ tmpdir = tempfile.mkdtemp()
+
+ config_path = os.path.join(tmpdir, 'dataset.yaml')
+ with open(config_path, 'w') as f:
+ f.write('''
+ dataset:
+ images:
+ nodata_value: ~
+ files:
+ ''')
+
+ for fn in binary_identity_tiff_filenames[0]:
+ f.write(' - %s\n' % (fn))
+ f.write('''
+ labels:
+ nodata_value: 2
+ files:
+ ''')
+ for fn in binary_identity_tiff_filenames[1]:
+ f.write(' - %s\n' % (fn))
+ f.write('''
+ classes: 2
+ io:
+ tile_size: [128, 128]
+ ''')
+
+ yield config_path
+
+ shutil.rmtree(tmpdir)
+
+def test_predict_main(identity_config, tmp_path):
+ config_reset()
+ model_path = tmp_path / 'model.h5'
+ inputs = tf.keras.layers.Input((32, 32, 2))
+ tf.keras.Model(inputs, inputs).save(model_path)
+ args = 'delta classify --config %s %s' % (identity_config, model_path)
+ old = os.getcwd()
+ os.chdir(tmp_path) # put temporary outputs here
+ main(args.split())
+ os.chdir(old)
+
+def test_train_main(identity_config, tmp_path):
+ config_reset()
+ train_config = tmp_path / 'config.yaml'
+ with open(train_config, 'w') as f:
+ f.write('''
+ train:
+ steps: 5
+ epochs: 3
+ network:
+ layers:
+ - Input:
+ shape: [1, 1, num_bands]
+ - Conv2D:
+ filters: 2
+ kernel_size: [1, 1]
+ activation: relu
+ padding: same
+ batch_size: 1
+ validation:
+ steps: 2
+ callbacks:
+ - ExponentialLRScheduler:
+ start_epoch: 2
+ ''')
+ args = 'delta train --config %s --config %s' % (identity_config, train_config)
+ main(args.split())
+
+def test_train_validate(identity_config, binary_identity_tiff_filenames, tmp_path):
+ config_reset()
+ train_config = tmp_path / 'config.yaml'
+ with open(train_config, 'w') as f:
+ f.write('''
+ train:
+ steps: 5
+ epochs: 3
+ network:
+ layers:
+ - Input:
+ shape: [~, ~, num_bands]
+ - Conv2D:
+ filters: 2
+ kernel_size: [1, 1]
+ activation: relu
+ padding: same
+ batch_size: 1
+ validation:
+ from_training: false
+ images:
+ nodata_value: ~
+ files: [%s]
+ labels:
+ nodata_value: ~
+ files: [%s]
+ steps: 2
+ callbacks:
+ - ExponentialLRScheduler:
+ start_epoch: 2
+ ''' % (binary_identity_tiff_filenames[0][0], binary_identity_tiff_filenames[1][0]))
+ args = 'delta train --config %s --config %s' % (identity_config, train_config)
+ main(args.split())
+
+def test_validate_main(identity_config):
+ config_reset()
+ args = 'delta validate --config %s' % (identity_config, )
+ main(args.split())
+
+def test_predict(binary_identity_tiff_filenames):
+ inputs = tf.keras.layers.Input((32, 32, 2))
+ model = tf.keras.Model(inputs, inputs)
+ pred = LabelPredictor(model)
+ image = TiffImage(binary_identity_tiff_filenames[0])
+ label = TiffImage(binary_identity_tiff_filenames[1])
+ pred.predict(image, label)
+ cm = pred.confusion_matrix()
+ assert np.sum(np.diag(cm)) == np.sum(cm)
+
+def test_predict_nodata(binary_identity_tiff_filenames):
+ inputs = tf.keras.layers.Input((32, 32, 2))
+ model = tf.keras.Model(inputs, inputs)
+ pred = LabelPredictor(model)
+ image = TiffImage(binary_identity_tiff_filenames[0])
+ label = TiffImage(binary_identity_tiff_filenames[1], 1)
+ pred.predict(image, label)
+ cm = pred.confusion_matrix()
+ assert cm[0, 0] == np.sum(cm)
+
+def test_predict_image_nodata(binary_identity_tiff_filenames):
+ inputs = tf.keras.layers.Input((32, 32, 2))
+ model = tf.keras.Model(inputs, inputs)
+ pred = LabelPredictor(model)
+ image = TiffImage(binary_identity_tiff_filenames[0], 1)
+ label = TiffImage(binary_identity_tiff_filenames[1])
+ pred.predict(image, label)
+ cm = pred.confusion_matrix()
+ assert np.sum(np.diag(cm)) == np.sum(cm)
+
+def test_predict_image(doubling_tiff_filenames):
+ inputs = tf.keras.layers.Input((32, 32, 1))
+ output = tf.keras.layers.Add()([inputs, inputs])
+ model = tf.keras.Model(inputs, output)
+ pred = ImagePredictor(model)
+ image = TiffImage(doubling_tiff_filenames[0])
+ label = TiffImage(doubling_tiff_filenames[1])
+ pred.predict(image, label)
diff --git a/tests/test_config.py b/tests/test_config.py
index 5586e471..32fbb62c 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -24,11 +24,13 @@
import numpy as np
import tensorflow as tf
+from conftest import config_reset
+
from delta.config import config
-from delta.ml import model_parser
+from delta.ml import config_parser
def test_general():
- config.reset()
+ config_reset()
assert config.general.gpus() == -1
@@ -37,9 +39,8 @@ def test_general():
gpus: 3
io:
threads: 5
- block_size_mb: 10
+ tile_size: [5, 5]
interleave_images: 3
- tile_ratio: 1.0
cache:
dir: nonsense
limit: 2
@@ -48,53 +49,69 @@ def test_general():
assert config.general.gpus() == 3
assert config.io.threads() == 5
- assert config.io.block_size_mb() == 10
+ assert config.io.tile_size()[0] == 5
+ assert config.io.tile_size()[1] == 5
assert config.io.interleave_images() == 3
- assert config.io.tile_ratio() == 1.0
cache = config.io.cache.manager()
assert cache.folder() == 'nonsense'
assert cache.limit() == 2
os.rmdir('nonsense')
def test_images_dir():
- config.reset()
+ config_reset()
dir_path = os.path.join(os.path.dirname(__file__), 'data')
test_str = '''
dataset:
images:
type: tiff
- preprocess:
- enabled: false
+ preprocess: ~
directory: %s/
extension: .tiff
''' % (dir_path)
config.load(yaml_str=test_str)
im = config.dataset.images()
- assert im.preprocess() is None
assert im.type() == 'tiff'
assert len(im) == 1
assert im[0].endswith('landsat.tiff') and os.path.exists(im[0])
+def test_preprocess():
+ config_reset()
+ test_str = '''
+ dataset:
+ images:
+ preprocess:
+ - scale:
+ factor: 2.0
+ - offset:
+ factor: 1.0
+ - clip:
+ bounds: [0, 5]
+ '''
+ config.load(yaml_str=test_str)
+ f = config.dataset.images().preprocess()
+ assert f(np.asarray([0.0]), None, None) == 1.0
+ assert f(np.asarray([2.0]), None, None) == 2.0
+ assert f(np.asarray([-5.0]), None, None) == 0.0
+ assert f(np.asarray([20.0]), None, None) == 5.0
+
def test_images_files():
- config.reset()
+ config_reset()
file_path = os.path.join(os.path.dirname(__file__), 'data', 'landsat.tiff')
test_str = '''
dataset:
images:
type: tiff
- preprocess:
- enabled: false
+ preprocess: ~
files: [%s]
''' % (file_path)
config.load(yaml_str=test_str)
im = config.dataset.images()
- assert im.preprocess() is None
assert im.type() == 'tiff'
assert len(im) == 1
assert im[0] == file_path
def test_classes():
- config.reset()
+ config_reset()
test_str = '''
dataset:
classes: 2
@@ -104,7 +121,22 @@ def test_classes():
for (i, c) in enumerate(config.dataset.classes):
assert c.value == i
assert config.dataset.classes.weights() is None
- config.reset()
+
+ def assert_classes(classes):
+ assert classes
+ values = [1, 2, 5]
+ for (i, c) in enumerate(classes):
+ e = values[i]
+ assert c.value == e
+ assert c.name == str(e)
+ assert c.color == e
+ assert classes.weights() == [1.0, 5.0, 2.0]
+ arr = np.array(values)
+ ind = classes.classes_to_indices_func()(arr)
+ assert np.max(ind) == 2
+ assert (classes.indices_to_classes_func()(ind) == values).all()
+
+ config_reset()
test_str = '''
dataset:
classes:
@@ -122,27 +154,37 @@ def test_classes():
weight: 2.0
'''
config.load(yaml_str=test_str)
- assert config.dataset.classes
- values = [1, 2, 5]
- for (i, c) in enumerate(config.dataset.classes):
- e = values[i]
- assert c.value == e
- assert c.name == str(e)
- assert c.color == e
- assert config.dataset.classes.weights() == [1.0, 5.0, 2.0]
- arr = np.array(values)
- ind = config.dataset.classes.classes_to_indices_func()(arr)
- assert np.max(ind) == 2
- assert (config.dataset.classes.indices_to_classes_func()(ind) == values).all()
+ assert_classes(config.dataset.classes)
+
+ config_reset()
+ test_str = '''
+ dataset:
+ classes:
+ 2:
+ name: 2
+ color: 2
+ weight: 5.0
+ 1:
+ name: 1
+ color: 1
+ weight: 1.0
+ 5:
+ name: 5
+ color: 5
+ weight: 2.0
+ '''
+ config.load(yaml_str=test_str)
+ assert_classes(config.dataset.classes)
def test_model_from_dict():
- config.reset()
+ config_reset()
test_str = '''
params:
v1 : 10
layers:
+ - Input:
+ shape: in_shape
- Flatten:
- input_shape: in_shape
- Dense:
units: v1
activation : relu
@@ -154,7 +196,7 @@ def test_model_from_dict():
input_shape = (17, 17, 8)
output_shape = 3
params_exposed = { 'out_shape' : output_shape, 'in_shape' : input_shape}
- model = model_parser.model_from_dict(yaml.safe_load(test_str), params_exposed)()
+ model = config_parser.model_from_dict(yaml.safe_load(test_str), params_exposed)()
model.compile(optimizer='adam', loss='mse')
assert model.input_shape[1:] == input_shape
@@ -162,13 +204,14 @@ def test_model_from_dict():
assert len(model.layers) == 4 # Input layer is added behind the scenes
def test_pretrained_layer():
- config.reset()
+ config_reset()
base_model = '''
params:
v1 : 10
layers:
+ - Input:
+ shape: in_shape
- Flatten:
- input_shape: in_shape
- Dense:
units: v1
activation : relu
@@ -180,7 +223,7 @@ def test_pretrained_layer():
input_shape = (17, 17, 8)
output_shape = 3
params_exposed = { 'out_shape' : output_shape, 'in_shape' : input_shape}
- m1 = model_parser.model_from_dict(yaml.safe_load(base_model), params_exposed)()
+ m1 = config_parser.model_from_dict(yaml.safe_load(base_model), params_exposed)()
m1.compile(optimizer='adam', loss='mse')
_, tmp_filename = tempfile.mkstemp(suffix='.h5')
@@ -190,6 +233,8 @@ def test_pretrained_layer():
params:
v1 : 10
layers:
+ - Input:
+ shape: in_shape
- Pretrained:
filename: %s
encoding_layer: encoding
@@ -200,89 +245,126 @@ def test_pretrained_layer():
units: out_shape
activation: softmax
''' % tmp_filename
- m2 = model_parser.model_from_dict(yaml.safe_load(pretrained_model), params_exposed)()
+ m2 = config_parser.model_from_dict(yaml.safe_load(pretrained_model), params_exposed)()
+ m2.compile(optimizer='adam', loss='mse')
+ assert len(m2.layers[1].layers) == 3
+ for i in range(1, len(m1.layers)):
+ assert isinstance(m1.layers[i], type(m2.layers[1].layers[i]))
+ if m1.layers[i].name == 'encoding':
+ break
+
+ # test using internal layer of pretrained as input
+ pretrained_model = '''
+ params:
+ v1 : 10
+ layers:
+ - Input:
+ shape: in_shape
+ - Pretrained:
+ filename: %s
+ encoding_layer: encoding
+ name: pretrained
+ outputs: [encoding]
+ - Dense:
+ units: 100
+ activation: relu
+ inputs: pretrained/encoding
+ - Dense:
+ units: out_shape
+ activation: softmax
+ ''' % tmp_filename
+ m2 = config_parser.model_from_dict(yaml.safe_load(pretrained_model), params_exposed)()
m2.compile(optimizer='adam', loss='mse')
- assert len(m2.layers[1].layers) == (len(m1.layers) - 2) # also don't take the input layer
+ assert len(m2.layers[1].layers) == (len(m1.layers) - 1) # also don't take the input layer
for i in range(1, len(m1.layers)):
- assert isinstance(m1.layers[i], type(m2.layers[1].layers[i - 1]))
+ assert isinstance(m1.layers[i], type(m2.layers[1].layers[i]))
if m1.layers[i].name == 'encoding':
break
os.remove(tmp_filename)
+def test_callbacks():
+ config_reset()
+ test_str = '''
+ train:
+ callbacks:
+ - EarlyStopping:
+ verbose: true
+ - ReduceLROnPlateau:
+ factor: 0.5
+ '''
+ config.load(yaml_str=test_str)
+ cbs = config_parser.config_callbacks()
+ assert len(cbs) == 2
+ assert isinstance(cbs[0], tf.keras.callbacks.EarlyStopping)
+ assert cbs[0].verbose
+ assert isinstance(cbs[1], tf.keras.callbacks.ReduceLROnPlateau)
+ assert cbs[1].factor == 0.5
+
def test_network_file():
- config.reset()
+ config_reset()
test_str = '''
dataset:
classes: 3
train:
network:
- chunk_size: 5
- model:
- yaml_file: networks/convpool.yaml
+ yaml_file: networks/convpool.yaml
'''
config.load(yaml_str=test_str)
- assert config.train.network.chunk_size() == 5
- model = model_parser.config_model(2)()
- assert model.input_shape == (None, config.train.network.chunk_size(), config.train.network.chunk_size(), 2)
- assert model.output_shape == (None, config.train.network.output_size(),
- config.train.network.output_size(), len(config.dataset.classes))
+ model = config_parser.config_model(2)()
+ assert model.input_shape == (None, 5, 5, 2)
+ assert model.output_shape == (None, 3, 3, 3)
def test_validate():
- config.reset()
+ config_reset()
test_str = '''
train:
- network:
- chunk_size: -1
+ stride: -1
'''
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
config.load(yaml_str=test_str)
- config.reset()
+ config_reset()
test_str = '''
train:
- network:
- chunk_size: string
+ stride: 0.5
'''
with pytest.raises(TypeError):
config.load(yaml_str=test_str)
def test_network_inline():
- config.reset()
+ config_reset()
test_str = '''
dataset:
classes: 3
train:
network:
- chunk_size: 5
- output_size: 1
- model:
- params:
- v1 : 10
- layers:
- - Flatten:
- input_shape: in_shape
- - Dense:
- units: v1
- activation : relu
- - Dense:
- units: out_dims
- activation : softmax
+ params:
+ v1 : 10
+ layers:
+ - Input:
+ shape: [5, 5, num_bands]
+ - Flatten:
+ - Dense:
+ units: v1
+ activation : relu
+ - Dense:
+ units: 3
+ activation : softmax
'''
config.load(yaml_str=test_str)
- assert config.train.network.chunk_size() == 5
assert len(config.dataset.classes) == 3
- model = model_parser.config_model(2)()
- assert model.input_shape == (None, config.train.network.chunk_size(), config.train.network.chunk_size(), 2)
+ model = config_parser.config_model(2)()
+ assert model.input_shape == (None, 5, 5, 2)
assert model.output_shape == (None, len(config.dataset.classes))
def test_train():
- config.reset()
+ config_reset()
test_str = '''
train:
- chunk_stride: 2
+ stride: 2
batch_size: 5
steps: 10
epochs: 3
- loss_function: loss
+ loss: SparseCategoricalCrossentropy
metrics: [metric]
optimizer: opt
validation:
@@ -291,20 +373,18 @@ def test_train():
'''
config.load(yaml_str=test_str)
tc = config.train.spec()
- assert tc.chunk_stride == 2
+ assert tc.stride == (2, 2)
assert tc.batch_size == 5
assert tc.steps == 10
assert tc.epochs == 3
- assert tc.loss_function == 'loss'
+ assert isinstance(config_parser.loss_from_dict(tc.loss), tf.keras.losses.SparseCategoricalCrossentropy)
assert tc.metrics == ['metric']
assert tc.optimizer == 'opt'
assert tc.validation.steps == 20
assert tc.validation.from_training
def test_mlflow():
- config.reset()
-
- assert config.mlflow.enabled()
+ config_reset()
test_str = '''
mlflow:
@@ -324,7 +404,7 @@ def test_mlflow():
assert config.mlflow.checkpoints.frequency() == 10
def test_tensorboard():
- config.reset()
+ config_reset()
assert not config.tensorboard.enabled()
@@ -339,24 +419,55 @@ def test_tensorboard():
assert config.tensorboard.dir() == 'nonsense'
def test_argparser():
- config.reset()
+ config_reset()
parser = argparse.ArgumentParser()
config.setup_arg_parser(parser)
file_path = os.path.join(os.path.dirname(__file__), 'data', 'landsat.tiff')
- options = parser.parse_args(('--chunk-size 5 --image-type tiff --image %s' % (file_path) +
+ options = parser.parse_args(('--image-type tiff --image %s' % (file_path) +
' --label-type tiff --label %s' % (file_path)).split())
config.parse_args(options)
- assert config.train.network.chunk_size() == 5
im = config.dataset.images()
assert im.preprocess() is not None
assert im.type() == 'tiff'
assert len(im) == 1
assert im[0].endswith('landsat.tiff') and os.path.exists(im[0])
im = config.dataset.labels()
- assert im.preprocess() is None
assert im.type() == 'tiff'
assert len(im) == 1
assert im[0].endswith('landsat.tiff') and os.path.exists(im[0])
+
+def test_argparser_config_file(tmp_path):
+ config_reset()
+
+ test_str = '''
+ tensorboard:
+ enabled: false
+ dir: nonsense
+ '''
+ p = tmp_path / "temp.yaml"
+ p.write_text(test_str)
+
+ parser = argparse.ArgumentParser()
+ config.setup_arg_parser(parser)
+ options = parser.parse_args(['--config', str(p)])
+ config.initialize(options, [])
+
+ assert not config.tensorboard.enabled()
+ assert config.tensorboard.dir() == 'nonsense'
+
+def test_missing_file():
+ config_reset()
+
+ parser = argparse.ArgumentParser()
+ config.setup_arg_parser(parser)
+ options = parser.parse_args(['--config', 'garbage.yaml'])
+ with pytest.raises(FileNotFoundError):
+ config.initialize(options, [])
+
+def test_dump():
+ config_reset()
+
+ assert config.to_dict() == yaml.load(config.export())
diff --git a/tests/test_extensions.py b/tests/test_extensions.py
new file mode 100644
index 00000000..f97b0a8f
--- /dev/null
+++ b/tests/test_extensions.py
@@ -0,0 +1,116 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#pylint:disable=redefined-outer-name
+"""
+Test for worldview class.
+"""
+import pytest
+import tensorflow as tf
+
+from conftest import config_reset
+
+from delta.config import config
+import delta.config.extensions as ext
+from delta.ml.train import ContinueTrainingException
+
+def test_efficientnet():
+ l = ext.layer('EfficientNet')
+ n = l((None, None, 8))
+ assert len(n.layers) == 334
+ assert n.layers[0].input_shape == [(None, None, None, 8)]
+ out_shape = n.compute_output_shape((None, 512, 512, 8)).as_list()
+ assert out_shape == [None, 16, 16, 352]
+
+def test_gaussian_sample():
+ l = ext.layer('GaussianSample')
+ n = l()
+ assert n.get_config()['kl_loss']
+ result = n((tf.zeros((1, 3, 3, 3)), tf.ones((1, 3, 3, 3))))
+ assert result.shape == (1, 3, 3, 3)
+ assert isinstance(n.callback(), tf.keras.callbacks.Callback)
+
+def test_ms_ssim():
+ l = ext.loss('ms_ssim')
+ assert l(tf.zeros((1, 180, 180, 1)), tf.zeros((1, 180, 180, 1))) == 0.0
+ l = ext.loss('ms_ssim_mse')
+ assert l(tf.zeros((1, 180, 180, 1)), tf.zeros((1, 180, 180, 1))) == 0.0
+
+def test_mapped():
+ mcce = ext.loss('MappedCategoricalCrossentropy')
+ z = tf.zeros((3, 3, 3, 3), dtype=tf.int32)
+ o = tf.ones((3, 3, 3, 3), dtype=tf.float32)
+ assert tf.reduce_sum(mcce([0, 0]).call(z, o)) == 0.0
+ assert tf.reduce_sum(mcce([1, 0]).call(z, o)) > 10.0
+ oo = tf.ones((3, 3, 3, 3, 2), dtype=tf.float32)
+ assert tf.reduce_sum(mcce([[0, 0], [1, 1]]).call(z, oo)) == 0.0
+ assert tf.reduce_sum(mcce([[1, 1], [0, 0]]).call(z, oo)) > 10.0
+
+ config_reset()
+ test_str = '''
+ dataset:
+ classes:
+ - 0:
+ name: class_0
+ - 1:
+ name: class_1
+ '''
+ config.load(yaml_str=test_str)
+
+ assert tf.reduce_sum(mcce({0: 0, 1:0}).call(z, o)) == 0.0
+ assert tf.reduce_sum(mcce({'class_0': 0, 'class_1':0}).call(z, o)) == 0.0
+ assert tf.reduce_sum(mcce({0:1, 1:0}).call(z, o)) > 10.0
+ assert tf.reduce_sum(mcce({'class_0': 1, 'class_1':0}).call(z, o)) > 10.0
+
+def test_sparse_recall():
+ m0 = ext.metric('SparseRecall')(0)
+ m1 = ext.metric('SparseRecall')(1)
+ z = tf.zeros((3, 3, 3, 3), dtype=tf.int32)
+ o = tf.ones((3, 3, 3, 3), dtype=tf.int32)
+
+ m0.reset_state()
+ m1.reset_state()
+ m0.update_state(z, z)
+ m1.update_state(z, z)
+ assert m0.result() == 1.0
+ assert m1.result() == 0.0
+
+ m0.reset_state()
+ m1.reset_state()
+ m0.update_state(o, z)
+ m1.update_state(o, z)
+ assert m0.result() == 0.0
+ assert m1.result() == 0.0
+
+def test_callbacks():
+ inputs = tf.keras.layers.Input((10, 10, 1))
+ out = tf.keras.layers.Conv2D(name='out', filters=16, kernel_size=3)(inputs)
+ m = tf.keras.Model(inputs, out)
+
+ c = ext.callback('SetTrainable')('out', 2)
+ c.model = m
+ out = m.get_layer('out')
+ out.trainable = False
+ assert not out.trainable
+ c.on_epoch_begin(0)
+ assert not out.trainable
+ with pytest.raises(ContinueTrainingException):
+ c.on_epoch_begin(1)
+ assert out.trainable
+
+ c = ext.callback('ExponentialLRScheduler')(start_epoch=2, multiplier=0.95)
+ assert isinstance(c, tf.keras.callbacks.LearningRateScheduler)
diff --git a/tests/test_imagery_dataset.py b/tests/test_imagery_dataset.py
index fece4a32..0c8c27a7 100644
--- a/tests/test_imagery_dataset.py
+++ b/tests/test_imagery_dataset.py
@@ -20,63 +20,27 @@
import pytest
import numpy as np
-import tensorflow as tf
-from tensorflow import keras
-
-from delta.config import config
-from delta.imagery import imagery_dataset
-from delta.imagery.sources import npy
-from delta.ml import train, predict
-from delta.ml.ml_config import TrainingSpec
import conftest
-def load_dataset(source, output_size):
- config.reset() # don't load any user files
- (image_path, label_path) = source[0]
- config.load(yaml_str=
- '''
- io:
- cache:
- dir: %s
- dataset:
- images:
- type: %s
- directory: %s
- extension: %s
- preprocess:
- enabled: false
- labels:
- type: %s
- directory: %s
- extension: %s
- preprocess:
- enabled: false
- train:
- network:
- chunk_size: 3
- mlflow:
- enabled: false''' %
- (os.path.dirname(image_path), source[2], os.path.dirname(image_path), source[1],
- source[4], os.path.dirname(label_path), source[3]))
-
- dataset = imagery_dataset.ImageryDataset(config.dataset.images(), config.dataset.labels(),
- config.train.network.chunk_size(), output_size,
- config.train.spec().chunk_stride)
- return dataset
-
-@pytest.fixture(scope="function", params=range(conftest.NUM_SOURCES))
-def dataset(all_sources, request):
- source = all_sources[request.param]
- return load_dataset(source, 1)
+from delta.config import config
+from delta.imagery import imagery_dataset, rectangle
-@pytest.fixture(scope="function")
-def dataset_block_label(all_sources):
- return load_dataset(all_sources[0], 3)
+def test_basics(dataset_block_label):
+ """
+ Tests basic methods of a dataset.
+ """
+ d = dataset_block_label
+ assert d.chunk_shape() == (3, 3)
+ assert d.input_shape() == (3, 3, 1)
+ assert d.output_shape() == (3, 3, 1)
+ assert len(d.image_set()) == len(d.label_set())
+ assert d.tile_shape() == [256, 1024]
+ assert d.tile_overlap() == (0, 0)
-def test_block_label(dataset_block_label): #pylint: disable=redefined-outer-name
+def test_block_label(dataset_block_label):
"""
- Same as previous test but with dataset that gives labels as 3x3 blocks.
+ Tests basic functionality of a dataset on 3x3 blocks.
"""
num_data = 0
for image in dataset_block_label.data():
@@ -116,31 +80,98 @@ def test_block_label(dataset_block_label): #pylint: disable=redefined-outer-name
if v6 or v7 or v8:
assert label[1, 1] == 0
-def test_train(dataset): #pylint: disable=redefined-outer-name
- def model_fn():
- kerasinput = keras.layers.Input((3, 3, 1))
- flat = keras.layers.Flatten()(kerasinput)
- dense2 = keras.layers.Dense(3 * 3, activation=tf.nn.relu)(flat)
- dense1 = keras.layers.Dense(2, activation=tf.nn.softmax)(dense2)
- reshape = keras.layers.Reshape((1, 1, 2))(dense1)
- return keras.Model(inputs=kerasinput, outputs=reshape)
- model, _ = train.train(model_fn, dataset,
- TrainingSpec(100, 5, 'sparse_categorical_crossentropy', ['accuracy']))
- ret = model.evaluate(x=dataset.dataset().batch(1000))
- assert ret[1] > 0.70
-
- (test_image, test_label) = conftest.generate_tile()
- test_label = test_label[1:-1, 1:-1]
- output_image = npy.NumpyImageWriter()
- predictor = predict.LabelPredictor(model, output_image=output_image)
- predictor.predict(npy.NumpyImage(test_image))
- # very easy test since we don't train much
- assert sum(sum(np.logical_xor(output_image.buffer()[:,:,0], test_label))) < 200
+def test_nodata(dataset_block_label):
+ """
+ Tests that this filters out blocks where labels are all 0.
+ """
+ dataset_block_label.label_set().set_nodata_value(0)
+ try:
+ ds = dataset_block_label.dataset()
+ for (_, label) in ds.take(100):
+ assert np.sum(label) > 0
+ finally:
+ dataset_block_label.label_set().set_nodata_value(None)
+
+def test_class_weights(dataset_block_label):
+ """
+ Tests that this filters out blocks where labels are all 0.
+ """
+ lookup = np.asarray([1.0, 2.0])
+ ds = dataset_block_label.dataset(class_weights=[1.0, 2.0])
+ for (_, label, weights) in ds.take(100):
+ assert np.all(lookup[label.numpy()] == weights)
+
+def test_rectangle():
+ """
+ Tests the Rectangle class basics.
+ """
+ r = rectangle.Rectangle(5, 10, 15, 30)
+ assert r.min_x == 5
+ assert r.min_y == 10
+ assert r.max_x == 15
+ assert r.max_y == 30
+ assert r.bounds() == (5, 15, 10, 30)
+ assert r.has_area()
+ assert r.get_min_coord() == (5, 10)
+ assert r.perimeter() == 60
+ assert r.area() == 200
+ r.shift(-5, -10)
+ assert r.bounds() == (0, 10, 0, 20)
+ r.scale_by_constant(2, 1)
+ assert r.bounds() == (0, 20, 0, 20)
+ r.expand(0, 0, -10, -5)
+ assert r.bounds() == (0, 10, 0, 15)
+ r.expand_to_contain_pt(14, 14)
+ assert r.bounds() == (0, 15, 0, 15)
+
+ r2 = rectangle.Rectangle(-5, -5, 5, 10)
+ assert r.get_intersection(r2).bounds() == (0, 5, 0, 10)
+ assert not r.contains_rect(r2)
+ assert r.overlaps(r2)
+ assert not r.contains_pt(-1, -1)
+ assert r2.contains_pt(-1, -1)
+ r.expand_to_contain_rect(r2)
+ assert r.bounds() == (-5, 15, -5, 15)
+
+def test_rectangle_rois():
+ """
+ Tests make_tile_rois.
+ """
+ r = rectangle.Rectangle(0, 0, 10, 10)
+ tiles = r.make_tile_rois((5, 5), include_partials=False)
+ assert len(tiles) == 4
+ for t in tiles:
+ assert t.width() == 5 and t.height() == 5
+ tiles = r.make_tile_rois((5, 10), include_partials=False)
+ assert len(tiles) == 2
+ tiles = r.make_tile_rois((11, 11), include_partials=False)
+ assert len(tiles) == 0
+ tiles = r.make_tile_rois((11, 11), include_partials=True)
+ assert len(tiles) == 1
+ assert tiles[0].bounds() == (0, 10, 0, 10)
+ tiles = r.make_tile_rois((20, 20), include_partials=True, min_shape=(11, 11))
+ assert len(tiles) == 0
+ tiles = r.make_tile_rois((20, 20), include_partials=True, min_shape=(10, 10))
+ assert len(tiles) == 1
+
+ tiles = r.make_tile_rois((6, 6), include_partials=False)
+ assert len(tiles) == 1
+ tiles = r.make_tile_rois((6, 6), include_partials=False, overlap_shape=(2, 2))
+ assert len(tiles) == 4
+ tiles = r.make_tile_rois((6, 6), include_partials=False, partials_overlap=True)
+ assert len(tiles) == 4
+ for t in tiles:
+ assert t.width() == 6 and t.height() == 6
+
+ tiles = r.make_tile_rois((5, 5), include_partials=False, by_block=True)
+ assert len(tiles) == 2
+ for row in tiles:
+ assert len(row) == 2
@pytest.fixture(scope="function")
def autoencoder(all_sources):
source = all_sources[0]
- config.reset() # don't load any user files
+ conftest.config_reset()
(image_path, _) = source[0]
config.load(yaml_str=
'''
@@ -152,23 +183,36 @@ def autoencoder(all_sources):
type: %s
directory: %s
extension: %s
- preprocess:
- enabled: false
- train:
- network:
- chunk_size: 3
- mlflow:
- enabled: false''' %
+ preprocess: ~''' %
(os.path.dirname(image_path), source[2], os.path.dirname(image_path), source[1]))
dataset = imagery_dataset.AutoencoderDataset(config.dataset.images(),
- config.train.network.chunk_size(), config.train.spec().chunk_stride)
+ (3, 3), stride=config.train.spec().stride)
return dataset
-def test_autoencoder(autoencoder): #pylint: disable=redefined-outer-name
+def test_autoencoder(autoencoder):
"""
Test that the inputs and outputs of the dataset are the same.
"""
ds = autoencoder.dataset()
for (image, label) in ds.take(1000):
assert (image.numpy() == label.numpy()).all()
+
+def test_resume_mode(autoencoder, tmpdir):
+ """
+ Test imagery dataset's resume functionality.
+ """
+ try:
+ autoencoder.set_resume_mode(True, str(tmpdir))
+ autoencoder.reset_access_counts()
+ for i in range(len(autoencoder.image_set())):
+ autoencoder.resume_log_update(i, count=10000, need_check=True)
+ assert autoencoder.resume_log_read(i) == (True, 10000)
+
+ ds = autoencoder.dataset()
+ count = 0
+ for (_, unused_) in ds.take(100):
+ count += 1
+ assert count == 0
+ finally:
+ autoencoder.set_resume_mode(False, None)
diff --git a/tests/test_sources.py b/tests/test_sources.py
new file mode 100644
index 00000000..1f9f6aa3
--- /dev/null
+++ b/tests/test_sources.py
@@ -0,0 +1,95 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#pylint:disable=redefined-outer-name, protected-access
+"""
+Test for worldview class.
+"""
+import os
+import pytest
+
+import numpy as np
+
+from delta.extensions.sources import landsat, worldview
+
+@pytest.fixture(scope="function")
+def wv_image(worldview_filenames):
+ return worldview.WorldviewImage(worldview_filenames[0])
+
+@pytest.fixture(scope="function")
+def landsat_image(landsat_filenames):
+ return landsat.LandsatImage(landsat_filenames[0], bands=[1])
+
+# very basic, doesn't actually look at content
+def test_wv_image(wv_image):
+ buf = wv_image.read()
+ assert buf.shape == (64, 32, 1)
+ assert buf[0, 0, 0] == 0.0
+
+ assert wv_image.meta_path() is not None
+ assert len(wv_image.scale()) == 1
+ assert len(wv_image.bandwidth()) == 1
+
+ worldview.toa_preprocess(wv_image, calc_reflectance=False)
+ buf = wv_image.read()
+ assert buf.shape == (64, 32, 1)
+ assert buf[0, 0, 0] == 0.0
+
+def test_landsat_image(landsat_image):
+ buf = landsat_image.read()
+ assert buf.shape == (64, 32, 1)
+ assert buf[0, 0, 0] == 0.0
+
+ assert landsat_image.radiance_mult()[0] == 2.0
+ assert landsat_image.radiance_add()[0] == 2.0
+ assert landsat_image.reflectance_mult()[0] == 2.0
+ assert landsat_image.reflectance_add()[0] == 2.0
+ assert landsat_image.k1_constant()[0] == 2.0
+ assert landsat_image.k2_constant()[0] == 2.0
+ assert landsat_image.sun_elevation() == 5.8
+
+ landsat.toa_preprocess(landsat_image, calc_reflectance=True)
+ buf = landsat_image.read()
+ assert buf.shape == (64, 32, 1)
+ assert buf[0, 0, 0] == 0.0
+
+ landsat.toa_preprocess(landsat_image)
+ buf = landsat_image.read()
+ assert buf.shape == (64, 32, 1)
+ assert buf[0, 0, 0] == 0.0
+
+def test_wv_cache(wv_image):
+ buf = wv_image.read()
+ cached_path = wv_image._paths[0]
+ mod_time = os.path.getmtime(cached_path)
+ path = wv_image.path()
+ new_image = worldview.WorldviewImage(path)
+ buf2 = wv_image.read()
+ assert np.all(buf == buf2)
+ assert new_image._paths[0] == cached_path
+ assert os.path.getmtime(cached_path) == mod_time
+
+def test_landsat_cache(landsat_image):
+ buf = landsat_image.read()
+ cached_path = landsat_image._paths[0]
+ mod_time = os.path.getmtime(cached_path)
+ path = landsat_image.path()
+ new_image = landsat.LandsatImage(path)
+ buf2 = landsat_image.read()
+ assert np.all(buf == buf2)
+ assert new_image._paths[0] == cached_path
+ assert os.path.getmtime(cached_path) == mod_time
diff --git a/tests/test_tiff.py b/tests/test_tiff.py
index 09d8045a..e55534fc 100644
--- a/tests/test_tiff.py
+++ b/tests/test_tiff.py
@@ -23,7 +23,7 @@
import numpy as np
from delta.imagery import rectangle
-from delta.imagery.sources.tiff import TiffImage, write_tiff
+from delta.extensions.sources.tiff import TiffImage, TiffWriter, write_tiff
def check_landsat_tiff(filename):
'''
@@ -32,13 +32,8 @@ def check_landsat_tiff(filename):
input_reader = TiffImage(filename)
assert input_reader.size() == (37, 37)
assert input_reader.num_bands() == 8
- for i in range(0, input_reader.num_bands()):
- (bsize, (blocks_x, blocks_y)) = input_reader.block_info(i)
- assert bsize == (6, 37)
- assert blocks_x == 7
- assert blocks_y == 1
- assert input_reader.numpy_type(i) == np.float32
- assert input_reader.nodata_value(i) is None
+ assert input_reader.dtype() == np.float32
+ assert input_reader.block_size() == (6, 37)
meta = input_reader.metadata()
geo = meta['geotransform']
@@ -66,11 +61,10 @@ def check_same(filename1, filename2, data_only=False):
in2 = TiffImage(filename2)
assert in1.size() == in2.size()
assert in1.num_bands() == in2.num_bands()
- for i in range(in1.num_bands()):
- if not data_only:
- assert in1.block_info(i) == in2.block_info(i)
- assert in1.data_type(i) == in2.data_type(i)
- assert in1.nodata_value(i) == in2.nodata_value(i)
+ assert in1.dtype() == in2.dtype()
+ if not data_only:
+ assert in1.block_size() == in2.block_size()
+ assert in1.nodata_value() == in2.nodata_value()
if not data_only:
m_1 = in1.metadata()
@@ -129,3 +123,14 @@ def test_geotiff_write(tmpdir):
assert numpy_image.shape == data.shape
assert np.allclose(numpy_image, data)
+
+ writer = TiffWriter(filename)
+ writer.initialize((3, 5, 1), numpy_image.dtype)
+ writer.write(numpy_image, 0, 0)
+ writer.close()
+
+ img = TiffImage(filename)
+ data = np.squeeze(img.read())
+
+ assert numpy_image.shape == data.shape
+ assert np.allclose(numpy_image, data)
diff --git a/tests/test_train.py b/tests/test_train.py
new file mode 100644
index 00000000..edd5a6d5
--- /dev/null
+++ b/tests/test_train.py
@@ -0,0 +1,148 @@
+# Copyright © 2020, United States Government, as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All rights reserved.
+#
+# The DELTA (Deep Earth Learning, Tools, and Analysis) platform is
+# licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+#pylint: disable=redefined-outer-name
+import os
+import shutil
+import tempfile
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+
+import conftest
+import h5py
+
+from delta.config import config
+from delta.extensions.sources import npy
+from delta.ml import train, predict, io
+from delta.extensions.layers.pretrained import pretrained
+from delta.ml.ml_config import TrainingSpec
+
+def evaluate_model(model_fn, dataset, output_trim=0, threshold=0.3, max_wrong=200, batch_size=10):
+ model, _ = train.train(model_fn, dataset,
+ TrainingSpec(batch_size, 5, 'sparse_categorical_crossentropy',
+ ['sparse_categorical_accuracy']))
+ ret = model.evaluate(x=dataset.dataset().batch(1000))
+ assert ret[1] > threshold # very loose test since not much training
+
+ (test_image, test_label) = conftest.generate_tile()
+ if output_trim > 0:
+ test_label = test_label[output_trim:-output_trim, output_trim:-output_trim]
+ output_image = npy.NumpyWriter()
+ predictor = predict.LabelPredictor(model, output_image=output_image)
+ predictor.predict(npy.NumpyImage(test_image))
+ # very easy test since we don't train much
+ assert sum(sum(np.logical_xor(output_image.buffer()[:,:], test_label))) < max_wrong
+
+def train_ae(ae_fn, ae_dataset):
+ model, _ = train.train(ae_fn, ae_dataset,
+ TrainingSpec(100, 5, 'mse', ['Accuracy']))
+
+ tmpdir = tempfile.mkdtemp()
+ model_path = os.path.join(tmpdir, 'ae_model.h5')
+ model.save(model_path)
+ return model_path, tmpdir
+
+def test_dense(dataset):
+ def model_fn():
+ kerasinput = keras.layers.Input((3, 3, 1))
+ flat = keras.layers.Flatten()(kerasinput)
+ dense2 = keras.layers.Dense(3 * 3, activation=tf.nn.relu)(flat)
+ dense1 = keras.layers.Dense(2, activation=tf.nn.softmax)(dense2)
+ reshape = keras.layers.Reshape((1, 1, 2))(dense1)
+ return keras.Model(inputs=kerasinput, outputs=reshape)
+ evaluate_model(model_fn, dataset, 1)
+
+def test_pretrained(dataset, ae_dataset):
+ # 1 create autoencoder
+ ae_dataset.set_chunk_output_shapes((10, 10), (10, 10))
+ def autoencoder_fn():
+ inputs = keras.layers.Input((10, 10, 1))
+ conv1 = keras.layers.Conv2D(filters=16, kernel_size=3, activation='relu', padding='same')(inputs)
+ down_samp1 = keras.layers.MaxPooling2D((2, 2), padding='same')(conv1)
+ encoded = keras.layers.Conv2D(filters=8, kernel_size=3, activation='relu', padding='same')(down_samp1)
+
+ # at this point the representation is (4, 4, 8) i.e. 128-dimensional
+
+ up_samp1 = keras.layers.UpSampling2D((2, 2))(encoded)
+ conv4 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(up_samp1)
+ decoded = keras.layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(conv4)
+
+ return keras.Model(inputs=inputs, outputs=decoded)
+ # 2 train autoencoder
+ ae_model, tmpdir = train_ae(autoencoder_fn, ae_dataset)
+ # 3 create model network based on autonecoder.
+ def model_fn():
+ inputs = keras.layers.Input((10, 10, 1))
+ pretrained_layer = pretrained(ae_model, 3, trainable=False)(inputs)[0]
+ up_samp1 = keras.layers.UpSampling2D((2,2))(pretrained_layer)
+ conv1 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(up_samp1)
+ output = keras.layers.Conv2D(2, (3,3), activation='softmax', padding='same')(conv1)
+ m = keras.Model(inputs=inputs, outputs=output)
+
+ return m
+
+ dataset.set_chunk_output_shapes((10, 10), (10, 10))
+ evaluate_model(model_fn, dataset)
+ shutil.rmtree(tmpdir)
+
+def test_fcn(dataset):
+ conftest.config_reset()
+
+ assert config.general.gpus() == -1
+
+ def model_fn():
+ inputs = keras.layers.Input((None, None, 1))
+ conv = keras.layers.Conv2D(filters=9, kernel_size=2, padding='same', strides=1)(inputs)
+ upscore = keras.layers.Conv2D(filters=2, kernel_size=1, padding='same', strides=1)(conv)
+ l = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(conv)
+ l = keras.layers.Conv2D(filters=2, kernel_size=1, strides=1)(l)
+ l = keras.layers.Conv2DTranspose(filters=2, padding='same', kernel_size=2, strides=2)(l)
+ l = keras.layers.Add()([upscore, upscore])
+ #l = keras.layers.Softmax(axis=3)(l)
+ m = keras.Model(inputs=inputs, outputs=l)
+ return m
+ dataset.set_chunk_output_shapes(None, (32, 32))
+ dataset.set_tile_shape((32, 32))
+ count = 0
+ for d in dataset.dataset():
+ count += 1
+ assert len(d) == 2
+ assert d[0].shape == (32, 32, 1)
+ assert d[1].shape == (32, 32, 1)
+ assert count == 2
+ # don't actually test correctness, this is not enough data for this size network
+ evaluate_model(model_fn, dataset, threshold=0.0, max_wrong=10000, batch_size=1)
+
+def test_save(tmp_path):
+ tmp_path = tmp_path / 'temp.h5'
+ inputs = keras.layers.Input((None, None, 1))
+ out = keras.layers.Conv2D(filters=9, kernel_size=2)(inputs)
+ m = keras.Model(inputs, out)
+ io.save_model(m, tmp_path)
+ with h5py.File(tmp_path, 'r') as f:
+ assert f.attrs['delta'] == config.export()
+
+def test_print():
+ """
+ Just make sure the printing functions in io don't crash.
+ """
+ inputs = keras.layers.Input((None, None, 1))
+ out = keras.layers.Conv2D(filters=9, kernel_size=2)(inputs)
+ m = keras.Model(inputs, out)
+ io.print_network(m, tile_shape=(512, 512))