diff --git a/wsinfer_zoo/__init__.py b/wsinfer_zoo/__init__.py index 650e4fd..ca73d95 100644 --- a/wsinfer_zoo/__init__.py +++ b/wsinfer_zoo/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from wsinfer_zoo import _version from wsinfer_zoo import client # noqa diff --git a/wsinfer_zoo/__main__.py b/wsinfer_zoo/__main__.py index 9acd8a1..7e56bb8 100644 --- a/wsinfer_zoo/__main__.py +++ b/wsinfer_zoo/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from wsinfer_zoo.cli import cli if __name__ == "__main__": diff --git a/wsinfer_zoo/cli.py b/wsinfer_zoo/cli.py index 65a5eda..cb962bf 100644 --- a/wsinfer_zoo/cli.py +++ b/wsinfer_zoo/cli.py @@ -1,5 +1,7 @@ """Command-line interface for the WSInfer model zoo.""" +from __future__ import annotations + import dataclasses import json import sys diff --git a/wsinfer_zoo/client.py b/wsinfer_zoo/client.py index dab4df9..c7a030e 100644 --- a/wsinfer_zoo/client.py +++ b/wsinfer_zoo/client.py @@ -1,16 +1,14 @@ """API to interact with WSInfer model zoo, hosted on HuggingFace.""" +from __future__ import annotations + import dataclasses import functools import json import sys from pathlib import Path from typing import Any -from typing import Dict -from typing import List -from typing import Optional from typing import Sequence -from typing import Union import jsonschema from huggingface_hub import hf_hub_download @@ -28,6 +26,7 @@ WSINFER_ZOO_REGISTRY_DEFAULT_PATH = ( Path.home() / ".wsinfer-zoo" / "wsinfer-zoo-registry.json" ) +WSINFER_ZOO_REGISTRY_DEFAULT_PATH.parent.mkdir(exist_ok=True) # In pyinstaller runtime for one-file executables, the root path # is the path to the executable. @@ -93,7 +92,7 @@ class TransformConfigurationItem: """Container for one item in the 'transform' property of the model configuration.""" name: str - arguments: Optional[Dict[str, Any]] + arguments: dict[str, Any] | None @dataclasses.dataclass @@ -109,21 +108,21 @@ class ModelConfiguration: class_names: Sequence[str] patch_size_pixels: int spacing_um_px: float - transform: List[TransformConfigurationItem] + transform: list[TransformConfigurationItem] def __post_init__(self): if len(self.class_names) != self.num_classes: raise InvalidModelConfiguration() @classmethod - def from_dict(cls, config: Dict) -> "ModelConfiguration": + def from_dict(cls, config: dict) -> ModelConfiguration: validate_config_json(config) architecture = config["architecture"] num_classes = config["num_classes"] patch_size_pixels = config["patch_size_pixels"] spacing_um_px = config["spacing_um_px"] class_names = config["class_names"] - transform_list: List[Dict[str, Any]] = config["transform"] + transform_list: list[dict[str, Any]] = config["transform"] transform = [ TransformConfigurationItem(name=t["name"], arguments=t.get("arguments")) for t in transform_list @@ -157,7 +156,7 @@ class HFInfo: """Container for information on model's location on HuggingFace Hub.""" repo_id: str - revision: Optional[str] = None + revision: str | None = None @dataclasses.dataclass @@ -187,7 +186,7 @@ class HFModelWeightsOnly(HFModel): def load_torchscript_model_from_hf( - repo_id: str, revision: Optional[str] = None + repo_id: str, revision: str | None = None ) -> HFModelTorchScript: """Load a TorchScript model from HuggingFace.""" model_path = hf_hub_download(repo_id, HF_TORCHSCRIPT_NAME, revision=revision) @@ -207,7 +206,7 @@ def load_torchscript_model_from_hf( def load_weights_from_hf( - repo_id: str, revision: Optional[str] = None, safetensors: bool = False + repo_id: str, revision: str | None = None, safetensors: bool = False ) -> HFModelWeightsOnly: """Load model weights from HuggingFace (this is not TorchScript).""" if safetensors: @@ -260,7 +259,7 @@ def __str__(self) -> str: class ModelRegistry: """Registry of models that can be used with WSInfer.""" - models: Dict[str, RegisteredModel] + models: dict[str, RegisteredModel] def get_model_by_name(self, name: str) -> RegisteredModel: try: @@ -269,7 +268,7 @@ def get_model_by_name(self, name: str) -> RegisteredModel: raise KeyError(f"model not found with name '{name}'.") @classmethod - def from_dict(cls, config: Dict) -> "ModelRegistry": + def from_dict(cls, config: dict) -> ModelRegistry: """Create a new ModelRegistry instance from a dictionary.""" validate_model_zoo_json(config) models = { @@ -286,7 +285,7 @@ def from_dict(cls, config: Dict) -> "ModelRegistry": @functools.lru_cache() -def load_registry(registry_file: Optional[Union[str, Path]] = None) -> ModelRegistry: +def load_registry(registry_file: str | Path | None = None) -> ModelRegistry: """Load model registry. This downloads the registry JSON file to a cache and reuses it if