Skip to content

Commit

Permalink
add inference api
Browse files Browse the repository at this point in the history
  • Loading branch information
likyoo committed Sep 23, 2023
1 parent b992f0f commit fc84ccd
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 5 deletions.
44 changes: 44 additions & 0 deletions docs/inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Inference with existing models

Open-CD provides pre-trained models for change detection in the corresponding presentation files, and supports multiple standard datasets, including LEVIR-CD, S2Looking, etc.
This note will show how to use existing models to inference on given images.

Open-CD provides several interfaces for users to easily use pre-trained models for inference.

- [Inference with existing models](#inference-with-existing-models)
- [Inferencer](#inferencer)
- [Basic Usage](#basic-usage)
- [Initialization](#initialization)
- [Visualize prediction](#visualize-prediction)
- [List model](#list-model)

## Inferencer

We provide the most **convenient** way to use the model in Open-CD `OpenCDInferencer`. You can get change mask for an image with only 3 lines of code.

### Basic Usage

The following example shows how to use `OpenCDInferencer` to perform inference on a single image pair.

```
>>> from opencd.apis import OpenCDInferencer
>>> # Load models into memory
>>> inferencer = OpenCDInferencer(model='changer_ex_r18_512x512_40k_levircd.py', weights='ChangerEx_r18-512x512_40k_levircd_20221223_120511.pth', classes=('unchanged', 'changed'), palette=[[0, 0, 0], [255, 255, 255]])
>>> # Inference
>>> inferencer([['demo_A.png', 'demo_B.png']], show=False, out_dir='OUTPUT_PATH')
```

Moreover, you can use `OpenCDInferencer` to process a list of images:

```
# Input a list of images
>>> images = [[image1_A, image1_B], [image2_A, image2_B], ...] # image1_A can be a file path or a np.ndarray
>>> inferencer(images, show=True, wait_time=0.5) # wait_time is delay time, and 0 means forever
# Save visualized rendering color maps and predicted results
# out_dir is the directory to save the output results, img_out_dir and pred_out_dir are subdirectories of out_dir
# to save visualized rendering color maps and predicted results
>>> inferencer(images, out_dir='outputs', img_out_dir='vis', pred_out_dir='pred')
```

4 changes: 4 additions & 0 deletions opencd/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Open-CD. All rights reserved.
from .opencd_inferencer import OpenCDInferencer

__all__ = ['OpenCDInferencer']
168 changes: 168 additions & 0 deletions opencd/apis/opencd_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) Open-CD. All rights reserved.
import os.path as osp
from typing import List, Optional, Union

import mmcv
import mmengine
import numpy as np

from mmcv.transforms import Compose

from mmseg.utils import ConfigType
from mmseg.apis import MMSegInferencer

class OpenCDInferencer(MMSegInferencer):
"""Change Detection inferencer, provides inference and visualization
interfaces. Note: MMEngine >= 0.5.0 is required.
Args:
classes (list, optional): Input classes for result rendering, as the
prediction of segmentation model is a segment map with label
indices, `classes` is a list which includes items responding to the
label indices. If classes is not defined, visualizer will take
`cityscapes` classes by default. Defaults to None.
palette (list, optional): Input palette for result rendering, which is
a list of color palette responding to the classes. If palette is
not defined, visualizer will take `cityscapes` palette by default.
Defaults to None.
dataset_name (str, optional): `Dataset name or alias.
visulizer will use the meta information of the dataset i.e. classes
and palette, but the `classes` and `palette` have higher priority.
Defaults to None.
scope (str, optional): The scope of the model. Defaults to 'opencd'.
""" # noqa

def __init__(self,
classes: Optional[Union[str, List]] = None,
palette: Optional[Union[str, List]] = None,
dataset_name: Optional[str] = None,
scope: Optional[str] = 'opencd',
**kwargs) -> None:
super().__init__(scope=scope, **kwargs)

classes = classes if classes else self.model.dataset_meta.classes
palette = palette if palette else self.model.dataset_meta.palette
self.visualizer.set_dataset_meta(classes, palette, dataset_name)

def _inputs_to_list(self, inputs: Union[str, np.ndarray]) -> list:
"""Preprocess the inputs to a list.
Preprocess inputs to a list according to its type:
- list or tuple: return inputs
- str:
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.
Args:
inputs (InputsType): Inputs for the inferencer.
Returns:
list: List of input for the :meth:`preprocess`.
"""
return list(inputs)

def visualize(self,
inputs: list,
preds: List[dict],
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
img_out_dir: str = '',
opacity: float = 1.0) -> List[np.ndarray]:
"""Visualize predictions.
Args:
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
preds (Any): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
img_out_dir (str): Output directory of rendering prediction i.e.
color segmentation mask. Defaults: ''
opacity (int, float): The transparency of segmentation mask.
Defaults to 0.8.
Returns:
List[np.ndarray]: Visualization results.
"""
if not show and img_out_dir == '' and not return_vis:
return None
if self.visualizer is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')

self.visualizer.alpha = opacity

results = []

for single_inputs, pred in zip(inputs, preds):
img_from_to = []
for single_input in single_inputs:
if isinstance(single_input, str):
img_bytes = mmengine.fileio.get(single_input)
img = mmcv.imfrombytes(img_bytes)
img = img[:, :, ::-1]
img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray):
img = single_input.copy()
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type:'
f'{type(single_input)}')
img_shape = img.shape
img_from_to.append(img)

out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
else None

img_zero_board = np.zeros(img_shape)
self.visualizer.add_datasample(
img_name,
img_zero_board,
img_from_to,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=True,
out_file=out_file)
if return_vis:
results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1

return results if return_vis else None

def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline.
Return a pipeline to handle various input data, such as ``str``,
``np.ndarray``. It is an abstract method in BaseInferencer, and should
be implemented in subclasses.
The returned pipeline will be used to process a single data.
It will be used in :meth:`preprocess` like this:
.. code-block:: python
def preprocess(self, inputs, batch_size, **kwargs):
...
dataset = map(self.pipeline, dataset)
...
"""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
# Loading annotations is also not applicable
for transform in ('MultiImgLoadAnnotations', 'MultiImgLoadDepthAnnotation'):
idx = self._get_transform_idx(pipeline_cfg, transform)
if idx != -1:
del pipeline_cfg[idx]

load_img_idx = self._get_transform_idx(pipeline_cfg,
'MultiImgLoadImageFromFile')
if load_img_idx == -1:
raise ValueError(
'MultiImgLoadImageFromFile is not found in the test pipeline')
pipeline_cfg[load_img_idx]['type'] = 'MultiImgLoadInferencerLoader'
return Compose(pipeline_cfg)
13 changes: 8 additions & 5 deletions opencd/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,9 @@ class MultiImgLoadInferencerLoader(BaseTransform):
def __init__(self, **kwargs) -> None:
super().__init__()
self.from_file = TRANSFORMS.build(
dict(type='LoadImageFromFile', **kwargs))
dict(type='MultiImgLoadImageFromFile', **kwargs))
self.from_ndarray = TRANSFORMS.build(
dict(type='LoadImageFromNDArray', **kwargs))
dict(type='MultiImgLoadLoadImageFromNDArray', **kwargs))

def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
"""Transform function to add image meta information.
Expand All @@ -436,11 +436,14 @@ def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
Returns:
dict: The dict contains loaded image and meta information.
"""
if isinstance(single_input, str):
assert len(single_input) == 2, \
'In `MultiImgLoadInferencerLoader`,' \
'`single_input` contains bi-temporal images'
if isinstance(single_input[0], str):
inputs = dict(img_path=single_input)
elif isinstance(single_input, Union[np.ndarray, list]):
elif isinstance(single_input[0], Union[np.ndarray, list]):
inputs = dict(img=single_input)
elif isinstance(single_input, dict):
elif isinstance(single_input[0], dict):
inputs = single_input
else:
raise NotImplementedError
Expand Down

0 comments on commit fc84ccd

Please sign in to comment.