Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hf metric wrapper #599

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_build_wheel-linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ jobs:
RUN_ON_DEVICE: ${{ inputs.run_on_device }}
run: |
if [[ $RUN_INTEGRATION_TESTS == true ]]; then
python -m pip install evaluate scikit-learn
integration=--integration
fi

Expand Down
40 changes: 32 additions & 8 deletions src/fairseq2/metrics/bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, gang: Gang) -> None:
super().__setattr__("_original_metrics", None)

self._gang = gang
self.auto_sync = False

def __getattr__(self, name: str) -> Any:
if "_metrics" in self.__dict__ and name in self._metrics:
Expand Down Expand Up @@ -62,6 +63,13 @@ def __delattr__(self, name: str) -> None:
if name in self._persistent_metrics:
del self._persistent_metrics[name]
else:
assert name not in [
"auto_sync",
"_gang",
"_metrics",
"_persistent_metrics",
"_original_metrics",
], f"Cannot delete protected attribute: {name}"
super().__delattr__(name)

@final
Expand Down Expand Up @@ -133,6 +141,23 @@ def reset_non_persistent_metrics(self) -> None:
@final
def sync_and_compute_metrics(self) -> Optional[Dict[str, Any]]:
"""Sync the metrics across all processes and compute their values."""
if self.auto_sync:
try:
logging.disable(logging.WARNING) # Suppress "No calls to update()".
values = {
_strip_underscore(name): m.compute()
for name, m in self._metrics.items()
}

# In auto-sync mode, the compute() automatically sync with other ranks
# and we get the result in rank 0
if self._gang.rank == 0:
self.process_metric_values(values)

return values
finally:
logging.disable(logging.NOTSET)

return sync_and_compute_metrics([self])

def process_metric_values(self, values: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -177,6 +202,12 @@ def reset_non_persistent_metrics(bags: Sequence[MetricBag]) -> None:
bag.reset_non_persistent_metrics()


def _strip_underscore(s: str) -> str:
if s.startswith("_"):
s = s[1:]
return s


def sync_and_compute_metrics(bags: Sequence[MetricBag]) -> Optional[Dict[str, Any]]:
"""Sync the metrics across all processes and and compute their values."""
if not bags:
Expand Down Expand Up @@ -207,14 +238,7 @@ def sync_and_compute_metrics(bags: Sequence[MetricBag]) -> Optional[Dict[str, An

if gang.rank == 0:
assert values is not None

def strip_underscore(s: str) -> str:
if s.startswith("_"):
s = s[1:]

return s

values = {strip_underscore(n): v for n, v in values.items()}
values = {_strip_underscore(n): v for n, v in values.items()}

for bag in bags:
bag.process_metric_values(values)
Expand Down
5 changes: 5 additions & 0 deletions src/fairseq2/recipes/hf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
101 changes: 101 additions & 0 deletions src/fairseq2/recipes/hf/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import importlib
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, final

import torch
from torcheval.metrics import Metric
from typing_extensions import Self

from fairseq2.typing import Device, override

if TYPE_CHECKING:
import numpy


@final
class HFMetric(Metric[torch.Tensor]):
"""
A wrapper of HuggingFace `evaluate.Metric` that is compatible with
fairseq2 MetricBag API (which uses `torcheval.metrics.Metric`)
"""

def __init__(self, metric_name: str, device: Optional[Device] = None, **kwargs) -> None: # type: ignore
try:
evaluate = importlib.import_module("evaluate")
except ImportError as exc:
raise ImportError(
"HFMetric requires the library `evaluate`, for instance via `pip install evaluate`"
) from exc
super().__init__(device=device)
self._metric = evaluate.load(metric_name)
self._metric_name = metric_name
self._add_state(
metric_name, torch.zeros((), device=device, dtype=torch.float32)
)
self.kwargs = kwargs

@override
@torch.inference_mode()
def update( # type: ignore
self,
predictions: Optional[Union[List[Any], torch.Tensor, numpy.ndarray]] = None, # type: ignore
references: Optional[Union[List[Any], torch.Tensor, numpy.ndarray]] = None, # type: ignore
**kwargs,
) -> Self:
self._metric.add_batch(predictions=predictions, references=references, **kwargs)

@override
@torch.inference_mode()
def compute(self) -> torch.Tensor:
"""
Compute the metric.

The real metric result is in rank-0 device. For all other ranks, it will be zero
"""
result = self._metric.compute(**self.kwargs)
if result is not None: # rank 0
assert (
self._metric_name in result
), f"Invalid result format: {result}. Expect key `{self._metric_name}`"
result_metric = torch.tensor(
result[self._metric_name], device=self.device, dtype=torch.float32
)
self.__setattr__(self._metric_name, result_metric)
else:
result_metric = torch.zeros((), device=self.device, dtype=torch.float32)
return result_metric

@override
@torch.inference_mode()
def merge_state(self, metrics: Iterable[HFMetric]) -> Self:
raise NotImplementedError(
"Calling `merge_state() is forbidden in HFMetric. If you run HFMetric inside"
"a MetricBag, set the `auto_sync` in the bag to True"
)

@override
@torch.inference_mode()
def reset(self) -> Self: # type: ignore
self.__setattr__(
self._metric_name, torch.zeros((), device=self.device, dtype=torch.float32)
)

# Reset the HF locks
self._metric._finalize() # type: ignore
if hasattr(self._metric, "filelock") and self.filelock is not None: # type: ignore
self._metric.filelock.release() # type: ignore
if (
hasattr(self._metric, "rendez_vous_lock")
and self.rendez_vous_lock is not None # type: ignore
):
self._metric.rendez_vous_lock.release() # type: ignore
self._metric.writer = None
self._metric.data = None
return self
5 changes: 5 additions & 0 deletions src/fairseq2/recipes/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
39 changes: 39 additions & 0 deletions tests/integration/test_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

from fairseq2.gang import FakeGang
from fairseq2.metrics import MetricBag
from fairseq2.recipes.hf.metrics import HFMetric
from tests.common import device


class TestMetric:
def test_hf_metric(self) -> None:
bag = MetricBag(gang=FakeGang(device=device))
bag.hf_accuracy = HFMetric("accuracy")

# All compute arguments are registered in the beginning
bag.f1 = HFMetric("f1", average="macro")

references = [[0, 1, 2], [0, 1], [2], [3]]
predictions = [[0, 1, 1], [2, 1], [0], [3]]

bag.begin_updates()
for p, r in zip(predictions, references):
bag.hf_accuracy.update(predictions=p, references=r)
bag.f1.update(predictions=p, references=r)
bag.commit_updates()

bag.auto_sync = True
result = bag.sync_and_compute_metrics()
assert result
assert (
"hf_accuracy" in result
and pytest.approx(result["hf_accuracy"].item(), 0.0001) == 0.5714
)
assert "f1" in result and pytest.approx(result["f1"].item(), 0.001) == 0.575
Loading