Skip to content

Commit

Permalink
Add general method to re-register TypeHandlers with options. Push new…
Browse files Browse the repository at this point in the history
… version.

PiperOrigin-RevId: 516919811
  • Loading branch information
cpgaffney1 authored and copybara-github committed Mar 15, 2023
1 parent 670ef8e commit 3795042
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.1.4] - 2022-03-15

### Added
- Add support for Tensorstore OCDBT option.
- Support for generic transformation function in PyTreeCheckpointHandler.
- Support n-digit checkpoint step format.

Expand Down
2 changes: 1 addition & 1 deletion orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""Orbax API."""

# A new PyPI release will be pushed everytime `__version__` is increased.
__version__ = '0.1.3'
__version__ = '0.1.4'
2 changes: 1 addition & 1 deletion orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
self._aggregate_filename = aggregate_filename
self._concurrent_gb = concurrent_gb
if use_ocdbt:
type_handlers.register_ocdbt_handlers()
type_handlers.register_standard_handlers_with_options(use_ocdbt=use_ocdbt)

def _get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
Expand Down
23 changes: 15 additions & 8 deletions orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,22 @@ def has_type_handler(ty: Any) -> bool:
return False


def register_ocdbt_handlers():
"""Re-registers select TypeHanders to use Tensorstore OCDBT driver."""
register_type_handler(int, ScalarHandler(use_ocdbt=True), override=True)
register_type_handler(float, ScalarHandler(use_ocdbt=True), override=True)
register_type_handler(np.number, ScalarHandler(use_ocdbt=True), override=True)
register_type_handler(np.ndarray, NumpyHandler(use_ocdbt=True), override=True)
def register_standard_handlers_with_options(**kwargs):
"""Re-registers a select set of handlers with the given options."""
register_type_handler(int, ScalarHandler(**kwargs), override=True)
register_type_handler(float, ScalarHandler(**kwargs), override=True)
register_type_handler(
np.number,
ScalarHandler(**kwargs),
override=True,
)
register_type_handler(
np.ndarray,
NumpyHandler(**kwargs),
override=True,
)
register_type_handler(
jax.Array,
ArrayHandler(use_ocdbt=True),
func=lambda ty: issubclass(ty, jax.Array) and jax.config.jax_array,
ArrayHandler(**kwargs),
override=True,
)

0 comments on commit 3795042

Please sign in to comment.