Skip to content

Commit

Permalink
Add feature extractors (#6)
Browse files Browse the repository at this point in the history
* modularize extractors + install timm

* use len instead of shape for num patches

* add virchow2 and hoptimus0

* mypy, please ignore timm

* use extractor.run

* do not cache tissue mask or coords

These are very quick to compute. Instead, we only cache
the tile embeddings, which take time to compute.

* add encoders hibou, kaiko, phikon, phikonv2, provgigapath, virchow

* mypy: ignore transformers

* sort imports

* set abmil model to eval mode

* update model inference func

* replace deprectated torchvision transforms funcs

* add Resize to hoptimus0 transform

* add --quantize option

* do not squeeze embeddings

* add --quantize to runlocal

* remove unused type:ignore commment

* add transformers as dependency

* test python3.12

* run mypy with python 3.12

* set zarr<3.0.0 and rm tifffile and imagecodecs

* ignore complicated numpy typing issues

* reformat with ruff

* use env var to force cpu usage

* in windows, set SPINPATH_FORCE_CPU
  • Loading branch information
kaczmarj authored Jan 13, 2025
1 parent 417c3cd commit b6b49a8
Show file tree
Hide file tree
Showing 21 changed files with 524 additions and 262 deletions.
16 changes: 9 additions & 7 deletions .github/workflows/cli-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
# python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.12"]
steps:
- name: Cache sample WSI
id: cache-wsi
Expand Down Expand Up @@ -41,14 +42,14 @@ jobs:
# Test it twice so the second time we get cache hits.
run: |
cd /tmp
wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
SPINPATH_FORCE_CPU=1 wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
SPINPATH_FORCE_CPU=1 wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
macOS:
runs-on: macos-latest
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.12"]
steps:
- name: Cache sample WSI
id: cache-wsi
Expand Down Expand Up @@ -78,14 +79,14 @@ jobs:
# Test it twice so the second time we get cache hits.
run: |
cd /tmp
wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
SPINPATH_FORCE_CPU=1 wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
SPINPATH_FORCE_CPU=1 wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
Windows:
runs-on: windows-latest
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.12"]
steps:
- name: Cache sample WSI
id: cache-wsi
Expand Down Expand Up @@ -116,5 +117,6 @@ jobs:
run: |
mkdir -p ~/foobar
cd ~/foobar
set SPINPATH_FORCE_CPU=1
wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
wsinfer-mil run -m kaczmarj/pancancer-tissue-classifier.tcga -i ~/wsi/CMU-1.svs
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
python-version: ["3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
16 changes: 7 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,15 @@ dependencies = [
"pillow",
"platformdirs",
"scikit-image>=0.20.0",
"timm",
"shapely",
"tabulate",
# https://github.com/Bayer-Group/tiffslide/issues/72#issuecomment-1631015274
"tifffile>=2023.1.23",
"tiffslide>=2.2.0",
# https://github.com/Bayer-Group/tiffslide/issues/72#issuecomment-1630091390
"imagecodecs >= 2022.7.27 ; python_version<'3.9'",
# https://github.com/Bayer-Group/tiffslide/issues/72#issuecomment-1630091390
"imagecodecs >= 2023.7.10 ; python_version>='3.9'",
# The installation of torch and torchvision can differ by hardware. Users are
# advised to install torch and torchvision for their given hardware and then install
# wsinfer-mil. See https://pytorch.org/get-started/locally/.
# https://github.com/Bayer-Group/tiffslide/issues/89
"zarr<3.0.0",
"torch>=1.7",
"torchvision",
"transformers",
"tqdm",
]
dynamic = ["version"]
Expand Down Expand Up @@ -114,6 +109,9 @@ module = [
"shapely.*",
"skimage.morphology",
"tifffile",
"timm",
"timm.*",
"transformers",
"zarr.storage",
]
ignore_missing_imports = true
Expand Down
186 changes: 46 additions & 140 deletions wsinfer_mil/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

import functools
import hashlib
import logging
from pathlib import Path

Expand All @@ -14,177 +14,83 @@
from PIL import Image

from wsinfer_mil.defaults import WSINFER_MIL_CACHE_DIR
from wsinfer_mil.extractors import PatchFeatureExtractor
from wsinfer_mil.patchlib import read_patch_coords
from wsinfer_mil.patchlib import write_patch_coords

TISSUE_MASK_FILENAME = "tissue.png"
PATCH_COORDS_FILENAME = "patches.h5"

logger = logging.getLogger(__name__)


def _get_embedding_filename(extractor: PatchFeatureExtractor) -> str:
return f"{extractor.name}.npy"


def log_setter(func): # type: ignore
"""Wrapper to log the status of a cache 'get_*' method."""

@functools.wraps(func)
def wrapper(self: Cache, *args, **kwargs): # type: ignore
key = func.__name__[4:] # trim off 'set_'
logger.debug(f"Attempting to set {key} in {self.slide_cache_dir}")
result = func(self, *args, **kwargs)
logger.debug(f"Set {key}")
return result

return wrapper


def log_getter(func): # type: ignore
"""Wrapper to log the status of a cache 'get_*' method."""
def _hash_image(image: Image.Image) -> str:
"""Generate a hash for an image."""
image_bytes = image.tobytes()
return hashlib.md5(image_bytes).hexdigest()

@functools.wraps(func)
def wrapper(self: Cache, *args, **kwargs): # type: ignore
key = func.__name__[4:] # trim off 'get_'
logger.debug(f"Attempting to get {key} from {self.slide_cache_dir}")
result = func(self, *args, **kwargs)
if result is None:
logger.debug(f"No entry found for {key}")
else:
logger.debug(f"Found entry for {key}")
return result

return wrapper
def _hash_array(array: npt.NDArray) -> str:
"""Generate a hash for a numpy array."""
array_bytes = array.tobytes()
return hashlib.md5(array_bytes).hexdigest()


class Cache:
"""Cache for slide tissue masks, patch coordinates, and embeddings."""
class EmbeddingsCache:
"""Cache for slide embeddings."""

def __init__(
self,
slide_path: str | Path,
*,
slide_quickhash: str,
patch_size_um: float,
cache_dir: str | None = None,
tissue_mask: Image.Image,
patch_coordinates: npt.NDArray,
embedding_model_name: str,
cache_dir: Path | None = None,
) -> None:
self.slide_path = slide_path
self.slide_quickhash = slide_quickhash
self.patch_size_um = patch_size_um
# self.patch_size_um = patch_size_um
if cache_dir is None:
self.cache_dir = WSINFER_MIL_CACHE_DIR
else:
self.cache_dir = Path(cache_dir)

self.slide_name = Path(slide_path).name

hash_parts = [
slide_quickhash,
_hash_image(tissue_mask),
_hash_array(patch_coordinates),
embedding_model_name,
]
combined_string = "_".join(hash_parts)
self.cache_key = hashlib.md5(combined_string.encode("utf-8")).hexdigest()
self.embedding_filename = self.cache_dir / f"{self.cache_key}.npy"

def save(self, embedding: npt.NDArray[np.float32]) -> None:
logger.debug(
f"Instantiating cache object for slide {self.slide_name}"
f" with patches of size {self.patch_size_um} microns."
)
logger.debug(f"Cache directory is {self.slide_cache_dir}")
if self.slide_cache_dir.exists():
logger.debug("Cache directory exists")

@property
def slide_cache_dir(self) -> Path:
slide_cache_dir = (
self.cache_dir
/ f"{self.patch_size_um}um"
/ f"{self.slide_name}_md5-{self.slide_quickhash}"
f"[Cache: {self.cache_key}] Saving embedding to {self.embedding_filename}"
)
return slide_cache_dir

@log_setter
def set_tissue_mask(self, tissue_mask: Image.Image) -> None:
if not isinstance(tissue_mask, Image.Image):
raise TypeError(f"tissue_mask must be Image but got {type(tissue_mask)}")
path = self.slide_cache_dir / TISSUE_MASK_FILENAME
path.parent.mkdir(exist_ok=True, parents=True)
tissue_mask.save(path)

@log_getter
def get_tissue_mask(self) -> Image.Image | None:
path = self.slide_cache_dir / TISSUE_MASK_FILENAME
if path.exists():
# Open the image in this way to close the file handle.
with Image.open(path) as img:
img.load()
return img
return None

@log_setter
def set_patch_coordinates(
self,
patch_coordinates: npt.NDArray[np.int_],
patch_size_um: float,
) -> None:
"""Set patch coordinates.
Parameters
----------
patch_coordinates : array
A Nx2 array, where each row contains [minx, miny.]
patch_spacing_um_px : float
The physical spacing of one pixel in micrometers per pixel.
Returns
-------
None
"""
if not isinstance(patch_coordinates, np.ndarray):
raise TypeError(
f"patch_coordinates must be Image but got {type(patch_coordinates)}"
)
if not np.issubdtype(patch_coordinates.dtype, np.int_):
raise TypeError(f"must be int dtype but got {patch_coordinates.dtype}")
path = self.slide_cache_dir / PATCH_COORDS_FILENAME
path.parent.mkdir(exist_ok=True, parents=True)
write_patch_coords(
path=path,
coords=patch_coordinates,
patch_size_um=patch_size_um,
compression="gzip",
)

@log_getter
def get_patch_coordinates(self) -> npt.NDArray[np.int_] | None:
"""Read a Nx4 array of patch coordinates.
Each row is [minx, miny, width, height] of the patch.
"""
path = self.slide_cache_dir / PATCH_COORDS_FILENAME
if path.exists():
return read_patch_coords(path)
return None

@log_setter
def set_embedding(
self,
extractor: PatchFeatureExtractor,
embedding: npt.NDArray[np.float32],
) -> None:
if not isinstance(embedding, np.ndarray):
raise TypeError(f"embedding must be a numpy array, got {type(embedding)}")
if embedding.dtype != np.float32:
raise TypeError(
f"dtype of embedding must be float32 but got {embedding.dtype}"
)
filename = _get_embedding_filename(extractor)
path = self.slide_cache_dir / filename
path.parent.mkdir(exist_ok=True, parents=True)
np.save(path, embedding)

@log_getter
def get_embedding(
self, extractor: PatchFeatureExtractor
) -> npt.NDArray[np.float32] | None:
filename = _get_embedding_filename(extractor)
path = self.slide_cache_dir / filename
if path.exists():
res = np.load(path)
if embedding.ndim != 2:
raise ValueError(f"embedding must be 2D, got {embedding.ndim}")
self.embedding_filename.parent.mkdir(exist_ok=True, parents=True)
np.save(self.embedding_filename, embedding)
self._embedding = embedding

def load(self) -> npt.NDArray[np.float32] | None:
logger.debug(
f"[Cache: {self.cache_key}] Attempting to load embedding {self.embedding_filename}"
)
if self.embedding_filename.exists():
logger.debug(f"[Cache: {self.cache_key}] Loading {self.embedding_filename}")
res = np.load(self.embedding_filename)
assert isinstance(res, np.ndarray)
assert res.dtype == np.float32
assert res.ndim == 2
return res
logger.debug(
f"[Cache: {self.cache_key}] Embedding does not exist for {self.embedding_filename}"
)
return None
8 changes: 8 additions & 0 deletions wsinfer_mil/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _run_impl(
num_workers: int,
tablefmt: str,
json: bool,
quantize: bool,
) -> None:
tissue_mask: Image.Image | None = None
if tissue_mask_path is not None:
Expand All @@ -46,6 +47,7 @@ def _run_impl(
model=model,
tissue_mask=tissue_mask,
num_workers=num_workers,
quantize=quantize,
)

if json:
Expand Down Expand Up @@ -96,6 +98,7 @@ def cli() -> None:
@click.option(
"--json", is_flag=True, help="Print the model outputs (and attention) as JSON"
)
@click.option("--quantize", is_flag=True, help="Quantize embedding model to float16")
def run(
*,
hf_repo_id: str,
Expand All @@ -105,6 +108,7 @@ def run(
num_workers: int,
table_format: str,
json: bool,
quantize: bool,
) -> None:
model = load_torchscript_model_from_hf(hf_repo_id, hf_repo_revision)
if num_workers == -1:
Expand All @@ -116,6 +120,7 @@ def run(
num_workers=num_workers,
tablefmt=table_format,
json=json,
quantize=quantize,
)


Expand Down Expand Up @@ -150,6 +155,7 @@ def run(
@click.option(
"--json", is_flag=True, help="Print the model outputs (and attention) as JSON"
)
@click.option("--quantize", is_flag=True, help="Quantize embedding model to float16")
def runlocal(
*,
model_path: Path,
Expand All @@ -159,6 +165,7 @@ def runlocal(
num_workers: int,
table_format: str,
json: bool,
quantize: bool,
) -> None:
model = load_torchscript_model_from_filesystem(model_path, model_config_path)
_run_impl(
Expand All @@ -168,4 +175,5 @@ def runlocal(
num_workers=num_workers,
tablefmt=table_format,
json=json,
quantize=quantize,
)
Loading

0 comments on commit b6b49a8

Please sign in to comment.