Skip to content

Commit

Permalink
Better scale float-valued tiles
Browse files Browse the repository at this point in the history
If scaling non uint8 tiles in numpy format, possibly use scikit-image
rather than converting to an 8-bit pil image for the process.

Add some missing type hints.
  • Loading branch information
manthey committed Nov 13, 2024
1 parent a71b018 commit b55d5fe
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 37 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Guard dtype types ([#1711](../../pull/1711), [#1714](../../pull/1714), [#1716](../../pull/1716))
- Better handle IndicaLabs tiff files ([#1717](../../pull/1717))
- Better detect files with geotransform data that aren't geospatial ([#1718](../../pull/1718))
- Better scale float-valued tiles ([#1725](../../pull/1725))

### Changes

Expand Down
37 changes: 21 additions & 16 deletions large_image/tilesource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self, encoding: str = 'JPEG', jpegQuality: int = 95,
self.edge = edge
self._setStyle(style)

def __getstate__(self):
def __getstate__(self) -> None:
"""
Allow pickling.
Expand All @@ -203,7 +203,7 @@ def __reduce__(self) -> Tuple[functools.partial, Tuple[str]]:
def __repr__(self) -> str:
return self.getState()

def _repr_png_(self):
def _repr_png_(self) -> bytes:
return self.getThumbnail(encoding='PNG')[0]

@property
Expand Down Expand Up @@ -257,12 +257,12 @@ def getCenter(self, *args, **kwargs) -> Tuple[float, float]:
return (bounds['sizeY'] / 2, bounds['sizeX'] / 2)

@property
def style(self):
def style(self) -> Optional[JSONDict]:
return self._style

@style.setter
def style(self, value):
if not hasattr(self, '_unstyledStyle') and value == getattr(self, '_unstyledStyle', None):
def style(self, value: Any) -> None:
if value is None and not hasattr(self, '_unstyledStyle'):
return
if not getattr(self, '_noCache', False):
msg = 'Cannot set the style of a cached source'
Expand Down Expand Up @@ -1139,8 +1139,9 @@ def _outputTileNumpyStyle(
"""
tile, mode = _imageToNumpy(intile)
if (applyStyle and (getattr(self, 'style', None) or hasattr(self, '_iccprofiles')) and
(not getattr(self, 'style', None) or len(self.style) != 1 or
self.style.get('icc') is not False)):
(not getattr(self, 'style', None) or
len(cast(JSONDict, self.style)) != 1 or
cast(JSONDict, self.style).get('icc') is not False)):
tile = self._applyStyle(tile, getattr(self, 'style', None), x, y, z, frame)
if tile.shape[0] != self.tileHeight or tile.shape[1] != self.tileWidth:
extend = np.zeros(
Expand Down Expand Up @@ -1242,7 +1243,7 @@ def _getAssociatedImage(self, imageKey: str) -> Optional[PIL.Image.Image]:
return None

@classmethod
def canRead(cls, *args, **kwargs):
def canRead(cls, *args, **kwargs) -> bool:
"""
Check if we can read the input. This takes the same parameters as
__init__.
Expand Down Expand Up @@ -1315,7 +1316,7 @@ def getMetadata(self) -> JSONDict:
def metadata(self) -> JSONDict:
return self.getMetadata()

def _getFrameValueInformation(self, frames: List[Dict]):
def _getFrameValueInformation(self, frames: List[Dict]) -> Dict[str, Any]:
"""
Given a `frames` list from a metadata response, return a dictionary describing
the value info for any frame axes. Keys in this dictionary follow the pattern "Value[AXIS]"
Expand Down Expand Up @@ -1412,7 +1413,7 @@ def _addMetadataFrameInformation(
for frame in metadata['frames']:
frame['Channel'] = channels[frame.get('IndexC', 0)]

def getInternalMetadata(self, **kwargs):
def getInternalMetadata(self, **kwargs) -> Optional[Dict[Any, Any]]:
"""
Return additional known metadata about the tile source. Data returned
from this method is not guaranteed to be in any particular format or
Expand Down Expand Up @@ -1485,10 +1486,12 @@ def _getFrame(self, frame: Optional[int] = None, **kwargs) -> int:
:returns: an integer frame number.
"""
frame = int(frame or 0)
if (hasattr(self, '_style') and 'bands' in self.style and
len(self.style['bands']) and
all(entry.get('frame') is not None for entry in self.style['bands'])):
frame = int(self.style['bands'][0]['frame'])
if (hasattr(self, '_style') and
'bands' in cast(JSONDict, self.style) and
len(cast(JSONDict, self.style)['bands']) and
all(entry.get('frame') is not None
for entry in cast(JSONDict, self.style)['bands'])):
frame = int(cast(JSONDict, self.style)['bands'][0]['frame'])
return frame

def _xyzInRange(
Expand Down Expand Up @@ -1627,8 +1630,10 @@ def _getTileFromEmptyLevel(self, x: int, y: int, z: int, **kwargs) -> Tuple[
getattr(PIL.Image, 'Resampling', PIL.Image).LANCZOS).convert(mode), TILE_FORMAT_PIL

@methodcache()
def getTile(self, x, y, z, pilImageAllowed=False, numpyAllowed=False,
sparseFallback=False, frame=None):
def getTile(self, x: int, y: int, z: int, pilImageAllowed: bool = False,
numpyAllowed: Union[bool, str] = False,
sparseFallback: bool = False, frame: Optional[int] = None) -> Union[
ImageBytes, PIL.Image.Image, bytes, np.ndarray]:
"""
Get a tile from a tile source, returning it as an binary image, a PIL
image, or a numpy array.
Expand Down
6 changes: 4 additions & 2 deletions large_image/tilesource/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def _styleBands(self) -> List[Dict[str, Any]]:
}
style = []
if hasattr(self, '_style'):
styleBands = self.style['bands'] if 'bands' in self.style else [self.style]
styleBands = (cast(JSONDict, self.style)['bands']
if 'bands' in cast(JSONDict, self.style) else [self.style])
for styleBand in styleBands:

styleBand = styleBand.copy()
Expand Down Expand Up @@ -189,7 +190,8 @@ def _setDefaultStyle(self) -> None:
not self._style or 'icc' in self._style and len(self._style) == 1):
return
if hasattr(self, '_style'):
styleBands = self.style['bands'] if 'bands' in self.style else [self.style]
styleBands = (cast(JSONDict, self.style)['bands']
if 'bands' in cast(JSONDict, self.style) else [self.style])
if not len(styleBands) or (len(styleBands) == 1 and isinstance(
styleBands[0].get('band', 1), int) and styleBands[0].get('band', 1) <= 0):
del self._style
Expand Down
57 changes: 45 additions & 12 deletions large_image/tilesource/tiledict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, cast
from typing import Any, Dict, Optional, Tuple, Union, cast

import numpy as np
import PIL
Expand All @@ -8,7 +8,7 @@

from .. import exceptions
from ..constants import TILE_FORMAT_IMAGE, TILE_FORMAT_NUMPY, TILE_FORMAT_PIL
from .utilities import _encodeImage, _imageToNumpy, _imageToPIL
from .utilities import ImageBytes, _encodeImage, _imageToNumpy, _imageToPIL


class LazyTileDict(dict):
Expand Down Expand Up @@ -159,6 +159,48 @@ def _retileTile(self) -> np.ndarray:
:th, :tw, :retile.shape[2]] # type: ignore[misc]
return cast(np.ndarray, retile)

def _resample(self, tileData: Union[ImageBytes, PIL.Image.Image, bytes, np.ndarray]) -> Tuple[
Union[ImageBytes, PIL.Image.Image, bytes, np.ndarray], Optional[PIL.Image.Image],
]:
"""
If we need to resample a tile, use PIL if it is uint8 or we are using
a specific resampling mode that is PIL-specific. Otherwise, use
skimage if available.
:param tileData: the image to scale.
:returns: tileData, pilData. pilData will be None if the results are a
numpy array.
"""
pilData = None
if self.resample in (False, None) or not self.requestedScale:
return tileData, pilData

Check warning on line 176 in large_image/tilesource/tiledict.py

View check run for this annotation

Codecov / codecov/patch

large_image/tilesource/tiledict.py#L176

Added line #L176 was not covered by tests
pilResize = True
if (isinstance(tileData, np.ndarray) and tileData.dtype.kind != np.uint8 and
TILE_FORMAT_NUMPY in self.format and self.resample in {True, 2, 3}):
try:
import skimage.transform
pilResize = False
except ImportError:
pass

Check warning on line 184 in large_image/tilesource/tiledict.py

View check run for this annotation

Codecov / codecov/patch

large_image/tilesource/tiledict.py#L183-L184

Added lines #L183 - L184 were not covered by tests
if pilResize:
pilData = _imageToPIL(tileData)

self['width'] = max(1, int(
pilData.size[0] / self.requestedScale))
self['height'] = max(1, int(
pilData.size[1] / self.requestedScale))
pilData = tileData = pilData.resize(
(self['width'], self['height']),
resample=getattr(PIL.Image, 'Resampling', PIL.Image).LANCZOS
if self.resample is True else self.resample)
else:
tileData = skimage.transform.resize(
cast(np.ndarray, tileData),
(self['width'], self['height'],
cast(np.ndarray, tileData).shape[2]), # type: ignore[misc]
order=3 if self.resample is True else self.resample)
return tileData, pilData

def __getitem__(self, key: str, *args, **kwargs) -> Any:
"""
If this is the first time either the tile or format key is requested,
Expand Down Expand Up @@ -187,16 +229,7 @@ def __getitem__(self, key: str, *args, **kwargs) -> Any:
pilData = None
# resample if needed
if self.resample not in (False, None) and self.requestedScale:
pilData = _imageToPIL(tileData)

self['width'] = max(1, int(
pilData.size[0] / self.requestedScale))
self['height'] = max(1, int(
pilData.size[1] / self.requestedScale))
pilData = tileData = pilData.resize(
(self['width'], self['height']),
resample=getattr(PIL.Image, 'Resampling', PIL.Image).LANCZOS
if self.resample is True else self.resample)
tileData, pilData = self._resample(tileData)

tileFormat = (TILE_FORMAT_PIL if isinstance(tileData, PIL.Image.Image)
else (TILE_FORMAT_NUMPY if isinstance(tileData, np.ndarray)
Expand Down
18 changes: 11 additions & 7 deletions test/lisource_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ def source_compare(sourcePath, opts): # noqa
sys.stdout.flush()

# get maxval for other histograms
h = ts.histogram(onlyMinMax=True, output=dict(maxWidth=2048, maxHeight=2048), **kwargs)
h = ts.histogram(
onlyMinMax=True, output=dict(maxWidth=2048, maxHeight=2048),
resample=0, **kwargs)
if 'max' not in h:
sys.stdout.write(' fail\n')
sys.stdout.flush()
Expand All @@ -400,7 +402,7 @@ def source_compare(sourcePath, opts): # noqa
maxval = 2 ** (int(math.log(maxval or 1) / math.log(2)) + 1) if maxval > 1 else 1
# thumbnail histogram
h = ts.histogram(bins=9, output=dict(maxWidth=256, maxHeight=256),
range=[0, maxval], **kwargs)
range=[0, maxval], resample=0, **kwargs)
maxchan = len(h['histogram'])
if maxchan == 4:
maxchan = 3
Expand All @@ -409,13 +411,13 @@ def source_compare(sourcePath, opts): # noqa
sys.stdout.flush()
# full image histogram
h = ts.histogram(bins=9, output=dict(maxWidth=2048, maxHeight=2048),
range=[0, maxval], **kwargs)
range=[0, maxval], resample=0, **kwargs)
result['full_2048_histogram'] = histotext(h, maxchan)
sys.stdout.write(' %s' % histotext(h, maxchan))
sys.stdout.flush()
if opts.full:
# at full res
h = ts.histogram(bins=9, range=[0, maxval], **kwargs)
h = ts.histogram(bins=9, range=[0, maxval], resample=0, **kwargs)
result['full_max_histogram'] = histotext(h, maxchan)
sys.stdout.write(' %s' % histotext(h, maxchan))
sys.stdout.flush()
Expand All @@ -426,12 +428,14 @@ def source_compare(sourcePath, opts): # noqa
if not opts.full:
h = ts.histogram(
bins=9, output=dict(maxWidth=2048, maxHeight=2048),
range=[0, maxval], frame=frames - 1, **kwargs)
range=[0, maxval], frame=frames - 1, resample=0,
**kwargs)
result['full_f_2048_histogram'] = histotext(h, maxchan)
sys.stdout.write(' %s' % histotext(h, maxchan))
else:
# at full res
h = ts.histogram(bins=9, range=[0, maxval], frame=frames - 1, **kwargs)
h = ts.histogram(bins=9, range=[0, maxval],
frame=frames - 1, resample=0, **kwargs)
result['full_f_max_histogram'] = histotext(h, maxchan)
sys.stdout.write(' %s' % histotext(h, maxchan))
sys.stdout.flush()
Expand All @@ -444,7 +448,7 @@ def source_compare(sourcePath, opts): # noqa
h = ts.histogram(bins=32, output=dict(
maxWidth=int(math.ceil(ts.sizeX / 2 ** (levels - 1 - ll))),
maxHeight=int(math.ceil(ts.sizeY / 2 ** (levels - 1 - ll))),
), range=[0, maxval], frame=f, **kwargs)
), range=[0, maxval], frame=f, resample=0, **kwargs)
t += time.time()
result[f'level_{ll}_f_{f}_histogram'] = histotext(h, maxchan)
sys.stdout.write('%3d%5d %s' % (ll, f, histotext(h, maxchan)))
Expand Down

0 comments on commit b55d5fe

Please sign in to comment.