From fce32b8f4605033d74a9744ebcf15e0fdc08075b Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:53:35 +0200 Subject: [PATCH 1/7] add hf metric wrapper --- src/fairseq2/metrics/bag.py | 40 +++++++++++--- src/fairseq2/metrics/hf.py | 102 ++++++++++++++++++++++++++++++++++++ tests/unit/test_metrics.py | 34 ++++++++++++ 3 files changed, 168 insertions(+), 8 deletions(-) create mode 100644 src/fairseq2/metrics/hf.py diff --git a/src/fairseq2/metrics/bag.py b/src/fairseq2/metrics/bag.py index b064ef2e3..7b9400a56 100644 --- a/src/fairseq2/metrics/bag.py +++ b/src/fairseq2/metrics/bag.py @@ -34,6 +34,7 @@ def __init__(self, gang: Gang) -> None: super().__setattr__("_original_metrics", None) self._gang = gang + self.auto_sync = False def __getattr__(self, name: str) -> Any: if "_metrics" in self.__dict__ and name in self._metrics: @@ -62,6 +63,13 @@ def __delattr__(self, name: str) -> None: if name in self._persistent_metrics: del self._persistent_metrics[name] else: + assert name not in [ + "auto_sync", + "_gang", + "_metrics", + "_persistent_metrics", + "_original_metrics", + ], f"Cannot delete protected attribute: {name}" super().__delattr__(name) @final @@ -133,6 +141,23 @@ def reset_non_persistent_metrics(self) -> None: @final def sync_and_compute_metrics(self) -> Optional[Dict[str, Any]]: """Sync the metrics across all processes and compute their values.""" + if self.auto_sync: + try: + logging.disable(logging.WARNING) # Suppress "No calls to update()". + values = { + _strip_underscore(name): m.compute() + for name, m in self._metrics.items() + } + + # In auto-sync mode, the compute() automatically sync with other ranks + # and we get the result in rank 0 + if self._gang.rank == 0: + self.process_metric_values(values) + + return values + finally: + logging.disable(logging.NOTSET) + return sync_and_compute_metrics([self]) def process_metric_values(self, values: Dict[str, Any]) -> None: @@ -177,6 +202,12 @@ def reset_non_persistent_metrics(bags: Sequence[MetricBag]) -> None: bag.reset_non_persistent_metrics() +def _strip_underscore(s: str) -> str: + if s.startswith("_"): + s = s[1:] + return s + + def sync_and_compute_metrics(bags: Sequence[MetricBag]) -> Optional[Dict[str, Any]]: """Sync the metrics across all processes and and compute their values.""" if not bags: @@ -207,14 +238,7 @@ def sync_and_compute_metrics(bags: Sequence[MetricBag]) -> Optional[Dict[str, An if gang.rank == 0: assert values is not None - - def strip_underscore(s: str) -> str: - if s.startswith("_"): - s = s[1:] - - return s - - values = {strip_underscore(n): v for n, v in values.items()} + values = {_strip_underscore(n): v for n, v in values.items()} for bag in bags: bag.process_metric_values(values) diff --git a/src/fairseq2/metrics/hf.py b/src/fairseq2/metrics/hf.py new file mode 100644 index 000000000..76e67077b --- /dev/null +++ b/src/fairseq2/metrics/hf.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations +import importlib +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + TypeVar, + Union, + final, +) + +import torch +from torcheval.metrics import Metric +from typing_extensions import Self + +from fairseq2.typing import Device, override + +if TYPE_CHECKING: + import numpy + + +@final +class HFMetric(Metric[torch.Tensor]): + """ + A wrapper of HuggingFace `evaluate.Metric` that is compatible with + fairseq2 MetricBag API (which uses `torcheval.metrics.Metric`) + """ + + def __init__(self, metric_name, device: Optional[Device] = None) -> None: + try: + evaluate = importlib.import_module("evaluate") + except ImportError as exc: + raise ImportError( + "HFMetric requires the library `evaluate`, for instance via `pip install evaluate`" + ) from exc + super().__init__(device=device) + self._metric = evaluate.load(metric_name) + self._metric_name = metric_name + self._add_state( + metric_name, torch.zeros([]), device=device, dtype=torch.float32 + ) + + @override + @torch.inference_mode() + def update( + self, + predictions: Optional[Union[List[Any], torch.Tensor, numpy.ndarray]] = None, + references: Optional[Union[List[Any], torch.Tensor, numpy.ndarray]] = None, + **kwargs, + ) -> Self: + self._metric.add_batch(predictions=predictions, references=references, **kwargs) + + @override + @torch.inference_mode() + def compute(self) -> torch.Tensor: + """ + Compute the metric. + + The real metric result is in rank-0 device. For all other ranks, it will be zero + """ + result = self._metric.compute() + if result is not None: # rank 0 + assert ( + self._metric_name in result + ), f"Invalid result format: {result}. Expect key `{self._metric_name}`" + result_metric = torch.FloatTensor([result[self._metric_name]]) + self.__setattr__(self._metric_name, result_metric) + return result_metric + + @override + @torch.inference_mode() + def merge_state(self, metrics: Iterable[HFMetric]) -> Self: + raise NotImplementedError( + "Calling `merge_state() is forbidden in HFMetric. If you run HFMetric inside" + "a MetricBag, set the `auto_sync` in the bag to True" + ) + + @override + @torch.inference_mode() + def reset(self) -> Self: + self.__setattr__(self._metric_name, torch.zeros([])) + + # Reset the HF locks + self._metric._finalize() + if hasattr(self._metric, "filelock") and self.filelock is not None: + self._metric.filelock.release() + if ( + hasattr(self._metric, "rendez_vous_lock") + and self.rendez_vous_lock is not None + ): + self._metric.rendez_vous_lock.release() + self._metric.writer = None + self._metric.data = None diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index f64a8b2a8..437469d77 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -10,6 +10,7 @@ from fairseq2.gang import FakeGang from fairseq2.metrics import MetricBag +from fairseq2.metrics.hf import HFMetric from tests.common import device @@ -73,3 +74,36 @@ def test_load_state_dict_raises_error_when_state_dict_is_corrupt(self) -> None: match=r"^`state_dict` must contain metrics \['test1', 'test2'\], but contains \['foo'\] instead\.$", ): bag.load_state_dict(state_dict) + + def test_hf_metric(self) -> None: + bag = MetricBag(gang=FakeGang(device=device)) + bag.hf_accuracy = HFMetric("accuracy") + + references = [[0, 1, 2], [0, 1], [2], [3]] + predictions = [[0, 1, 1], [2, 1], [0], [3]] + + bag.begin_updates() + for p, r in zip(predictions, references): + bag.hf_accuracy.update(predictions=p, references=r) + + # Make sure auto_sync is set for HFMetric + with pytest.raises( + NotImplementedError, match=r"^Calling `merge_state() is forbidden in HFMetric" + ): + bag.sync_and_compute_metrics() + + bag.rollback_updates() + + # Add another metrics and properly set auto_sync flag + bag.em = HFMetric("exact_match") + bag.auto_sync = True + + bag.begin_updates() + for p, r in zip(predictions, references): + bag.hf_accuracy.update(predictions=p, references=r) + bag.em.update(predictions=p, references=r) + bag.commit_updates() + + result = bag.sync_and_compute_metrics() + assert "accuracy" in result + assert "exact_match" in result From 68d09a603d3f86db7de9963732171035b85fd69a Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 18 Jun 2024 14:34:32 +0200 Subject: [PATCH 2/7] update test cases --- src/fairseq2/metrics/hf.py | 49 +++++++++++++++++++------------------- tests/unit/test_metrics.py | 32 ++++++++++--------------- 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/src/fairseq2/metrics/hf.py b/src/fairseq2/metrics/hf.py index 76e67077b..fc382c65b 100644 --- a/src/fairseq2/metrics/hf.py +++ b/src/fairseq2/metrics/hf.py @@ -5,18 +5,9 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations + import importlib -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Optional, - TypeVar, - Union, - final, -) +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, final import torch from torcheval.metrics import Metric @@ -35,7 +26,7 @@ class HFMetric(Metric[torch.Tensor]): fairseq2 MetricBag API (which uses `torcheval.metrics.Metric`) """ - def __init__(self, metric_name, device: Optional[Device] = None) -> None: + def __init__(self, metric_name: str, device: Optional[Device] = None, **kwargs) -> None: # type: ignore try: evaluate = importlib.import_module("evaluate") except ImportError as exc: @@ -46,15 +37,16 @@ def __init__(self, metric_name, device: Optional[Device] = None) -> None: self._metric = evaluate.load(metric_name) self._metric_name = metric_name self._add_state( - metric_name, torch.zeros([]), device=device, dtype=torch.float32 + metric_name, torch.zeros((), device=device, dtype=torch.float32) ) + self.kwargs = kwargs @override @torch.inference_mode() - def update( + def update( # type: ignore self, - predictions: Optional[Union[List[Any], torch.Tensor, numpy.ndarray]] = None, - references: Optional[Union[List[Any], torch.Tensor, numpy.ndarray]] = None, + predictions: Optional[Union[List[Any], torch.Tensor, numpy.ndarray]] = None, # type: ignore + references: Optional[Union[List[Any], torch.Tensor, numpy.ndarray]] = None, # type: ignore **kwargs, ) -> Self: self._metric.add_batch(predictions=predictions, references=references, **kwargs) @@ -67,13 +59,17 @@ def compute(self) -> torch.Tensor: The real metric result is in rank-0 device. For all other ranks, it will be zero """ - result = self._metric.compute() + result = self._metric.compute(**self.kwargs) if result is not None: # rank 0 assert ( self._metric_name in result ), f"Invalid result format: {result}. Expect key `{self._metric_name}`" - result_metric = torch.FloatTensor([result[self._metric_name]]) + result_metric = torch.tensor( + result[self._metric_name], device=self.device, dtype=torch.float32 + ) self.__setattr__(self._metric_name, result_metric) + else: + result_metric = torch.zeros((), device=self.device, dtype=torch.float32) return result_metric @override @@ -86,17 +82,20 @@ def merge_state(self, metrics: Iterable[HFMetric]) -> Self: @override @torch.inference_mode() - def reset(self) -> Self: - self.__setattr__(self._metric_name, torch.zeros([])) + def reset(self) -> Self: # type: ignore + self.__setattr__( + self._metric_name, torch.zeros((), device=self.device, dtype=torch.float32) + ) # Reset the HF locks - self._metric._finalize() - if hasattr(self._metric, "filelock") and self.filelock is not None: - self._metric.filelock.release() + self._metric._finalize() # type: ignore + if hasattr(self._metric, "filelock") and self.filelock is not None: # type: ignore + self._metric.filelock.release() # type: ignore if ( hasattr(self._metric, "rendez_vous_lock") - and self.rendez_vous_lock is not None + and self.rendez_vous_lock is not None # type: ignore ): - self._metric.rendez_vous_lock.release() + self._metric.rendez_vous_lock.release() # type: ignore self._metric.writer = None self._metric.data = None + return self diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 437469d77..d4fb6f83c 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -79,31 +79,23 @@ def test_hf_metric(self) -> None: bag = MetricBag(gang=FakeGang(device=device)) bag.hf_accuracy = HFMetric("accuracy") + # All compute arguments are registered in the beginning + bag.f1 = HFMetric("f1", average="macro") + references = [[0, 1, 2], [0, 1], [2], [3]] predictions = [[0, 1, 1], [2, 1], [0], [3]] bag.begin_updates() for p, r in zip(predictions, references): bag.hf_accuracy.update(predictions=p, references=r) - - # Make sure auto_sync is set for HFMetric - with pytest.raises( - NotImplementedError, match=r"^Calling `merge_state() is forbidden in HFMetric" - ): - bag.sync_and_compute_metrics() - - bag.rollback_updates() - - # Add another metrics and properly set auto_sync flag - bag.em = HFMetric("exact_match") - bag.auto_sync = True - - bag.begin_updates() - for p, r in zip(predictions, references): - bag.hf_accuracy.update(predictions=p, references=r) - bag.em.update(predictions=p, references=r) + bag.f1.update(predictions=p, references=r) bag.commit_updates() - + + bag.auto_sync = True result = bag.sync_and_compute_metrics() - assert "accuracy" in result - assert "exact_match" in result + assert result + assert ( + "hf_accuracy" in result + and pytest.approx(result["hf_accuracy"].item(), 0.0001) == 0.5714 + ) + assert "f1" in result and pytest.approx(result["f1"].item(), 0.001) == 0.575 From 1e36968e79496af62288f3bf6725293ca0b38783 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:24:01 +0200 Subject: [PATCH 3/7] restructure --- src/fairseq2/recipes/hf/__init__.py | 5 +++ .../{metrics/hf.py => recipes/hf/metrics.py} | 0 src/fairseq2/recipes/wav2vec2/__init__.py | 5 +++ tests/integration/test_hf.py | 39 +++++++++++++++++++ tests/unit/test_metrics.py | 27 +------------ 5 files changed, 50 insertions(+), 26 deletions(-) create mode 100644 src/fairseq2/recipes/hf/__init__.py rename src/fairseq2/{metrics/hf.py => recipes/hf/metrics.py} (100%) create mode 100644 tests/integration/test_hf.py diff --git a/src/fairseq2/recipes/hf/__init__.py b/src/fairseq2/recipes/hf/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/src/fairseq2/recipes/hf/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/src/fairseq2/metrics/hf.py b/src/fairseq2/recipes/hf/metrics.py similarity index 100% rename from src/fairseq2/metrics/hf.py rename to src/fairseq2/recipes/hf/metrics.py diff --git a/src/fairseq2/recipes/wav2vec2/__init__.py b/src/fairseq2/recipes/wav2vec2/__init__.py index e69de29bb..2e41cd717 100644 --- a/src/fairseq2/recipes/wav2vec2/__init__.py +++ b/src/fairseq2/recipes/wav2vec2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/integration/test_hf.py b/tests/integration/test_hf.py new file mode 100644 index 000000000..5b9337e94 --- /dev/null +++ b/tests/integration/test_hf.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +from fairseq2.gang import FakeGang +from fairseq2.metrics import MetricBag +from fairseq2.recipes.hf.metrics import HFMetric +from tests.common import device + + +class TestMetric: + def test_hf_metric(self) -> None: + bag = MetricBag(gang=FakeGang(device=device)) + bag.hf_accuracy = HFMetric("accuracy") + + # All compute arguments are registered in the beginning + bag.f1 = HFMetric("f1", average="macro") + + references = [[0, 1, 2], [0, 1], [2], [3]] + predictions = [[0, 1, 1], [2, 1], [0], [3]] + + bag.begin_updates() + for p, r in zip(predictions, references): + bag.hf_accuracy.update(predictions=p, references=r) + bag.f1.update(predictions=p, references=r) + bag.commit_updates() + + bag.auto_sync = True + result = bag.sync_and_compute_metrics() + assert result + assert ( + "hf_accuracy" in result + and pytest.approx(result["hf_accuracy"].item(), 0.0001) == 0.5714 + ) + assert "f1" in result and pytest.approx(result["f1"].item(), 0.001) == 0.575 diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index d4fb6f83c..f956f3e0a 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -10,7 +10,7 @@ from fairseq2.gang import FakeGang from fairseq2.metrics import MetricBag -from fairseq2.metrics.hf import HFMetric +from fairseq2.recipes.hf.metrics import HFMetric from tests.common import device @@ -74,28 +74,3 @@ def test_load_state_dict_raises_error_when_state_dict_is_corrupt(self) -> None: match=r"^`state_dict` must contain metrics \['test1', 'test2'\], but contains \['foo'\] instead\.$", ): bag.load_state_dict(state_dict) - - def test_hf_metric(self) -> None: - bag = MetricBag(gang=FakeGang(device=device)) - bag.hf_accuracy = HFMetric("accuracy") - - # All compute arguments are registered in the beginning - bag.f1 = HFMetric("f1", average="macro") - - references = [[0, 1, 2], [0, 1], [2], [3]] - predictions = [[0, 1, 1], [2, 1], [0], [3]] - - bag.begin_updates() - for p, r in zip(predictions, references): - bag.hf_accuracy.update(predictions=p, references=r) - bag.f1.update(predictions=p, references=r) - bag.commit_updates() - - bag.auto_sync = True - result = bag.sync_and_compute_metrics() - assert result - assert ( - "hf_accuracy" in result - and pytest.approx(result["hf_accuracy"].item(), 0.0001) == 0.5714 - ) - assert "f1" in result and pytest.approx(result["f1"].item(), 0.001) == 0.575 From 54f109006f054b366747eaf37e5da7d7eb7ce037 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:29:30 +0200 Subject: [PATCH 4/7] revert import --- tests/unit/test_metrics.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index f956f3e0a..eebbf6511 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -10,7 +10,6 @@ from fairseq2.gang import FakeGang from fairseq2.metrics import MetricBag -from fairseq2.recipes.hf.metrics import HFMetric from tests.common import device @@ -74,3 +73,28 @@ def test_load_state_dict_raises_error_when_state_dict_is_corrupt(self) -> None: match=r"^`state_dict` must contain metrics \['test1', 'test2'\], but contains \['foo'\] instead\.$", ): bag.load_state_dict(state_dict) + + def test_hf_metric(self) -> None: + bag = MetricBag(gang=FakeGang(device=device)) + bag.hf_accuracy = HFMetric("accuracy") + + # All compute arguments are registered in the beginning + bag.f1 = HFMetric("f1", average="macro") + + references = [[0, 1, 2], [0, 1], [2], [3]] + predictions = [[0, 1, 1], [2, 1], [0], [3]] + + bag.begin_updates() + for p, r in zip(predictions, references): + bag.hf_accuracy.update(predictions=p, references=r) + bag.f1.update(predictions=p, references=r) + bag.commit_updates() + + bag.auto_sync = True + result = bag.sync_and_compute_metrics() + assert result + assert ( + "hf_accuracy" in result + and pytest.approx(result["hf_accuracy"].item(), 0.0001) == 0.5714 + ) + assert "f1" in result and pytest.approx(result["f1"].item(), 0.001) == 0.575 From 037d73b06c8dabfdb66646fd1a56043e44f533a9 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:33:38 +0200 Subject: [PATCH 5/7] revert import --- tests/unit/test_metrics.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index eebbf6511..f64a8b2a8 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -73,28 +73,3 @@ def test_load_state_dict_raises_error_when_state_dict_is_corrupt(self) -> None: match=r"^`state_dict` must contain metrics \['test1', 'test2'\], but contains \['foo'\] instead\.$", ): bag.load_state_dict(state_dict) - - def test_hf_metric(self) -> None: - bag = MetricBag(gang=FakeGang(device=device)) - bag.hf_accuracy = HFMetric("accuracy") - - # All compute arguments are registered in the beginning - bag.f1 = HFMetric("f1", average="macro") - - references = [[0, 1, 2], [0, 1], [2], [3]] - predictions = [[0, 1, 1], [2, 1], [0], [3]] - - bag.begin_updates() - for p, r in zip(predictions, references): - bag.hf_accuracy.update(predictions=p, references=r) - bag.f1.update(predictions=p, references=r) - bag.commit_updates() - - bag.auto_sync = True - result = bag.sync_and_compute_metrics() - assert result - assert ( - "hf_accuracy" in result - and pytest.approx(result["hf_accuracy"].item(), 0.0001) == 0.5714 - ) - assert "f1" in result and pytest.approx(result["f1"].item(), 0.001) == 0.575 From 4c5e73e7de0438f5da84f5478d300397de0a330f Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 18 Jun 2024 16:02:24 +0200 Subject: [PATCH 6/7] add evaluate to CI --- .github/workflows/_build_wheel-linux.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/_build_wheel-linux.yaml b/.github/workflows/_build_wheel-linux.yaml index a1056349c..a5c2d8e03 100644 --- a/.github/workflows/_build_wheel-linux.yaml +++ b/.github/workflows/_build_wheel-linux.yaml @@ -235,6 +235,7 @@ jobs: RUN_ON_DEVICE: ${{ inputs.run_on_device }} run: | if [[ $RUN_INTEGRATION_TESTS == true ]]; then + python -m pip install evaluate integration=--integration fi From 12cd1b39ae6eceb197658cd0152beab7b36ce225 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 18 Jun 2024 16:58:39 +0200 Subject: [PATCH 7/7] add scikit-learn to CI --- .github/workflows/_build_wheel-linux.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_build_wheel-linux.yaml b/.github/workflows/_build_wheel-linux.yaml index a5c2d8e03..4a2477ec6 100644 --- a/.github/workflows/_build_wheel-linux.yaml +++ b/.github/workflows/_build_wheel-linux.yaml @@ -235,7 +235,7 @@ jobs: RUN_ON_DEVICE: ${{ inputs.run_on_device }} run: | if [[ $RUN_INTEGRATION_TESTS == true ]]; then - python -m pip install evaluate + python -m pip install evaluate scikit-learn integration=--integration fi