Skip to content

Commit

Permalink
fix(artschema.tags): fix detection of installed extras and pkg spec v…
Browse files Browse the repository at this point in the history
…erification
  • Loading branch information
lariel-fernandes committed Oct 23, 2024
1 parent cf4a20c commit 64f97c0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/mlopus/artschema/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def dist(self) -> packaging.Dist:

def check_requirement(self):
"""Validate this package requirement against the current Python environment."""
if not packaging.check_dist(self.dist, self.version, self.constraint):
if not packaging.check_dist(self.dist, specifier=self.constraint + self.version):
raise RuntimeError(f"Python requirement not matched: {self.name}{self.constraint}{self.version}")

def check_extras(self):
Expand Down Expand Up @@ -196,7 +196,7 @@ def register(self, subject: API):
:param subject: | Experiment, run, model or model version with API handle.
"""
logger.info("Registering artifact schemas for %s\n%s", subject, self.json(indent=4))
logger.info("Registering artifact schemas for %s\n%s", subject, self.model_dump_json(indent=4))
subject.set_tags(self)

def get_schema(self, alias: str | None = None) -> ClassSpec:
Expand Down
74 changes: 50 additions & 24 deletions src/mlopus/utils/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import re
from pathlib import Path
from typing import Type, Any, Set, Literal
from typing import Type, Any, Set, Literal, Dict, Tuple

import importlib_metadata
from packaging.specifiers import SpecifierSet
Expand All @@ -12,43 +12,69 @@

Dist = importlib_metadata.Distribution

VersionConstraint = Literal["==", "~", "^", ">="]
VersionConstraint = Literal["==", "~=", "^", ">="]


class Patterns:
"""Patterns used in packaging inspection."""

EXTRA_REQ = re.compile(r'^\w+ \(.*\) ; extra == "(?P<extra>\w+)"$') # extracts optional extra from package req
EXTRA_REQ = re.compile(r'^(?P<pkg>[\w.-]+)(?P<specifier>.*); extra == "(?P<extra>\w+)"$')


def get_dist(name: str) -> Dist:
def get_dist(name: str, strict: bool = True) -> Dist | None:
"""Get distribution metadata by name."""
return importlib_metadata.distribution(name)
try:
return importlib_metadata.distribution(name)
except importlib_metadata.PackageNotFoundError:
if strict:
raise
return None


def is_editable_dist(dist: Dist) -> bool:
"""Tell if distribution is installed from editable source code."""
return (origin := dist.origin) and (dir_info := getattr(origin, "dir_info", None)) and dir_info.editable # noqa


def get_available_dist_extras(dist: Dist) -> Set[str]:
"""Get list of optional extras that can be installed for the given package distribution."""
return set(dist.metadata.get_all("Provides-Extra"))
def get_available_dist_extras(dist: Dist) -> Dict[str, Tuple[str, str]]:
"""Get mapping of optional extras that can be installed for the given package distribution.
Output format: {extra: [(pkg, specifier), ...]}
"""
extras = {}

for req in dist.requires:
if match := Patterns.EXTRA_REQ.fullmatch(req):
spec = match.group("pkg"), match.group("specifier")
extras.setdefault(match.group("extra"), []).append(spec)

return extras


def get_installed_dist_extras(dist: Dist) -> Set[str]:
"""Get list of optional extras currently installed for the given package distribution."""
return {match.group("extra") for x in dist.requires if (match := Patterns.EXTRA_REQ.fullmatch(x))}
installed = set()

for extra, reqs in get_available_dist_extras(dist).items():
for pkg, specifier in reqs:
if not (dist := get_dist(pkg, strict=False)) or not check_dist(dist, specifier):
break
else:
installed.add(extra)

return installed

def check_dist(dist: Dist, version: str, constraint: VersionConstraint) -> bool:

def check_dist(dist: Dist, specifier: str) -> bool:
"""Check if version of package distribution satisfies the specified version constraint."""
return check_version(dist.version, version, constraint)
return check_version(dist.version, specifier)


def check_version(actual_version: str, required_version: str, constraint: VersionConstraint) -> bool:
def check_version(actual_version: str, specifier: str) -> bool:
"""Check if version satisfies constraint."""
return Version(actual_version) in SpecifierSet(_convert_specifier(constraint + required_version))
return Version(actual_version) in SpecifierSet(
",".join(_convert_caret_specifier(x) if x.startswith("^") else x for x in specifier.split(","))
)


def pkg_dist_of_cls(cls: Type[Any]) -> Dist:
Expand All @@ -72,15 +98,15 @@ def pkg_dist_of_cls(cls: Type[Any]) -> Dist:
raise RuntimeError(f"Distribution not found for {cls}")


def _convert_specifier(version_constraint: str):
"""Convert version constraint to pattern accepted by `packaging.SpecifierSet`"""
if version_constraint.startswith("^"):
base_version = version_constraint[1:]
major, minor, _ = base_version.split(".")
return f">={base_version},<{int(major)+1}.0.0"
elif version_constraint.startswith("~"):
base_version = version_constraint[1:]
major, minor, _ = base_version.split(".")
return f">={base_version},<{major}.{int(minor)+1}.0"
def _convert_caret_specifier(caret_spec: str) -> str:
"""Convert a caret version specifier (e.g.: ^1.2.3) to a specifier supported by SpecifierSet."""
version = Version(caret_spec.removeprefix("^"))

if version.major > 0:
upper = f"{version.major + 1}"
elif version.minor > 0:
upper = f"0.{version.minor + 1}"
else:
return version_constraint
upper = f"0.0.{version.micro + 1}"

return f">={version},<{upper}"

0 comments on commit 64f97c0

Please sign in to comment.