Skip to content

Commit

Permalink
fix(BaseMlflowApi): protect cached files, allow duplication by defaul…
Browse files Browse the repository at this point in the history
…t if using caching from local repo
  • Loading branch information
lariel-fernandes committed Aug 27, 2024
1 parent 7e86c57 commit ae625c4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
40 changes: 20 additions & 20 deletions src/mlopus/mlflow/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,12 +855,12 @@ def log_run_artifact(
target = urls.urljoin(run.repo_url, path_in_run)

if repo_is_local := urls.is_local(target):
if allow_duplication is None:
allow_duplication = True if keep_the_source else False # noqa: SIM211,SIM210

if use_cache is None:
use_cache = False

if allow_duplication is None:
allow_duplication = True if keep_the_source or use_cache else False # noqa: SIM211,SIM210

if keep_the_source:
if allow_duplication:
mode = "copy"
Expand All @@ -877,35 +877,35 @@ def log_run_artifact(
move_abs_links=False,
)
else:
if allow_duplication is None:
allow_duplication = False

if use_cache is None:
use_cache = True

if allow_duplication is None:
allow_duplication = False

self.file_transfer.push_files(source, target)
except BaseException as exc:
raise exceptions.FailedToPublishArtifact(source) from exc

logger.debug(f"Artifact successfully published to '{target}'")

if use_cache:
cache = self._get_run_artifact_cache_path(run, path_in_run, allow_base_resolve=False)
if repo_is_local:
if allow_duplication:
paths.place_path(target.path, cache, mode="copy", overwrite=True)
else:
raise RuntimeError("Cannot cache artifact without duplication when run artifacts repo is local")
elif keep_the_source:
if allow_duplication:
paths.place_path(source, cache, mode="copy", overwrite=True)
with self._lock_run_artifact(run.id, path_in_run, allow_base_resolve=False) as cache:
if repo_is_local:
if allow_duplication:
paths.place_path(target.path, cache, mode="copy", overwrite=True)
else:
raise RuntimeError("Cannot cache artifact without duplication when run artifacts repo is local")
elif keep_the_source:
if allow_duplication:
paths.place_path(source, cache, mode="copy", overwrite=True)
else:
logger.warning("Keeping the `source` as a symbolic link to the cached artifact")
logger.debug(f"{source} -> {cache}")
paths.place_path(source, cache, mode="move", overwrite=True)
paths.place_path(cache, source, mode="link", overwrite=True)
else:
logger.warning("Keeping the `source` as a symbolic link to the cached artifact")
logger.debug(f"{source} -> {cache}")
paths.place_path(source, cache, mode="move", overwrite=True)
paths.place_path(cache, source, mode="link", overwrite=True)
else:
paths.place_path(source, cache, mode="move", overwrite=True)

if not keep_the_source:
paths.ensure_non_existing(source, force=True)
Expand Down
7 changes: 6 additions & 1 deletion src/tests/test_mlflow/mlflow_api_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ def test_artifacts(self, request, temp_dir, api: API):
self._ctx_model(request, api, name, {}),
self._ctx_model_version(request, api, run_id, name, {}),
):
api.log_run_artifact(run_id, dumper, path_in_run, allow_duplication=True, use_cache=True)
api.log_run_artifact(run_id, lambda p: p.write_text("ok"), path_in_run, use_cache=True)

with pytest.raises(PermissionError): # Manual rewrite is not permitted
api.get_run_artifact(run_id, path_in_run).write_text("bla")

api.log_run_artifact(run_id, dumper, path_in_run, use_cache=True)

assert len(api.list_run_artifacts(run_id, path_in_run)) == len(artifact)

Expand Down

0 comments on commit ae625c4

Please sign in to comment.